use super::{FilterBackend, FilterResult};
use async_trait::async_trait;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Default)]
pub struct SynonymDictionary {
synonyms: HashMap<String, HashSet<String>>,
bidirectional: bool,
}
impl SynonymDictionary {
pub fn new() -> Self {
Self {
synonyms: HashMap::new(),
bidirectional: true,
}
}
pub fn with_bidirectional(bidirectional: bool) -> Self {
Self {
synonyms: HashMap::new(),
bidirectional,
}
}
pub fn add_synonym(&mut self, term: impl Into<String>, synonym: impl Into<String>) {
let term = term.into().to_lowercase();
let synonym = synonym.into().to_lowercase();
self.synonyms
.entry(term.clone())
.or_default()
.insert(synonym.clone());
if self.bidirectional {
self.synonyms.entry(synonym).or_default().insert(term);
}
}
pub fn add_synonyms(
&mut self,
term: impl Into<String>,
synonyms: impl IntoIterator<Item = impl Into<String>>,
) {
let term = term.into();
for synonym in synonyms {
self.add_synonym(term.clone(), synonym);
}
}
pub fn get_synonyms(&self, term: &str) -> Vec<String> {
self.synonyms
.get(&term.to_lowercase())
.map(|set| set.iter().cloned().collect())
.unwrap_or_default()
}
pub fn from_groups(groups: Vec<Vec<impl Into<String>>>) -> Self {
let mut dict = Self::new();
for group in groups {
let terms: Vec<String> = group.into_iter().map(|s| s.into()).collect();
for i in 0..terms.len() {
for j in 0..terms.len() {
if i != j {
dict.add_synonym(terms[i].clone(), terms[j].clone());
}
}
}
}
dict
}
pub fn len(&self) -> usize {
self.synonyms.len()
}
pub fn is_empty(&self) -> bool {
self.synonyms.is_empty()
}
}
#[derive(Debug, Default)]
pub struct SynonymExpander {
dictionary: SynonymDictionary,
enabled: bool,
expansion_limit: Option<usize>,
min_term_length: usize,
}
impl SynonymExpander {
pub fn new() -> Self {
Self {
dictionary: SynonymDictionary::new(),
enabled: true,
expansion_limit: Some(10),
min_term_length: 3,
}
}
pub fn with_dictionary(mut self, dictionary: SynonymDictionary) -> Self {
self.dictionary = dictionary;
self
}
pub fn with_expansion_limit(mut self, limit: usize) -> Self {
self.expansion_limit = Some(limit);
self
}
pub fn with_min_term_length(mut self, length: usize) -> Self {
self.min_term_length = length;
self
}
pub fn set_enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
pub fn expand_query(&self, query: &str) -> Vec<String> {
let terms: Vec<&str> = query.split_whitespace().collect();
let mut expanded_terms = Vec::new();
for term in terms {
expanded_terms.push(term.to_string());
if term.len() < self.min_term_length {
continue;
}
let synonyms = self.dictionary.get_synonyms(term);
let synonyms_to_add = if let Some(limit) = self.expansion_limit {
synonyms.into_iter().take(limit).collect::<Vec<_>>()
} else {
synonyms
};
expanded_terms.extend(synonyms_to_add);
}
expanded_terms
}
fn apply_expansion(&self, sql: String, search_terms: &str) -> FilterResult<String> {
use reinhardt_query::prelude::{Cond, Expr, MySqlQueryBuilder, QueryStatementBuilder};
let expanded_terms = self.expand_query(search_terms);
if expanded_terms.is_empty() {
return Ok(sql);
}
let mut cond = Cond::any();
for term in &expanded_terms {
let escaped_term = term.replace('\'', "''");
let like_expr = format!("content LIKE '%{}%'", escaped_term);
cond = cond.add(Expr::cust(like_expr));
}
let dummy_query = reinhardt_query::prelude::Query::select()
.expr(Expr::val(1))
.cond_where(cond)
.to_owned();
let full_sql = dummy_query.to_string(MySqlQueryBuilder);
let where_idx = full_sql.find("WHERE").unwrap_or(0);
let where_clause = &full_sql[where_idx..];
if sql.to_uppercase().contains("WHERE") {
Ok(sql.replace("WHERE", &format!("{} AND", where_clause)))
} else {
Ok(format!("{} {}", sql, where_clause))
}
}
}
#[async_trait]
impl FilterBackend for SynonymExpander {
async fn filter_queryset(
&self,
query_params: &HashMap<String, String>,
sql: String,
) -> FilterResult<String> {
if !self.enabled {
return Ok(sql);
}
let search_terms = query_params
.get("q")
.or_else(|| query_params.get("search"))
.or_else(|| query_params.get("query"));
if let Some(terms) = search_terms {
self.apply_expansion(sql, terms)
} else {
Ok(sql)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_synonym_dictionary_creation() {
let dict = SynonymDictionary::new();
assert!(dict.is_empty());
assert_eq!(dict.len(), 0);
}
#[test]
fn test_synonym_dictionary_add_synonym() {
let mut dict = SynonymDictionary::new();
dict.add_synonym("car", "automobile");
let synonyms = dict.get_synonyms("car");
assert_eq!(synonyms.len(), 1);
assert!(synonyms.contains(&"automobile".to_string()));
}
#[test]
fn test_synonym_dictionary_bidirectional() {
let mut dict = SynonymDictionary::new();
dict.add_synonym("car", "automobile");
let car_synonyms = dict.get_synonyms("car");
assert!(car_synonyms.contains(&"automobile".to_string()));
let auto_synonyms = dict.get_synonyms("automobile");
assert!(auto_synonyms.contains(&"car".to_string()));
}
#[test]
fn test_synonym_dictionary_unidirectional() {
let mut dict = SynonymDictionary::with_bidirectional(false);
dict.add_synonym("car", "automobile");
let car_synonyms = dict.get_synonyms("car");
assert!(car_synonyms.contains(&"automobile".to_string()));
let auto_synonyms = dict.get_synonyms("automobile");
assert!(auto_synonyms.is_empty());
}
#[test]
fn test_synonym_dictionary_add_multiple() {
let mut dict = SynonymDictionary::new();
dict.add_synonyms("big", vec!["large", "huge", "enormous"]);
let synonyms = dict.get_synonyms("big");
assert_eq!(synonyms.len(), 3);
assert!(synonyms.contains(&"large".to_string()));
assert!(synonyms.contains(&"huge".to_string()));
assert!(synonyms.contains(&"enormous".to_string()));
}
#[test]
fn test_synonym_dictionary_from_groups() {
let groups = vec![
vec!["happy", "joyful", "glad"],
vec!["sad", "unhappy", "sorrowful"],
];
let dict = SynonymDictionary::from_groups(groups);
let happy_synonyms = dict.get_synonyms("happy");
assert_eq!(happy_synonyms.len(), 2);
assert!(happy_synonyms.contains(&"joyful".to_string()));
assert!(happy_synonyms.contains(&"glad".to_string()));
}
#[test]
fn test_synonym_dictionary_case_insensitive() {
let mut dict = SynonymDictionary::new();
dict.add_synonym("Car", "Automobile");
let synonyms = dict.get_synonyms("car");
assert!(synonyms.contains(&"automobile".to_string()));
let synonyms_upper = dict.get_synonyms("CAR");
assert!(synonyms_upper.contains(&"automobile".to_string()));
}
#[test]
fn test_synonym_expander_creation() {
let expander = SynonymExpander::new();
assert!(expander.enabled);
assert_eq!(expander.min_term_length, 3);
assert_eq!(expander.expansion_limit, Some(10));
}
#[test]
fn test_synonym_expander_with_dictionary() {
let mut dict = SynonymDictionary::new();
dict.add_synonym("fast", "quick");
let expander = SynonymExpander::new().with_dictionary(dict);
let expanded = expander.expand_query("fast");
assert!(expanded.len() > 1);
}
#[test]
fn test_synonym_expander_expand_query() {
let mut dict = SynonymDictionary::new();
dict.add_synonym("fast", "quick");
dict.add_synonym("fast", "rapid");
let expander = SynonymExpander::new().with_dictionary(dict);
let expanded = expander.expand_query("fast car");
assert!(expanded.contains(&"fast".to_string()));
assert!(expanded.contains(&"quick".to_string()));
assert!(expanded.contains(&"rapid".to_string()));
assert!(expanded.contains(&"car".to_string()));
}
#[test]
fn test_synonym_expander_min_term_length() {
let mut dict = SynonymDictionary::new();
dict.add_synonym("is", "exists");
let expander = SynonymExpander::new()
.with_dictionary(dict)
.with_min_term_length(3);
let expanded = expander.expand_query("is");
assert_eq!(expanded.len(), 1);
assert_eq!(expanded[0], "is");
}
#[test]
fn test_synonym_expander_expansion_limit() {
let mut dict = SynonymDictionary::new();
dict.add_synonyms(
"big",
vec!["large", "huge", "enormous", "massive", "gigantic"],
);
let expander = SynonymExpander::new()
.with_dictionary(dict)
.with_expansion_limit(2);
let expanded = expander.expand_query("big");
assert!(expanded.len() <= 3);
}
#[test]
fn test_synonym_expander_disabled() {
let expander = SynonymExpander::new().set_enabled(false);
assert!(!expander.enabled);
}
#[tokio::test]
async fn test_synonym_expander_no_search_terms() {
let expander = SynonymExpander::new();
let params = HashMap::new();
let sql = "SELECT * FROM articles".to_string();
let result = expander
.filter_queryset(¶ms, sql.clone())
.await
.unwrap();
assert_eq!(result, sql);
}
#[tokio::test]
async fn test_synonym_expander_disabled_passthrough() {
let expander = SynonymExpander::new().set_enabled(false);
let mut params = HashMap::new();
params.insert("q".to_string(), "fast".to_string());
let sql = "SELECT * FROM articles".to_string();
let result = expander
.filter_queryset(¶ms, sql.clone())
.await
.unwrap();
assert_eq!(result, sql);
}
#[tokio::test]
async fn test_synonym_expander_single_term_expansion() {
let mut dict = SynonymDictionary::new();
dict.add_synonym("fast", "quick");
dict.add_synonym("fast", "rapid");
let expander = SynonymExpander::new().with_dictionary(dict);
let mut params = HashMap::new();
params.insert("q".to_string(), "fast".to_string());
let sql = "SELECT * FROM articles".to_string();
let result = expander.filter_queryset(¶ms, sql).await.unwrap();
let where_start = result.find("WHERE").expect("WHERE clause not found");
let where_clause = &result[where_start..];
assert!(
where_clause.starts_with("WHERE "),
"Expected WHERE clause to start with 'WHERE ', got: {}",
where_clause
);
assert!(
where_clause.contains("content LIKE '%fast%'"),
"Expected 'content LIKE '%fast%' in WHERE clause, got: {}",
where_clause
);
assert!(
where_clause.contains("content LIKE '%quick%'"),
"Expected 'content LIKE '%quick%' in WHERE clause, got: {}",
where_clause
);
assert!(
where_clause.contains("content LIKE '%rapid%'"),
"Expected 'content LIKE '%rapid%' in WHERE clause, got: {}",
where_clause
);
let or_count = where_clause.matches(" OR ").count();
assert_eq!(
or_count, 2,
"Expected 2 OR connectors for 3 terms, got: {}",
or_count
);
}
#[tokio::test]
async fn test_synonym_expander_multi_term_expansion() {
let mut dict = SynonymDictionary::new();
dict.add_synonym("fast", "quick");
dict.add_synonym("car", "automobile");
let expander = SynonymExpander::new().with_dictionary(dict);
let mut params = HashMap::new();
params.insert("q".to_string(), "fast car".to_string());
let sql = "SELECT * FROM articles".to_string();
let result = expander.filter_queryset(¶ms, sql).await.unwrap();
let where_start = result.find("WHERE").expect("WHERE clause not found");
let where_clause = &result[where_start..];
let expected_terms = vec!["fast", "quick", "car", "automobile"];
for term in expected_terms {
assert!(
where_clause.contains(&format!("content LIKE '%{}%'", term)),
"Expected 'content LIKE '%{}%' in WHERE clause, got: {}",
term,
where_clause
);
}
let or_count = where_clause.matches(" OR ").count();
assert_eq!(
or_count, 3,
"Expected 3 OR connectors for 4 terms, got: {}",
or_count
);
}
#[tokio::test]
async fn test_synonym_expander_existing_where_clause() {
let mut dict = SynonymDictionary::new();
dict.add_synonym("fast", "quick");
let expander = SynonymExpander::new().with_dictionary(dict);
let mut params = HashMap::new();
params.insert("q".to_string(), "fast".to_string());
let sql = "SELECT * FROM articles WHERE status = 'published'".to_string();
let result = expander.filter_queryset(¶ms, sql).await.unwrap();
assert!(
result.contains("WHERE") && result.contains("content LIKE"),
"Expected result to contain 'WHERE' and 'content LIKE', got: {}",
result
);
assert!(
result.contains("content LIKE '%fast%'"),
"Expected 'content LIKE '%fast%' in result, got: {}",
result
);
assert!(
result.contains("content LIKE '%quick%'"),
"Expected 'content LIKE '%quick%' in result, got: {}",
result
);
assert!(
result.contains("AND") && result.contains("status = 'published'"),
"Expected 'AND status = 'published'' in result, got: {}",
result
);
let synonym_pos = result.find("WHERE").unwrap();
let and_pos = result.find("AND").unwrap();
let status_pos = result.find("status = 'published'").unwrap();
assert!(
synonym_pos < and_pos && and_pos < status_pos,
"Expected order: WHERE (synonyms) AND status, got: {}",
result
);
}
#[tokio::test]
async fn test_synonym_expander_sql_injection_protection() {
let mut dict = SynonymDictionary::new();
dict.add_synonym("test", "test'; DROP TABLE articles; --");
let expander = SynonymExpander::new().with_dictionary(dict);
let mut params = HashMap::new();
params.insert("q".to_string(), "test".to_string());
let sql = "SELECT * FROM articles".to_string();
let result = expander.filter_queryset(¶ms, sql).await.unwrap();
let expected_escaped_synonym = "test''; drop table articles; --";
assert!(
result.contains(expected_escaped_synonym),
"Expected escaped malicious synonym '{}' in result, got: {}",
expected_escaped_synonym,
result
);
assert!(
result.contains("content LIKE '%test%'"),
"Expected original term 'content LIKE '%test%' in result, got: {}",
result
);
assert!(
result.contains(&format!("content LIKE '%{}%'", expected_escaped_synonym)),
"Expected malicious content in LIKE clause 'content LIKE '%{}%', got: {}",
expected_escaped_synonym,
result
);
let where_start = result.find("WHERE").expect("WHERE clause not found");
let where_clause = &result[where_start..];
let or_count = where_clause.matches(" OR ").count();
assert_eq!(
or_count, 1,
"Expected 1 OR connector for 2 terms (test + malicious synonym), got: {}",
or_count
);
}
#[tokio::test]
async fn test_synonym_expander_with_search_param() {
let mut dict = SynonymDictionary::new();
dict.add_synonym("fast", "quick");
let expander = SynonymExpander::new().with_dictionary(dict);
let mut params = HashMap::new();
params.insert("search".to_string(), "fast".to_string());
let sql = "SELECT * FROM articles".to_string();
let result = expander.filter_queryset(¶ms, sql).await.unwrap();
let where_start = result.find("WHERE").expect("WHERE clause not found");
let where_clause = &result[where_start..];
assert!(
where_clause.contains("content LIKE '%fast%'"),
"Expected 'content LIKE '%fast%' in WHERE clause, got: {}",
where_clause
);
assert!(
where_clause.contains("content LIKE '%quick%'"),
"Expected 'content LIKE '%quick%' in WHERE clause, got: {}",
where_clause
);
}
#[tokio::test]
async fn test_synonym_expander_with_query_param() {
let mut dict = SynonymDictionary::new();
dict.add_synonym("fast", "quick");
let expander = SynonymExpander::new().with_dictionary(dict);
let mut params = HashMap::new();
params.insert("query".to_string(), "fast".to_string());
let sql = "SELECT * FROM articles".to_string();
let result = expander.filter_queryset(¶ms, sql).await.unwrap();
let where_start = result.find("WHERE").expect("WHERE clause not found");
let where_clause = &result[where_start..];
assert!(
where_clause.contains("content LIKE '%fast%'"),
"Expected 'content LIKE '%fast%' in WHERE clause, got: {}",
where_clause
);
assert!(
where_clause.contains("content LIKE '%quick%'"),
"Expected 'content LIKE '%quick%' in WHERE clause, got: {}",
where_clause
);
}
}