use std::fmt;
use std::marker::PhantomData;
use crate::config::DatabaseType;
use crate::error::{Error, Result};
use crate::internal::{ConnectionTrait, FromQueryResult, Statement, Value};
use crate::model::Model;
#[derive(Debug, Clone, Default)]
pub struct FullTextConfig {
pub language: Option<String>,
pub mode: SearchMode,
pub min_word_length: Option<u32>,
pub max_word_length: Option<u32>,
pub stop_words: Vec<String>,
pub weights: Option<SearchWeights>,
}
impl FullTextConfig {
pub fn new() -> Self {
Self::default()
}
pub fn language(mut self, lang: impl Into<String>) -> Self {
self.language = Some(lang.into());
self
}
pub fn mode(mut self, mode: SearchMode) -> Self {
self.mode = mode;
self
}
pub fn min_word_length(mut self, len: u32) -> Self {
self.min_word_length = Some(len);
self
}
pub fn max_word_length(mut self, len: u32) -> Self {
self.max_word_length = Some(len);
self
}
pub fn stop_words(mut self, words: Vec<String>) -> Self {
self.stop_words = words;
self
}
pub fn weights(mut self, weights: SearchWeights) -> Self {
self.weights = Some(weights);
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SearchMode {
#[default]
Natural,
Boolean,
Phrase,
Prefix,
Fuzzy,
Proximity(u32),
}
impl fmt::Display for SearchMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SearchMode::Natural => write!(f, "natural"),
SearchMode::Boolean => write!(f, "boolean"),
SearchMode::Phrase => write!(f, "phrase"),
SearchMode::Prefix => write!(f, "prefix"),
SearchMode::Fuzzy => write!(f, "fuzzy"),
SearchMode::Proximity(d) => write!(f, "proximity({})", d),
}
}
}
#[derive(Debug, Clone)]
pub struct SearchWeights {
pub a: f32,
pub b: f32,
pub c: f32,
pub d: f32,
}
impl Default for SearchWeights {
fn default() -> Self {
Self {
a: 1.0,
b: 0.4,
c: 0.2,
d: 0.1,
}
}
}
impl SearchWeights {
pub fn new(a: f32, b: f32, c: f32, d: f32) -> Self {
Self { a, b, c, d }
}
pub fn to_pg_array(&self) -> String {
format!("'{{{},{},{},{}}}'", self.d, self.c, self.b, self.a)
}
}
#[derive(Debug, Clone)]
pub struct SearchResult<T> {
pub record: T,
pub rank: f64,
pub highlights: Vec<HighlightedField>,
}
impl<T> SearchResult<T> {
pub fn new(record: T, rank: f64) -> Self {
Self {
record,
rank,
highlights: Vec::new(),
}
}
pub fn with_highlights(mut self, highlights: Vec<HighlightedField>) -> Self {
self.highlights = highlights;
self
}
}
#[derive(Debug, Clone)]
pub struct HighlightedField {
pub field: String,
pub highlighted: String,
pub original: String,
pub match_count: usize,
}
impl HighlightedField {
pub fn new(
field: impl Into<String>,
highlighted: impl Into<String>,
original: impl Into<String>,
) -> Self {
let highlighted = highlighted.into();
let original = original.into();
let match_count = highlighted.matches("<mark>").count();
Self {
field: field.into(),
highlighted,
original,
match_count,
}
}
}
pub trait FullTextSearch: Model + Sized {
fn search(columns: &[&str], query: &str) -> FullTextSearchBuilder<Self> {
FullTextSearchBuilder::new(columns, query)
}
fn search_with_config(
columns: &[&str],
query: &str,
config: FullTextConfig,
) -> FullTextSearchBuilder<Self> {
FullTextSearchBuilder::new(columns, query).config(config)
}
fn search_ranked(columns: &[&str], query: &str) -> FullTextSearchBuilder<Self> {
FullTextSearchBuilder::new(columns, query).with_ranking()
}
fn search_highlighted(
columns: &[&str],
query: &str,
start_tag: &str,
end_tag: &str,
) -> FullTextSearchBuilder<Self> {
FullTextSearchBuilder::new(columns, query).with_highlights(start_tag, end_tag)
}
}
impl<T: Model> FullTextSearch for T {}
pub struct FullTextSearchBuilder<T: Model> {
columns: Vec<String>,
query: String,
config: FullTextConfig,
with_ranking: bool,
highlight_config: Option<HighlightConfig>,
limit: Option<u64>,
offset: Option<u64>,
min_rank: Option<f64>,
_marker: PhantomData<T>,
}
#[derive(Debug, Clone)]
pub struct HighlightConfig {
pub start_tag: String,
pub end_tag: String,
pub max_length: Option<usize>,
pub fragment_words: Option<usize>,
}
impl Default for HighlightConfig {
fn default() -> Self {
Self {
start_tag: "<mark>".to_string(),
end_tag: "</mark>".to_string(),
max_length: None,
fragment_words: Some(10),
}
}
}
impl<T: Model> FullTextSearchBuilder<T> {
pub fn new(columns: &[&str], query: &str) -> Self {
Self {
columns: columns.iter().map(|s| s.to_string()).collect(),
query: query.to_string(),
config: FullTextConfig::default(),
with_ranking: false,
highlight_config: None,
limit: None,
offset: None,
min_rank: None,
_marker: PhantomData,
}
}
pub fn config(mut self, config: FullTextConfig) -> Self {
self.config = config;
self
}
pub fn with_ranking(mut self) -> Self {
self.with_ranking = true;
self
}
pub fn with_highlights(mut self, start_tag: &str, end_tag: &str) -> Self {
self.highlight_config = Some(HighlightConfig {
start_tag: start_tag.to_string(),
end_tag: end_tag.to_string(),
..Default::default()
});
self
}
pub fn highlight_config(mut self, config: HighlightConfig) -> Self {
self.highlight_config = Some(config);
self
}
pub fn limit(mut self, limit: u64) -> Self {
self.limit = Some(limit);
self
}
pub fn offset(mut self, offset: u64) -> Self {
self.offset = Some(offset);
self
}
pub fn min_rank(mut self, rank: f64) -> Self {
self.min_rank = Some(rank);
self
}
pub fn mode(mut self, mode: SearchMode) -> Self {
self.config.mode = mode;
self
}
pub fn language(mut self, lang: impl Into<String>) -> Self {
self.config.language = Some(lang.into());
self
}
pub async fn get(self) -> Result<Vec<T>>
where
T: FromQueryResult,
{
use crate::database::Connection;
let db = crate::database::__current_db()?;
let db_type = db.backend();
let (sql, params) = self.build_sql(db_type)?;
let backend = db.__internal_backend()?;
let statement = Statement::from_sql_and_values(backend, &sql, params);
let results = match db.__get_connection()? {
crate::database::ConnectionRef::Database(conn) => {
crate::profiling::__profile_future(conn.connection().query_all_raw(statement)).await
}
crate::database::ConnectionRef::Transaction(tx) => {
crate::profiling::__profile_future(tx.as_ref().query_all_raw(statement)).await
}
}
.map_err(|e| Error::query(e.to_string()))?;
let mut records = Vec::new();
for row in results {
if let Ok(record) = T::from_query_result(&row, "") {
records.push(record);
}
}
Ok(records)
}
pub async fn get_ranked(self) -> Result<Vec<SearchResult<T>>>
where
T: FromQueryResult,
{
use crate::database::Connection;
let db = crate::database::__current_db()?;
let db_type = db.backend();
let (sql, params) = self.build_ranked_sql(db_type)?;
let backend = db.__internal_backend()?;
let statement = Statement::from_sql_and_values(backend, &sql, params);
let results = match db.__get_connection()? {
crate::database::ConnectionRef::Database(conn) => {
crate::profiling::__profile_future(conn.connection().query_all_raw(statement)).await
}
crate::database::ConnectionRef::Transaction(tx) => {
crate::profiling::__profile_future(tx.as_ref().query_all_raw(statement)).await
}
}
.map_err(|e| Error::query(e.to_string()))?;
let mut records = Vec::new();
for row in results {
if let Ok(record) = T::from_query_result(&row, "") {
let rank = row.try_get::<f64>("", "_fts_rank").unwrap_or(0.0);
records.push(SearchResult::new(record, rank));
}
}
Ok(records)
}
pub async fn first(mut self) -> Result<Option<T>>
where
T: FromQueryResult,
{
self.limit = Some(1);
let results = self.get().await?;
Ok(results.into_iter().next())
}
pub async fn count(self) -> Result<u64> {
use crate::database::Connection;
let db = crate::database::__current_db()?;
let db_type = db.backend();
let (sql, params) = self.build_count_sql(db_type)?;
let backend = db.__internal_backend()?;
let statement = Statement::from_sql_and_values(backend, &sql, params);
let result = match db.__get_connection()? {
crate::database::ConnectionRef::Database(conn) => {
crate::profiling::__profile_future(conn.connection().query_one_raw(statement)).await
}
crate::database::ConnectionRef::Transaction(tx) => {
crate::profiling::__profile_future(tx.as_ref().query_one_raw(statement)).await
}
}
.map_err(|e| Error::query(e.to_string()))?;
if let Some(row) = result {
let count: i64 = row.try_get("", "count").unwrap_or(0);
crate::internal::count_to_u64(count, "fulltext count")
} else {
Ok(0)
}
}
pub(crate) fn build_sql(&self, db_type: DatabaseType) -> Result<(String, Vec<Value>)> {
match db_type {
DatabaseType::Postgres => self.build_postgres_sql(),
DatabaseType::MySQL | DatabaseType::MariaDB => self.build_mysql_sql(),
DatabaseType::SQLite => self.build_sqlite_sql(),
}
}
pub(crate) fn build_ranked_sql(&self, db_type: DatabaseType) -> Result<(String, Vec<Value>)> {
match db_type {
DatabaseType::Postgres => self.build_postgres_ranked_sql(),
DatabaseType::MySQL | DatabaseType::MariaDB => self.build_mysql_ranked_sql(),
DatabaseType::SQLite => self.build_sqlite_ranked_sql(),
}
}
pub(crate) fn build_count_sql(&self, db_type: DatabaseType) -> Result<(String, Vec<Value>)> {
match db_type {
DatabaseType::Postgres => self.build_postgres_count_sql(),
DatabaseType::MySQL | DatabaseType::MariaDB => self.build_mysql_count_sql(),
DatabaseType::SQLite => self.build_sqlite_count_sql(),
}
}
fn build_postgres_sql(&self) -> Result<(String, Vec<Value>)> {
let table = quote_ident(DatabaseType::Postgres, T::table_name());
let mut params = Vec::new();
let language_placeholder = self.push_param(
DatabaseType::Postgres,
&mut params,
Value::String(Some(
self.config
.language
.clone()
.unwrap_or_else(|| "english".to_string()),
)),
);
let tsvector_expr = self.build_pg_tsvector_expr(&language_placeholder);
let tsquery_expr = self.build_pg_tsquery_expr(&language_placeholder, &mut params);
let mut sql = format!(
"SELECT * FROM {} WHERE {} @@ {}",
table, tsvector_expr, tsquery_expr
);
if self.with_ranking {
let weights_placeholder = self.pg_weights_placeholder(&mut params);
sql = format!(
"SELECT *, ts_rank_cd(CAST({} AS real[]), {}, {}) AS _fts_rank FROM {} WHERE {} @@ {} ORDER BY _fts_rank DESC",
weights_placeholder,
tsvector_expr,
tsquery_expr,
table,
tsvector_expr,
tsquery_expr
);
}
self.append_limit_offset(DatabaseType::Postgres, &mut sql, &mut params)?;
Ok((sql, params))
}
fn build_postgres_ranked_sql(&self) -> Result<(String, Vec<Value>)> {
let table = quote_ident(DatabaseType::Postgres, T::table_name());
let mut params = Vec::new();
let language_placeholder = self.push_param(
DatabaseType::Postgres,
&mut params,
Value::String(Some(
self.config
.language
.clone()
.unwrap_or_else(|| "english".to_string()),
)),
);
let tsvector_expr = self.build_pg_tsvector_expr(&language_placeholder);
let tsquery_expr = self.build_pg_tsquery_expr(&language_placeholder, &mut params);
let weights_placeholder = self.pg_weights_placeholder(&mut params);
let mut sql = format!(
"SELECT *, ts_rank_cd(CAST({} AS real[]), {}, {}) AS _fts_rank FROM {} WHERE {} @@ {}",
weights_placeholder, tsvector_expr, tsquery_expr, table, tsvector_expr, tsquery_expr
);
if let Some(min_rank) = self.min_rank {
let min_rank_placeholder = self.push_param(
DatabaseType::Postgres,
&mut params,
Value::Double(Some(min_rank)),
);
sql.push_str(&format!(
" AND ts_rank_cd(CAST({} AS real[]), {}, {}) >= {}",
weights_placeholder, tsvector_expr, tsquery_expr, min_rank_placeholder
));
}
sql.push_str(" ORDER BY _fts_rank DESC");
self.append_limit_offset(DatabaseType::Postgres, &mut sql, &mut params)?;
Ok((sql, params))
}
fn build_postgres_count_sql(&self) -> Result<(String, Vec<Value>)> {
let table = quote_ident(DatabaseType::Postgres, T::table_name());
let mut params = Vec::new();
let language_placeholder = self.push_param(
DatabaseType::Postgres,
&mut params,
Value::String(Some(
self.config
.language
.clone()
.unwrap_or_else(|| "english".to_string()),
)),
);
let tsvector_expr = self.build_pg_tsvector_expr(&language_placeholder);
let tsquery_expr = self.build_pg_tsquery_expr(&language_placeholder, &mut params);
Ok((
format!(
"SELECT COUNT(*) as count FROM {} WHERE {} @@ {}",
table, tsvector_expr, tsquery_expr
),
params,
))
}
fn build_pg_tsvector_expr(&self, language_placeholder: &str) -> String {
if self.columns.len() == 1 {
format!(
"to_tsvector(CAST({} AS regconfig), COALESCE({}, ''))",
language_placeholder,
quote_ident(DatabaseType::Postgres, &self.columns[0])
)
} else {
let cols: Vec<String> = self
.columns
.iter()
.map(|c| format!("COALESCE({}, '')", quote_ident(DatabaseType::Postgres, c)))
.collect();
format!(
"to_tsvector(CAST({} AS regconfig), {})",
language_placeholder,
cols.join(" || ' ' || ")
)
}
}
fn build_pg_tsquery_expr(&self, language_placeholder: &str, params: &mut Vec<Value>) -> String {
match self.config.mode {
SearchMode::Natural => {
let placeholder = self.push_param(
DatabaseType::Postgres,
params,
Value::String(Some(self.query.clone())),
);
format!(
"plainto_tsquery(CAST({} AS regconfig), {})",
language_placeholder, placeholder
)
}
SearchMode::Boolean => {
let placeholder = self.push_param(
DatabaseType::Postgres,
params,
Value::String(Some(self.query.clone())),
);
format!(
"to_tsquery(CAST({} AS regconfig), {})",
language_placeholder, placeholder
)
}
SearchMode::Phrase => {
let placeholder = self.push_param(
DatabaseType::Postgres,
params,
Value::String(Some(self.query.clone())),
);
format!(
"phraseto_tsquery(CAST({} AS regconfig), {})",
language_placeholder, placeholder
)
}
SearchMode::Prefix => {
let words: Vec<&str> = self.query.split_whitespace().collect();
let prefixed: Vec<String> = words.iter().map(|w| format!("{}:*", w)).collect();
let placeholder = self.push_param(
DatabaseType::Postgres,
params,
Value::String(Some(prefixed.join(" & "))),
);
format!(
"to_tsquery(CAST({} AS regconfig), {})",
language_placeholder, placeholder
)
}
SearchMode::Fuzzy => {
let placeholder = self.push_param(
DatabaseType::Postgres,
params,
Value::String(Some(self.query.clone())),
);
format!(
"plainto_tsquery(CAST({} AS regconfig), {})",
language_placeholder, placeholder
)
}
SearchMode::Proximity(distance) => {
let words: Vec<&str> = self.query.split_whitespace().collect();
let proximity: Vec<String> = words.iter().map(|w| w.to_string()).collect();
let placeholder = self.push_param(
DatabaseType::Postgres,
params,
Value::String(Some(proximity.join(&format!(" <{}> ", distance)))),
);
format!(
"to_tsquery(CAST({} AS regconfig), {})",
language_placeholder, placeholder
)
}
}
}
fn build_mysql_sql(&self) -> Result<(String, Vec<Value>)> {
let table = quote_ident(DatabaseType::MySQL, T::table_name());
let mut params = Vec::new();
let columns_str = self
.columns
.iter()
.map(|c| quote_ident(DatabaseType::MySQL, c))
.collect::<Vec<_>>()
.join(", ");
let mode_modifier = match self.config.mode {
SearchMode::Natural => "",
SearchMode::Boolean => " IN BOOLEAN MODE",
SearchMode::Phrase => " WITH QUERY EXPANSION",
_ => "",
};
let query_placeholder = self.push_param(
DatabaseType::MySQL,
&mut params,
Value::String(Some(self.query.clone())),
);
let mut sql = format!(
"SELECT * FROM {} WHERE MATCH({}) AGAINST({}{}) ",
table, columns_str, query_placeholder, mode_modifier
);
self.append_limit_offset(DatabaseType::MySQL, &mut sql, &mut params)?;
Ok((sql, params))
}
fn build_mysql_ranked_sql(&self) -> Result<(String, Vec<Value>)> {
let table = quote_ident(DatabaseType::MySQL, T::table_name());
let mut params = Vec::new();
let columns_str = self
.columns
.iter()
.map(|c| quote_ident(DatabaseType::MySQL, c))
.collect::<Vec<_>>()
.join(", ");
let mode_modifier = match self.config.mode {
SearchMode::Natural => "",
SearchMode::Boolean => " IN BOOLEAN MODE",
SearchMode::Phrase => " WITH QUERY EXPANSION",
_ => "",
};
let rank_placeholder = self.push_param(
DatabaseType::MySQL,
&mut params,
Value::String(Some(self.query.clone())),
);
let where_placeholder = self.push_param(
DatabaseType::MySQL,
&mut params,
Value::String(Some(self.query.clone())),
);
let mut sql = format!(
"SELECT *, MATCH({}) AGAINST({}{}) AS _fts_rank FROM {} \
WHERE MATCH({}) AGAINST({}{}) ",
columns_str,
rank_placeholder,
mode_modifier,
table,
columns_str,
where_placeholder,
mode_modifier
);
if let Some(min_rank) = self.min_rank {
let min_rank_placeholder = self.push_param(
DatabaseType::MySQL,
&mut params,
Value::Double(Some(min_rank)),
);
let against_placeholder = self.push_param(
DatabaseType::MySQL,
&mut params,
Value::String(Some(self.query.clone())),
);
sql.push_str(&format!(
"AND MATCH({}) AGAINST({}{}) >= {} ",
columns_str, against_placeholder, mode_modifier, min_rank_placeholder
));
}
sql.push_str("ORDER BY _fts_rank DESC ");
self.append_limit_offset(DatabaseType::MySQL, &mut sql, &mut params)?;
Ok((sql, params))
}
fn build_mysql_count_sql(&self) -> Result<(String, Vec<Value>)> {
let table = quote_ident(DatabaseType::MySQL, T::table_name());
let mut params = Vec::new();
let columns_str = self
.columns
.iter()
.map(|c| quote_ident(DatabaseType::MySQL, c))
.collect::<Vec<_>>()
.join(", ");
let mode_modifier = match self.config.mode {
SearchMode::Natural => "",
SearchMode::Boolean => " IN BOOLEAN MODE",
_ => "",
};
let query_placeholder = self.push_param(
DatabaseType::MySQL,
&mut params,
Value::String(Some(self.query.clone())),
);
Ok((
format!(
"SELECT COUNT(*) as count FROM {} WHERE MATCH({}) AGAINST({}{})",
table, columns_str, query_placeholder, mode_modifier
),
params,
))
}
fn build_sqlite_sql(&self) -> Result<(String, Vec<Value>)> {
let table_name = T::table_name();
let table = quote_ident(DatabaseType::SQLite, table_name);
let fts_table_name = format!("{}_fts", table_name);
let fts_table = quote_ident(DatabaseType::SQLite, &fts_table_name);
let mut params = Vec::new();
let query_placeholder = self.push_param(
DatabaseType::SQLite,
&mut params,
Value::String(Some(escape_fts5_query(&self.query))),
);
let mut sql = format!(
"SELECT t.* FROM {} t \
INNER JOIN {} fts ON t.rowid = fts.rowid \
WHERE {} MATCH {} ",
table, fts_table, fts_table, query_placeholder
);
self.append_limit_offset(DatabaseType::SQLite, &mut sql, &mut params)?;
Ok((sql, params))
}
fn build_sqlite_ranked_sql(&self) -> Result<(String, Vec<Value>)> {
let table_name = T::table_name();
let table = quote_ident(DatabaseType::SQLite, table_name);
let fts_table_name = format!("{}_fts", table_name);
let fts_table = quote_ident(DatabaseType::SQLite, &fts_table_name);
let mut params = Vec::new();
let query_placeholder = self.push_param(
DatabaseType::SQLite,
&mut params,
Value::String(Some(escape_fts5_query(&self.query))),
);
let mut sql = format!(
"SELECT t.*, bm25({}) AS _fts_rank FROM {} t \
INNER JOIN {} fts ON t.rowid = fts.rowid \
WHERE {} MATCH {} ",
fts_table, table, fts_table, fts_table, query_placeholder
);
if let Some(min_rank) = self.min_rank {
let min_rank_placeholder = self.push_param(
DatabaseType::SQLite,
&mut params,
Value::Double(Some(-min_rank)),
);
sql.push_str(&format!(
"AND bm25({}) <= {} ",
fts_table, min_rank_placeholder
));
}
sql.push_str(&format!("ORDER BY bm25({}) ", fts_table));
self.append_limit_offset(DatabaseType::SQLite, &mut sql, &mut params)?;
Ok((sql, params))
}
fn build_sqlite_count_sql(&self) -> Result<(String, Vec<Value>)> {
let table_name = T::table_name();
let table = quote_ident(DatabaseType::SQLite, table_name);
let fts_table_name = format!("{}_fts", table_name);
let fts_table = quote_ident(DatabaseType::SQLite, &fts_table_name);
let mut params = Vec::new();
let query_placeholder = self.push_param(
DatabaseType::SQLite,
&mut params,
Value::String(Some(escape_fts5_query(&self.query))),
);
Ok((
format!(
"SELECT COUNT(*) as count FROM {} t \
INNER JOIN {} fts ON t.rowid = fts.rowid \
WHERE {} MATCH {}",
table, fts_table, fts_table, query_placeholder
),
params,
))
}
fn push_param(&self, db_type: DatabaseType, params: &mut Vec<Value>, value: Value) -> String {
let placeholder = match db_type {
DatabaseType::Postgres => format!("${}", params.len() + 1),
DatabaseType::MySQL | DatabaseType::MariaDB | DatabaseType::SQLite => "?".to_string(),
};
params.push(value);
placeholder
}
fn pg_weights_placeholder(&self, params: &mut Vec<Value>) -> String {
let weights = self
.config
.weights
.as_ref()
.map(|w| w.to_pg_array().trim_matches('\'').to_string())
.unwrap_or_else(|| "{0.1,0.2,0.4,1.0}".to_string());
self.push_param(DatabaseType::Postgres, params, Value::String(Some(weights)))
}
fn append_limit_offset(
&self,
db_type: DatabaseType,
sql: &mut String,
params: &mut Vec<Value>,
) -> Result<()> {
if let Some(limit) = self.limit {
let limit_value = i64::try_from(limit)
.map_err(|_| Error::query("Full-text search limit exceeds i64 range"))?;
let placeholder = self.push_param(db_type, params, Value::BigInt(Some(limit_value)));
sql.push_str(&format!(" LIMIT {}", placeholder));
}
if let Some(offset) = self.offset {
let offset_value = i64::try_from(offset)
.map_err(|_| Error::query("Full-text search offset exceeds i64 range"))?;
let placeholder = self.push_param(db_type, params, Value::BigInt(Some(offset_value)));
sql.push_str(&format!(" OFFSET {}", placeholder));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct FullTextIndex {
pub name: String,
pub table: String,
pub columns: Vec<String>,
pub config: FullTextIndexConfig,
}
#[derive(Debug, Clone, Default)]
pub struct FullTextIndexConfig {
pub language: Option<String>,
pub pg_index_type: PgFullTextIndexType,
pub mysql_parser: Option<String>,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum PgFullTextIndexType {
#[default]
GIN,
GiST,
}
impl FullTextIndex {
pub fn new(name: impl Into<String>, table: impl Into<String>, columns: Vec<String>) -> Self {
Self {
name: name.into(),
table: table.into(),
columns,
config: FullTextIndexConfig::default(),
}
}
pub fn language(mut self, lang: impl Into<String>) -> Self {
self.config.language = Some(lang.into());
self
}
pub fn pg_index_type(mut self, index_type: PgFullTextIndexType) -> Self {
self.config.pg_index_type = index_type;
self
}
pub fn to_postgres_sql(&self) -> String {
let language = self.config.language.as_deref().unwrap_or("english");
let index_type = match self.config.pg_index_type {
PgFullTextIndexType::GIN => "GIN",
PgFullTextIndexType::GiST => "GiST",
};
let tsvector_expr = if self.columns.len() == 1 {
format!(
"to_tsvector('{}', COALESCE({}, ''))",
language,
quote_ident(DatabaseType::Postgres, &self.columns[0])
)
} else {
let cols: Vec<String> = self
.columns
.iter()
.map(|c| format!("COALESCE({}, '')", quote_ident(DatabaseType::Postgres, c)))
.collect();
format!("to_tsvector('{}', {})", language, cols.join(" || ' ' || "))
};
format!(
"CREATE INDEX {} ON {} USING {} (({}))",
quote_ident(DatabaseType::Postgres, &self.name),
quote_ident(DatabaseType::Postgres, &self.table),
index_type,
tsvector_expr
)
}
pub fn to_mysql_sql(&self) -> String {
let columns_str = self
.columns
.iter()
.map(|c| quote_ident(DatabaseType::MySQL, c))
.collect::<Vec<_>>()
.join(", ");
let parser = self
.config
.mysql_parser
.as_ref()
.map(|p| format!(" WITH PARSER {}", p))
.unwrap_or_default();
format!(
"CREATE FULLTEXT INDEX {} ON {}({}){}",
quote_ident(DatabaseType::MySQL, &self.name),
quote_ident(DatabaseType::MySQL, &self.table),
columns_str,
parser
)
}
pub fn to_sqlite_sql(&self) -> Vec<String> {
let fts_table = format!("{}_fts", self.table);
let columns_str = self
.columns
.iter()
.map(|column| quote_ident(DatabaseType::SQLite, column))
.collect::<Vec<_>>()
.join(", ");
vec![
format!(
"CREATE VIRTUAL TABLE IF NOT EXISTS {} USING fts5({}, content={}, content_rowid={})",
quote_ident(DatabaseType::SQLite, &fts_table),
columns_str,
quote_ident(DatabaseType::SQLite, &self.table),
quote_ident(DatabaseType::SQLite, "rowid")
),
format!(
"CREATE TRIGGER IF NOT EXISTS {} AFTER INSERT ON {} BEGIN \
INSERT INTO \"{}\"(rowid, {}) VALUES (new.rowid, {}); \
END",
quote_ident(DatabaseType::SQLite, &format!("{}_ai", self.table)),
quote_ident(DatabaseType::SQLite, &self.table),
quote_ident(DatabaseType::SQLite, &fts_table),
columns_str,
self.columns
.iter()
.map(|c| format!("new.{}", quote_ident(DatabaseType::SQLite, c)))
.collect::<Vec<_>>()
.join(", ")
),
format!(
"CREATE TRIGGER IF NOT EXISTS {} AFTER DELETE ON {} BEGIN \
INSERT INTO {}({}, rowid, {}) VALUES('delete', old.rowid, {}); \
END",
quote_ident(DatabaseType::SQLite, &format!("{}_ad", self.table)),
quote_ident(DatabaseType::SQLite, &self.table),
quote_ident(DatabaseType::SQLite, &fts_table),
quote_ident(DatabaseType::SQLite, &fts_table),
columns_str,
self.columns
.iter()
.map(|c| format!("old.{}", quote_ident(DatabaseType::SQLite, c)))
.collect::<Vec<_>>()
.join(", ")
),
format!(
"CREATE TRIGGER IF NOT EXISTS {} AFTER UPDATE ON {} BEGIN \
INSERT INTO {}({}, rowid, {}) VALUES('delete', old.rowid, {}); \
INSERT INTO {}(rowid, {}) VALUES (new.rowid, {}); \
END",
quote_ident(DatabaseType::SQLite, &format!("{}_au", self.table)),
quote_ident(DatabaseType::SQLite, &self.table),
quote_ident(DatabaseType::SQLite, &fts_table),
quote_ident(DatabaseType::SQLite, &fts_table),
columns_str,
self.columns
.iter()
.map(|c| format!("old.{}", quote_ident(DatabaseType::SQLite, c)))
.collect::<Vec<_>>()
.join(", "),
quote_ident(DatabaseType::SQLite, &fts_table),
columns_str,
self.columns
.iter()
.map(|c| format!("new.{}", quote_ident(DatabaseType::SQLite, c)))
.collect::<Vec<_>>()
.join(", ")
),
]
}
pub fn to_sql(&self, db_type: DatabaseType) -> Vec<String> {
match db_type {
DatabaseType::Postgres => vec![self.to_postgres_sql()],
DatabaseType::MySQL | DatabaseType::MariaDB => vec![self.to_mysql_sql()],
DatabaseType::SQLite => self.to_sqlite_sql(),
}
}
}
pub fn highlight_text(text: &str, query: &str, start_tag: &str, end_tag: &str) -> String {
let words: Vec<&str> = query.split_whitespace().collect();
let mut result = text.to_string();
let patterns: Vec<regex::Regex> = words
.iter()
.filter_map(|word| regex::Regex::new(&format!(r"(?i)\b{}\b", regex::escape(word))).ok())
.collect();
for pattern in &patterns {
result = pattern
.replace_all(&result, |caps: ®ex::Captures| {
format!("{}{}{}", start_tag, &caps[0], end_tag)
})
.to_string();
}
result
}
pub fn generate_snippet(
text: &str,
query: &str,
fragment_words: usize,
start_tag: &str,
end_tag: &str,
) -> String {
let words: Vec<&str> = text.split_whitespace().collect();
let query_words_owned: Vec<String> =
query.split_whitespace().map(|w| w.to_lowercase()).collect();
let mut match_pos = None;
for (i, word) in words.iter().enumerate() {
let word_lower = word.to_lowercase();
if query_words_owned.iter().any(|q| word_lower.contains(q)) {
match_pos = Some(i);
break;
}
}
if let Some(pos) = match_pos {
let start = pos.saturating_sub(fragment_words);
let end = (pos + fragment_words).min(words.len());
let snippet_words: Vec<String> = words[start..end]
.iter()
.map(|w| {
let word_lower = w.to_lowercase();
if query_words_owned.iter().any(|q| word_lower.contains(q)) {
format!("{}{}{}", start_tag, w, end_tag)
} else {
w.to_string()
}
})
.collect();
let mut snippet = snippet_words.join(" ");
if start > 0 {
snippet = format!("...{}", snippet);
}
if end < words.len() {
snippet = format!("{}...", snippet);
}
snippet
} else {
let end = fragment_words.min(words.len());
let snippet = words[..end].join(" ");
if end < words.len() {
format!("{}...", snippet)
} else {
snippet
}
}
}
pub fn pg_headline_sql(
column: &str,
query: &str,
language: &str,
start_tag: &str,
end_tag: &str,
) -> String {
format!(
"ts_headline('{}', \"{}\", plainto_tsquery('{}', '{}'), \
'StartSel={}, StopSel={}, MaxWords=35, MinWords=15')",
language,
column,
language,
escape_string(query),
start_tag,
end_tag
)
}
fn escape_string(s: &str) -> String {
s.replace('\'', "''").replace('\\', "\\\\")
}
fn escape_fts5_query(s: &str) -> String {
s.replace('"', "\"\"").replace('\'', "''")
}
fn quote_ident(db_type: DatabaseType, name: &str) -> String {
let quote = match db_type {
DatabaseType::Postgres | DatabaseType::SQLite => '"',
DatabaseType::MySQL | DatabaseType::MariaDB => '`',
};
let escaped = name.replace(quote, &format!("{quote}{quote}"));
format!("{}{}{}", quote, escaped, quote)
}
#[cfg(test)]
#[path = "testing/fulltext_tests.rs"]
mod tests;