use std::fmt;
use std::marker::PhantomData;
use crate::config::DatabaseType;
use crate::error::{Error, Result};
use crate::model::Model;
use crate::internal::{
ConnectionTrait, Statement, FromQueryResult,
};
use crate::database::try_db;
#[derive(Debug, Clone, Default)]
pub struct FullTextConfig {
pub language: Option<String>,
pub mode: SearchMode,
pub min_word_length: Option<u32>,
pub max_word_length: Option<u32>,
pub stop_words: Vec<String>,
pub weights: Option<SearchWeights>,
}
impl FullTextConfig {
pub fn new() -> Self {
Self::default()
}
pub fn language(mut self, lang: impl Into<String>) -> Self {
self.language = Some(lang.into());
self
}
pub fn mode(mut self, mode: SearchMode) -> Self {
self.mode = mode;
self
}
pub fn min_word_length(mut self, len: u32) -> Self {
self.min_word_length = Some(len);
self
}
pub fn max_word_length(mut self, len: u32) -> Self {
self.max_word_length = Some(len);
self
}
pub fn stop_words(mut self, words: Vec<String>) -> Self {
self.stop_words = words;
self
}
pub fn weights(mut self, weights: SearchWeights) -> Self {
self.weights = Some(weights);
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SearchMode {
#[default]
Natural,
Boolean,
Phrase,
Prefix,
Fuzzy,
Proximity(u32),
}
impl fmt::Display for SearchMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SearchMode::Natural => write!(f, "natural"),
SearchMode::Boolean => write!(f, "boolean"),
SearchMode::Phrase => write!(f, "phrase"),
SearchMode::Prefix => write!(f, "prefix"),
SearchMode::Fuzzy => write!(f, "fuzzy"),
SearchMode::Proximity(d) => write!(f, "proximity({})", d),
}
}
}
#[derive(Debug, Clone)]
pub struct SearchWeights {
pub a: f32,
pub b: f32,
pub c: f32,
pub d: f32,
}
impl Default for SearchWeights {
fn default() -> Self {
Self {
a: 1.0,
b: 0.4,
c: 0.2,
d: 0.1,
}
}
}
impl SearchWeights {
pub fn new(a: f32, b: f32, c: f32, d: f32) -> Self {
Self { a, b, c, d }
}
pub fn to_pg_array(&self) -> String {
format!("'{{{},{},{},{}}}'", self.d, self.c, self.b, self.a)
}
}
#[derive(Debug, Clone)]
pub struct SearchResult<T> {
pub record: T,
pub rank: f64,
pub highlights: Vec<HighlightedField>,
}
impl<T> SearchResult<T> {
pub fn new(record: T, rank: f64) -> Self {
Self {
record,
rank,
highlights: Vec::new(),
}
}
pub fn with_highlights(mut self, highlights: Vec<HighlightedField>) -> Self {
self.highlights = highlights;
self
}
}
#[derive(Debug, Clone)]
pub struct HighlightedField {
pub field: String,
pub highlighted: String,
pub original: String,
pub match_count: usize,
}
impl HighlightedField {
pub fn new(field: impl Into<String>, highlighted: impl Into<String>, original: impl Into<String>) -> Self {
let highlighted = highlighted.into();
let original = original.into();
let match_count = highlighted.matches("<mark>").count();
Self {
field: field.into(),
highlighted,
original,
match_count,
}
}
}
pub trait FullTextSearch: Model + Sized {
fn search(
columns: &[&str],
query: &str,
) -> FullTextSearchBuilder<Self> {
FullTextSearchBuilder::new(columns, query)
}
fn search_with_config(
columns: &[&str],
query: &str,
config: FullTextConfig,
) -> FullTextSearchBuilder<Self> {
FullTextSearchBuilder::new(columns, query).config(config)
}
fn search_ranked(
columns: &[&str],
query: &str,
) -> FullTextSearchBuilder<Self> {
FullTextSearchBuilder::new(columns, query).with_ranking()
}
fn search_highlighted(
columns: &[&str],
query: &str,
start_tag: &str,
end_tag: &str,
) -> FullTextSearchBuilder<Self> {
FullTextSearchBuilder::new(columns, query)
.with_highlights(start_tag, end_tag)
}
}
impl<T: Model> FullTextSearch for T {}
pub struct FullTextSearchBuilder<T: Model> {
columns: Vec<String>,
query: String,
config: FullTextConfig,
with_ranking: bool,
highlight_config: Option<HighlightConfig>,
limit: Option<u64>,
offset: Option<u64>,
min_rank: Option<f64>,
_marker: PhantomData<T>,
}
#[derive(Debug, Clone)]
pub struct HighlightConfig {
pub start_tag: String,
pub end_tag: String,
pub max_length: Option<usize>,
pub fragment_words: Option<usize>,
}
impl Default for HighlightConfig {
fn default() -> Self {
Self {
start_tag: "<mark>".to_string(),
end_tag: "</mark>".to_string(),
max_length: None,
fragment_words: Some(10),
}
}
}
impl<T: Model> FullTextSearchBuilder<T> {
pub fn new(columns: &[&str], query: &str) -> Self {
Self {
columns: columns.iter().map(|s| s.to_string()).collect(),
query: query.to_string(),
config: FullTextConfig::default(),
with_ranking: false,
highlight_config: None,
limit: None,
offset: None,
min_rank: None,
_marker: PhantomData,
}
}
pub fn config(mut self, config: FullTextConfig) -> Self {
self.config = config;
self
}
pub fn with_ranking(mut self) -> Self {
self.with_ranking = true;
self
}
pub fn with_highlights(mut self, start_tag: &str, end_tag: &str) -> Self {
self.highlight_config = Some(HighlightConfig {
start_tag: start_tag.to_string(),
end_tag: end_tag.to_string(),
..Default::default()
});
self
}
pub fn highlight_config(mut self, config: HighlightConfig) -> Self {
self.highlight_config = Some(config);
self
}
pub fn limit(mut self, limit: u64) -> Self {
self.limit = Some(limit);
self
}
pub fn offset(mut self, offset: u64) -> Self {
self.offset = Some(offset);
self
}
pub fn min_rank(mut self, rank: f64) -> Self {
self.min_rank = Some(rank);
self
}
pub fn mode(mut self, mode: SearchMode) -> Self {
self.config.mode = mode;
self
}
pub fn language(mut self, lang: impl Into<String>) -> Self {
self.config.language = Some(lang.into());
self
}
pub async fn get(self) -> Result<Vec<T>>
where
T: FromQueryResult,
{
let db = try_db().ok_or_else(|| Error::connection("Database not initialized"))?;
let db_type = db.backend();
let sql = self.build_sql(db_type)?;
let backend = db.__internal_backend();
let statement = Statement::from_string(
backend,
sql,
);
let results = db.__internal_connection()
.query_all_raw(statement)
.await
.map_err(|e| Error::query(e.to_string()))?;
let mut records = Vec::new();
for row in results {
if let Ok(record) = T::from_query_result(&row, "") {
records.push(record);
}
}
Ok(records)
}
pub async fn get_ranked(self) -> Result<Vec<SearchResult<T>>>
where
T: FromQueryResult,
{
let db = try_db().ok_or_else(|| Error::connection("Database not initialized"))?;
let db_type = db.backend();
let sql = self.build_ranked_sql(db_type)?;
let backend = db.__internal_backend();
let statement = Statement::from_string(
backend,
sql,
);
let results = db.__internal_connection()
.query_all_raw(statement)
.await
.map_err(|e| Error::query(e.to_string()))?;
let mut records = Vec::new();
for row in results {
if let Ok(record) = T::from_query_result(&row, "") {
let rank = row.try_get::<f64>("", "_fts_rank")
.unwrap_or(0.0);
records.push(SearchResult::new(record, rank));
}
}
Ok(records)
}
pub async fn first(mut self) -> Result<Option<T>>
where
T: FromQueryResult,
{
self.limit = Some(1);
let results = self.get().await?;
Ok(results.into_iter().next())
}
pub async fn count(self) -> Result<u64> {
let db = try_db().ok_or_else(|| Error::connection("Database not initialized"))?;
let db_type = db.backend();
let sql = self.build_count_sql(db_type)?;
let backend = db.__internal_backend();
let statement = Statement::from_string(
backend,
sql,
);
let result = db.__internal_connection()
.query_one_raw(statement)
.await
.map_err(|e| Error::query(e.to_string()))?;
if let Some(row) = result {
let count: i64 = row.try_get("", "count")
.unwrap_or(0);
Ok(count as u64)
} else {
Ok(0)
}
}
fn build_sql(&self, db_type: DatabaseType) -> Result<String> {
match db_type {
DatabaseType::Postgres => self.build_postgres_sql(),
DatabaseType::MySQL | DatabaseType::MariaDB => self.build_mysql_sql(),
DatabaseType::SQLite => self.build_sqlite_sql(),
}
}
fn build_ranked_sql(&self, db_type: DatabaseType) -> Result<String> {
match db_type {
DatabaseType::Postgres => self.build_postgres_ranked_sql(),
DatabaseType::MySQL | DatabaseType::MariaDB => self.build_mysql_ranked_sql(),
DatabaseType::SQLite => self.build_sqlite_ranked_sql(),
}
}
fn build_count_sql(&self, db_type: DatabaseType) -> Result<String> {
match db_type {
DatabaseType::Postgres => self.build_postgres_count_sql(),
DatabaseType::MySQL | DatabaseType::MariaDB => self.build_mysql_count_sql(),
DatabaseType::SQLite => self.build_sqlite_count_sql(),
}
}
fn build_postgres_sql(&self) -> Result<String> {
let table = T::table_name();
let escaped_query = escape_string(&self.query);
let language = self.config.language.as_deref().unwrap_or("english");
let tsvector_expr = self.build_pg_tsvector_expr(language);
let tsquery_expr = self.build_pg_tsquery_expr(language, &escaped_query);
let mut sql = format!(
"SELECT * FROM \"{}\" WHERE {} @@ {}",
table, tsvector_expr, tsquery_expr
);
if self.with_ranking {
let weights = self.config.weights.as_ref()
.map(|w| w.to_pg_array())
.unwrap_or_else(|| "'{0.1,0.2,0.4,1.0}'".to_string());
sql = format!(
"SELECT *, ts_rank_cd({}, {}, {}) AS _fts_rank FROM \"{}\" WHERE {} @@ {} ORDER BY _fts_rank DESC",
weights, tsvector_expr, tsquery_expr, table, tsvector_expr, tsquery_expr
);
}
if let Some(limit) = self.limit {
sql.push_str(&format!(" LIMIT {}", limit));
}
if let Some(offset) = self.offset {
sql.push_str(&format!(" OFFSET {}", offset));
}
Ok(sql)
}
fn build_postgres_ranked_sql(&self) -> Result<String> {
let table = T::table_name();
let escaped_query = escape_string(&self.query);
let language = self.config.language.as_deref().unwrap_or("english");
let tsvector_expr = self.build_pg_tsvector_expr(language);
let tsquery_expr = self.build_pg_tsquery_expr(language, &escaped_query);
let weights = self.config.weights.as_ref()
.map(|w| w.to_pg_array())
.unwrap_or_else(|| "'{0.1,0.2,0.4,1.0}'".to_string());
let mut sql = format!(
"SELECT *, ts_rank_cd({}, {}, {}) AS _fts_rank FROM \"{}\" WHERE {} @@ {}",
weights, tsvector_expr, tsquery_expr, table, tsvector_expr, tsquery_expr
);
if let Some(min_rank) = self.min_rank {
sql.push_str(&format!(" AND ts_rank_cd({}, {}, {}) >= {}",
weights, tsvector_expr, tsquery_expr, min_rank));
}
sql.push_str(" ORDER BY _fts_rank DESC");
if let Some(limit) = self.limit {
sql.push_str(&format!(" LIMIT {}", limit));
}
if let Some(offset) = self.offset {
sql.push_str(&format!(" OFFSET {}", offset));
}
Ok(sql)
}
fn build_postgres_count_sql(&self) -> Result<String> {
let table = T::table_name();
let escaped_query = escape_string(&self.query);
let language = self.config.language.as_deref().unwrap_or("english");
let tsvector_expr = self.build_pg_tsvector_expr(language);
let tsquery_expr = self.build_pg_tsquery_expr(language, &escaped_query);
Ok(format!(
"SELECT COUNT(*) as count FROM \"{}\" WHERE {} @@ {}",
table, tsvector_expr, tsquery_expr
))
}
fn build_pg_tsvector_expr(&self, language: &str) -> String {
if self.columns.len() == 1 {
format!("to_tsvector('{}', COALESCE(\"{}\", ''))", language, self.columns[0])
} else {
let cols: Vec<String> = self.columns.iter()
.map(|c| format!("COALESCE(\"{}\", '')", c))
.collect();
format!("to_tsvector('{}', {})", language, cols.join(" || ' ' || "))
}
}
fn build_pg_tsquery_expr(&self, language: &str, query: &str) -> String {
match self.config.mode {
SearchMode::Natural => {
format!("plainto_tsquery('{}', '{}')", language, query)
}
SearchMode::Boolean => {
format!("to_tsquery('{}', '{}')", language, query)
}
SearchMode::Phrase => {
format!("phraseto_tsquery('{}', '{}')", language, query)
}
SearchMode::Prefix => {
let words: Vec<&str> = query.split_whitespace().collect();
let prefixed: Vec<String> = words.iter()
.map(|w| format!("{}:*", w))
.collect();
format!("to_tsquery('{}', '{}')", language, prefixed.join(" & "))
}
SearchMode::Fuzzy => {
format!("plainto_tsquery('{}', '{}')", language, query)
}
SearchMode::Proximity(distance) => {
let words: Vec<&str> = query.split_whitespace().collect();
let proximity: Vec<String> = words.iter()
.map(|w| w.to_string())
.collect();
format!("to_tsquery('{}', '{}')", language, proximity.join(&format!(" <{}> ", distance)))
}
}
}
fn build_mysql_sql(&self) -> Result<String> {
let table = T::table_name();
let escaped_query = escape_string(&self.query);
let columns_str = self.columns.iter()
.map(|c| format!("`{}`", c))
.collect::<Vec<_>>()
.join(", ");
let mode_modifier = match self.config.mode {
SearchMode::Natural => "",
SearchMode::Boolean => " IN BOOLEAN MODE",
SearchMode::Phrase => " WITH QUERY EXPANSION",
_ => "",
};
let mut sql = format!(
"SELECT * FROM `{}` WHERE MATCH({}) AGAINST('{}'{}) ",
table, columns_str, escaped_query, mode_modifier
);
if let Some(limit) = self.limit {
sql.push_str(&format!("LIMIT {} ", limit));
}
if let Some(offset) = self.offset {
sql.push_str(&format!("OFFSET {} ", offset));
}
Ok(sql)
}
fn build_mysql_ranked_sql(&self) -> Result<String> {
let table = T::table_name();
let escaped_query = escape_string(&self.query);
let columns_str = self.columns.iter()
.map(|c| format!("`{}`", c))
.collect::<Vec<_>>()
.join(", ");
let mode_modifier = match self.config.mode {
SearchMode::Natural => "",
SearchMode::Boolean => " IN BOOLEAN MODE",
SearchMode::Phrase => " WITH QUERY EXPANSION",
_ => "",
};
let mut sql = format!(
"SELECT *, MATCH({}) AGAINST('{}'{}) AS _fts_rank FROM `{}` \
WHERE MATCH({}) AGAINST('{}'{}) ",
columns_str, escaped_query, mode_modifier,
table,
columns_str, escaped_query, mode_modifier
);
if let Some(min_rank) = self.min_rank {
sql.push_str(&format!("AND MATCH({}) AGAINST('{}'{}) >= {} ",
columns_str, escaped_query, mode_modifier, min_rank));
}
sql.push_str("ORDER BY _fts_rank DESC ");
if let Some(limit) = self.limit {
sql.push_str(&format!("LIMIT {} ", limit));
}
if let Some(offset) = self.offset {
sql.push_str(&format!("OFFSET {} ", offset));
}
Ok(sql)
}
fn build_mysql_count_sql(&self) -> Result<String> {
let table = T::table_name();
let escaped_query = escape_string(&self.query);
let columns_str = self.columns.iter()
.map(|c| format!("`{}`", c))
.collect::<Vec<_>>()
.join(", ");
let mode_modifier = match self.config.mode {
SearchMode::Natural => "",
SearchMode::Boolean => " IN BOOLEAN MODE",
_ => "",
};
Ok(format!(
"SELECT COUNT(*) as count FROM `{}` WHERE MATCH({}) AGAINST('{}'{})",
table, columns_str, escaped_query, mode_modifier
))
}
fn build_sqlite_sql(&self) -> Result<String> {
let table = T::table_name();
let fts_table = format!("{}_fts", table);
let escaped_query = escape_fts5_query(&self.query);
let mut sql = format!(
"SELECT t.* FROM \"{}\" t \
INNER JOIN \"{}\" fts ON t.rowid = fts.rowid \
WHERE \"{}\" MATCH '{}' ",
table, fts_table, fts_table, escaped_query
);
if let Some(limit) = self.limit {
sql.push_str(&format!("LIMIT {} ", limit));
}
if let Some(offset) = self.offset {
sql.push_str(&format!("OFFSET {} ", offset));
}
Ok(sql)
}
fn build_sqlite_ranked_sql(&self) -> Result<String> {
let table = T::table_name();
let fts_table = format!("{}_fts", table);
let escaped_query = escape_fts5_query(&self.query);
let mut sql = format!(
"SELECT t.*, bm25(\"{}\") AS _fts_rank FROM \"{}\" t \
INNER JOIN \"{}\" fts ON t.rowid = fts.rowid \
WHERE \"{}\" MATCH '{}' ",
fts_table, table, fts_table, fts_table, escaped_query
);
if let Some(min_rank) = self.min_rank {
sql.push_str(&format!("AND bm25(\"{}\") <= {} ", fts_table, -min_rank));
}
sql.push_str(&format!("ORDER BY bm25(\"{}\") ", fts_table));
if let Some(limit) = self.limit {
sql.push_str(&format!("LIMIT {} ", limit));
}
if let Some(offset) = self.offset {
sql.push_str(&format!("OFFSET {} ", offset));
}
Ok(sql)
}
fn build_sqlite_count_sql(&self) -> Result<String> {
let table = T::table_name();
let fts_table = format!("{}_fts", table);
let escaped_query = escape_fts5_query(&self.query);
Ok(format!(
"SELECT COUNT(*) as count FROM \"{}\" t \
INNER JOIN \"{}\" fts ON t.rowid = fts.rowid \
WHERE \"{}\" MATCH '{}'",
table, fts_table, fts_table, escaped_query
))
}
}
#[derive(Debug, Clone)]
pub struct FullTextIndex {
pub name: String,
pub table: String,
pub columns: Vec<String>,
pub config: FullTextIndexConfig,
}
#[derive(Debug, Clone, Default)]
pub struct FullTextIndexConfig {
pub language: Option<String>,
pub pg_index_type: PgFullTextIndexType,
pub mysql_parser: Option<String>,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum PgFullTextIndexType {
#[default]
GIN,
GiST,
}
impl FullTextIndex {
pub fn new(name: impl Into<String>, table: impl Into<String>, columns: Vec<String>) -> Self {
Self {
name: name.into(),
table: table.into(),
columns,
config: FullTextIndexConfig::default(),
}
}
pub fn language(mut self, lang: impl Into<String>) -> Self {
self.config.language = Some(lang.into());
self
}
pub fn pg_index_type(mut self, index_type: PgFullTextIndexType) -> Self {
self.config.pg_index_type = index_type;
self
}
pub fn to_postgres_sql(&self) -> String {
let language = self.config.language.as_deref().unwrap_or("english");
let index_type = match self.config.pg_index_type {
PgFullTextIndexType::GIN => "GIN",
PgFullTextIndexType::GiST => "GiST",
};
let tsvector_expr = if self.columns.len() == 1 {
format!("to_tsvector('{}', COALESCE(\"{}\", ''))", language, self.columns[0])
} else {
let cols: Vec<String> = self.columns.iter()
.map(|c| format!("COALESCE(\"{}\", '')", c))
.collect();
format!("to_tsvector('{}', {})", language, cols.join(" || ' ' || "))
};
format!(
"CREATE INDEX \"{}\" ON \"{}\" USING {} (({}))",
self.name, self.table, index_type, tsvector_expr
)
}
pub fn to_mysql_sql(&self) -> String {
let columns_str = self.columns.iter()
.map(|c| format!("`{}`", c))
.collect::<Vec<_>>()
.join(", ");
let parser = self.config.mysql_parser.as_ref()
.map(|p| format!(" WITH PARSER {}", p))
.unwrap_or_default();
format!(
"CREATE FULLTEXT INDEX `{}` ON `{}`({}){}",
self.name, self.table, columns_str, parser
)
}
pub fn to_sqlite_sql(&self) -> Vec<String> {
let fts_table = format!("{}_fts", self.table);
let columns_str = self.columns.join(", ");
vec![
format!(
"CREATE VIRTUAL TABLE IF NOT EXISTS \"{}\" USING fts5({}, content=\"{}\", content_rowid=\"rowid\")",
fts_table, columns_str, self.table
),
format!(
"CREATE TRIGGER IF NOT EXISTS \"{}_ai\" AFTER INSERT ON \"{}\" BEGIN \
INSERT INTO \"{}\"(rowid, {}) VALUES (new.rowid, {}); \
END",
self.table, self.table, fts_table, columns_str,
self.columns.iter().map(|c| format!("new.\"{}\"", c)).collect::<Vec<_>>().join(", ")
),
format!(
"CREATE TRIGGER IF NOT EXISTS \"{}_ad\" AFTER DELETE ON \"{}\" BEGIN \
INSERT INTO \"{}\"(\"{}\", rowid, {}) VALUES('delete', old.rowid, {}); \
END",
self.table, self.table, fts_table, fts_table, columns_str,
self.columns.iter().map(|c| format!("old.\"{}\"", c)).collect::<Vec<_>>().join(", ")
),
format!(
"CREATE TRIGGER IF NOT EXISTS \"{}_au\" AFTER UPDATE ON \"{}\" BEGIN \
INSERT INTO \"{}\"(\"{}\", rowid, {}) VALUES('delete', old.rowid, {}); \
INSERT INTO \"{}\"(rowid, {}) VALUES (new.rowid, {}); \
END",
self.table, self.table,
fts_table, fts_table, columns_str,
self.columns.iter().map(|c| format!("old.\"{}\"", c)).collect::<Vec<_>>().join(", "),
fts_table, columns_str,
self.columns.iter().map(|c| format!("new.\"{}\"", c)).collect::<Vec<_>>().join(", ")
),
]
}
pub fn to_sql(&self, db_type: DatabaseType) -> Vec<String> {
match db_type {
DatabaseType::Postgres => vec![self.to_postgres_sql()],
DatabaseType::MySQL | DatabaseType::MariaDB => vec![self.to_mysql_sql()],
DatabaseType::SQLite => self.to_sqlite_sql(),
}
}
}
pub fn highlight_text(
text: &str,
query: &str,
start_tag: &str,
end_tag: &str,
) -> String {
let words: Vec<&str> = query.split_whitespace().collect();
let mut result = text.to_string();
let patterns: Vec<regex::Regex> = words.iter()
.filter_map(|word| {
regex::Regex::new(&format!(r"(?i)\b{}\b", regex::escape(word))).ok()
})
.collect();
for pattern in &patterns {
result = pattern.replace_all(&result, |caps: ®ex::Captures| {
format!("{}{}{}", start_tag, &caps[0], end_tag)
}).to_string();
}
result
}
pub fn generate_snippet(
text: &str,
query: &str,
fragment_words: usize,
start_tag: &str,
end_tag: &str,
) -> String {
let words: Vec<&str> = text.split_whitespace().collect();
let query_words_owned: Vec<String> = query.split_whitespace()
.map(|w| w.to_lowercase())
.collect();
let mut match_pos = None;
for (i, word) in words.iter().enumerate() {
let word_lower = word.to_lowercase();
if query_words_owned.iter().any(|q| word_lower.contains(q)) {
match_pos = Some(i);
break;
}
}
if let Some(pos) = match_pos {
let start = pos.saturating_sub(fragment_words);
let end = (pos + fragment_words).min(words.len());
let snippet_words: Vec<String> = words[start..end].iter()
.map(|w| {
let word_lower = w.to_lowercase();
if query_words_owned.iter().any(|q| word_lower.contains(q)) {
format!("{}{}{}", start_tag, w, end_tag)
} else {
w.to_string()
}
})
.collect();
let mut snippet = snippet_words.join(" ");
if start > 0 {
snippet = format!("...{}", snippet);
}
if end < words.len() {
snippet = format!("{}...", snippet);
}
snippet
} else {
let end = fragment_words.min(words.len());
let snippet = words[..end].join(" ");
if end < words.len() {
format!("{}...", snippet)
} else {
snippet
}
}
}
pub fn pg_headline_sql(
column: &str,
query: &str,
language: &str,
start_tag: &str,
end_tag: &str,
) -> String {
format!(
"ts_headline('{}', \"{}\", plainto_tsquery('{}', '{}'), \
'StartSel={}, StopSel={}, MaxWords=35, MinWords=15')",
language, column, language, escape_string(query), start_tag, end_tag
)
}
fn escape_string(s: &str) -> String {
s.replace('\'', "''")
.replace('\\', "\\\\")
}
fn escape_fts5_query(s: &str) -> String {
s.replace('"', "\"\"")
.replace('\'', "''")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_search_mode_display() {
assert_eq!(SearchMode::Natural.to_string(), "natural");
assert_eq!(SearchMode::Boolean.to_string(), "boolean");
assert_eq!(SearchMode::Phrase.to_string(), "phrase");
assert_eq!(SearchMode::Prefix.to_string(), "prefix");
assert_eq!(SearchMode::Fuzzy.to_string(), "fuzzy");
assert_eq!(SearchMode::Proximity(3).to_string(), "proximity(3)");
}
#[test]
fn test_search_weights() {
let weights = SearchWeights::new(1.0, 0.5, 0.3, 0.1);
assert_eq!(weights.to_pg_array(), "'{0.1,0.3,0.5,1}'");
}
#[test]
fn test_highlight_text() {
let text = "The quick brown fox jumps over the lazy dog";
let highlighted = highlight_text(text, "quick fox", "<b>", "</b>");
assert!(highlighted.contains("<b>quick</b>"));
assert!(highlighted.contains("<b>fox</b>"));
}
#[test]
fn test_generate_snippet() {
let text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. \
The quick brown fox jumps over the lazy dog. \
Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.";
let snippet = generate_snippet(text, "fox", 5, "<mark>", "</mark>");
assert!(snippet.contains("<mark>fox</mark>"));
assert!(snippet.contains("..."));
}
#[test]
fn test_fulltext_index_postgres() {
let index = FullTextIndex::new("idx_articles_search", "articles", vec!["title".to_string(), "content".to_string()])
.language("english")
.pg_index_type(PgFullTextIndexType::GIN);
let sql = index.to_postgres_sql();
assert!(sql.contains("CREATE INDEX"));
assert!(sql.contains("USING GIN"));
assert!(sql.contains("to_tsvector"));
}
#[test]
fn test_fulltext_index_mysql() {
let index = FullTextIndex::new("idx_articles_search", "articles", vec!["title".to_string(), "content".to_string()]);
let sql = index.to_mysql_sql();
assert!(sql.contains("CREATE FULLTEXT INDEX"));
assert!(sql.contains("`title`, `content`"));
}
#[test]
fn test_fulltext_index_mariadb() {
let index = FullTextIndex::new("idx_articles_search", "articles", vec!["title".to_string(), "content".to_string()]);
let sqls = index.to_sql(DatabaseType::MariaDB);
assert_eq!(sqls.len(), 1);
let sql = &sqls[0];
assert!(sql.contains("CREATE FULLTEXT INDEX"));
assert!(sql.contains("`title`, `content`"));
}
#[test]
fn test_fulltext_index_sqlite() {
let index = FullTextIndex::new("idx_articles_search", "articles", vec!["title".to_string(), "content".to_string()]);
let sqls = index.to_sqlite_sql();
assert!(sqls.len() == 4);
assert!(sqls[0].contains("CREATE VIRTUAL TABLE"));
assert!(sqls[0].contains("fts5"));
}
#[test]
fn test_escape_string() {
assert_eq!(escape_string("it's"), "it''s");
assert_eq!(escape_string("back\\slash"), "back\\\\slash");
}
#[test]
fn test_fulltext_config() {
let config = FullTextConfig::new()
.language("german")
.mode(SearchMode::Boolean)
.min_word_length(3)
.max_word_length(50);
assert_eq!(config.language, Some("german".to_string()));
assert_eq!(config.mode, SearchMode::Boolean);
assert_eq!(config.min_word_length, Some(3));
assert_eq!(config.max_word_length, Some(50));
}
#[test]
fn test_search_result() {
let result: SearchResult<String> = SearchResult::new("test".to_string(), 0.95);
assert_eq!(result.record, "test");
assert_eq!(result.rank, 0.95);
assert!(result.highlights.is_empty());
}
#[test]
fn test_highlighted_field() {
let field = HighlightedField::new(
"content",
"The <mark>quick</mark> brown <mark>fox</mark>",
"The quick brown fox"
);
assert_eq!(field.field, "content");
assert_eq!(field.match_count, 2);
}
}