use std::collections::HashMap;
use uuid::Uuid;
pub trait QueryExpander: Send + Sync {
fn expand(&self, query: &str) -> Vec<String>;
}
pub struct RuleBasedExpander {
synonyms: HashMap<String, Vec<String>>,
}
impl RuleBasedExpander {
const MAX_EXPANSIONS: usize = 5;
pub fn new() -> Self {
let mut synonyms = HashMap::new();
synonyms.insert(
"error".to_string(),
vec![
"bug".to_string(),
"issue".to_string(),
"problem".to_string(),
"exception".to_string(),
],
);
synonyms.insert(
"function".to_string(),
vec![
"method".to_string(),
"fn".to_string(),
"procedure".to_string(),
"routine".to_string(),
],
);
synonyms.insert(
"create".to_string(),
vec![
"make".to_string(),
"build".to_string(),
"generate".to_string(),
"new".to_string(),
],
);
synonyms.insert(
"delete".to_string(),
vec![
"remove".to_string(),
"drop".to_string(),
"destroy".to_string(),
"erase".to_string(),
],
);
synonyms.insert(
"update".to_string(),
vec![
"modify".to_string(),
"change".to_string(),
"edit".to_string(),
"patch".to_string(),
],
);
synonyms.insert(
"list".to_string(),
vec![
"array".to_string(),
"vector".to_string(),
"collection".to_string(),
"slice".to_string(),
],
);
synonyms.insert(
"config".to_string(),
vec![
"configuration".to_string(),
"settings".to_string(),
"options".to_string(),
],
);
synonyms.insert(
"auth".to_string(),
vec![
"authentication".to_string(),
"authorization".to_string(),
"login".to_string(),
],
);
synonyms.insert(
"db".to_string(),
vec![
"database".to_string(),
"storage".to_string(),
"datastore".to_string(),
],
);
synonyms.insert(
"api".to_string(),
vec![
"endpoint".to_string(),
"interface".to_string(),
"service".to_string(),
],
);
Self { synonyms }
}
pub fn with_synonyms(synonyms: HashMap<String, Vec<String>>) -> Self {
Self { synonyms }
}
pub fn add_synonym(&mut self, word: &str, alternatives: Vec<String>) {
self.synonyms.insert(word.to_lowercase(), alternatives);
}
}
impl Default for RuleBasedExpander {
fn default() -> Self {
Self::new()
}
}
impl QueryExpander for RuleBasedExpander {
fn expand(&self, query: &str) -> Vec<String> {
let mut results = vec![query.to_string()];
let tokens: Vec<&str> = query.split_whitespace().collect();
if tokens.is_empty() {
return results;
}
for (i, token) in tokens.iter().enumerate() {
let lower = token.to_lowercase();
if let Some(syns) = self.synonyms.get(&lower) {
for syn in syns {
if results.len() >= Self::MAX_EXPANSIONS {
break;
}
let expanded: Vec<String> = tokens
.iter()
.enumerate()
.map(|(j, t)| if j == i { syn.clone() } else { t.to_string() })
.collect();
let expanded_query = expanded.join(" ");
if !results.contains(&expanded_query) {
results.push(expanded_query);
}
}
}
if results.len() >= Self::MAX_EXPANSIONS {
break;
}
}
results
}
}
pub fn deduplicate_results(results: Vec<(Uuid, f32)>) -> Vec<(Uuid, f32)> {
let mut best: HashMap<Uuid, f32> = HashMap::new();
for (id, score) in results {
let entry = best.entry(id).or_insert(score);
if score > *entry {
*entry = score;
}
}
let mut deduped: Vec<(Uuid, f32)> = best.into_iter().collect();
deduped.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
deduped
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_rule_based_expansion() {
let expander = RuleBasedExpander::new();
let results = expander.expand("fix the error in auth");
assert!(results.len() > 1);
assert_eq!(results[0], "fix the error in auth"); let has_bug = results.iter().any(|r| r.contains("bug"));
assert!(has_bug, "Should expand 'error' to 'bug': {results:?}");
}
#[test]
fn test_empty_query() {
let expander = RuleBasedExpander::new();
let results = expander.expand("");
assert_eq!(results.len(), 1);
assert_eq!(results[0], "");
}
#[test]
fn test_no_synonyms_match() {
let expander = RuleBasedExpander::new();
let results = expander.expand("hello world");
assert_eq!(results.len(), 1);
assert_eq!(results[0], "hello world");
}
#[test]
fn test_custom_synonyms() {
let mut synonyms = HashMap::new();
synonyms.insert(
"fast".to_string(),
vec!["quick".to_string(), "rapid".to_string()],
);
let expander = RuleBasedExpander::with_synonyms(synonyms);
let results = expander.expand("fast code");
assert!(results.len() > 1);
assert!(results.iter().any(|r| r.contains("quick")));
}
#[test]
fn test_dedup_results() {
let id1 = Uuid::new_v4();
let id2 = Uuid::new_v4();
let results = vec![
(id1, 0.5),
(id2, 0.3),
(id1, 0.8), ];
let deduped = deduplicate_results(results);
assert_eq!(deduped.len(), 2);
let id1_result = deduped.iter().find(|(id, _)| *id == id1).unwrap();
assert!((id1_result.1 - 0.8).abs() < 0.001);
}
}