1use std::collections::HashMap;
2use uuid::Uuid;
3
4pub trait QueryExpander: Send + Sync {
7 fn expand(&self, query: &str) -> Vec<String>;
9}
10
11pub struct RuleBasedExpander {
15 synonyms: HashMap<String, Vec<String>>,
16}
17
18impl RuleBasedExpander {
19 const MAX_EXPANSIONS: usize = 5;
21
22 pub fn new() -> Self {
24 let mut synonyms = HashMap::new();
25 synonyms.insert(
26 "error".to_string(),
27 vec![
28 "bug".to_string(),
29 "issue".to_string(),
30 "problem".to_string(),
31 "exception".to_string(),
32 ],
33 );
34 synonyms.insert(
35 "function".to_string(),
36 vec![
37 "method".to_string(),
38 "fn".to_string(),
39 "procedure".to_string(),
40 "routine".to_string(),
41 ],
42 );
43 synonyms.insert(
44 "create".to_string(),
45 vec![
46 "make".to_string(),
47 "build".to_string(),
48 "generate".to_string(),
49 "new".to_string(),
50 ],
51 );
52 synonyms.insert(
53 "delete".to_string(),
54 vec![
55 "remove".to_string(),
56 "drop".to_string(),
57 "destroy".to_string(),
58 "erase".to_string(),
59 ],
60 );
61 synonyms.insert(
62 "update".to_string(),
63 vec![
64 "modify".to_string(),
65 "change".to_string(),
66 "edit".to_string(),
67 "patch".to_string(),
68 ],
69 );
70 synonyms.insert(
71 "list".to_string(),
72 vec![
73 "array".to_string(),
74 "vector".to_string(),
75 "collection".to_string(),
76 "slice".to_string(),
77 ],
78 );
79 synonyms.insert(
80 "config".to_string(),
81 vec![
82 "configuration".to_string(),
83 "settings".to_string(),
84 "options".to_string(),
85 ],
86 );
87 synonyms.insert(
88 "auth".to_string(),
89 vec![
90 "authentication".to_string(),
91 "authorization".to_string(),
92 "login".to_string(),
93 ],
94 );
95 synonyms.insert(
96 "db".to_string(),
97 vec![
98 "database".to_string(),
99 "storage".to_string(),
100 "datastore".to_string(),
101 ],
102 );
103 synonyms.insert(
104 "api".to_string(),
105 vec![
106 "endpoint".to_string(),
107 "interface".to_string(),
108 "service".to_string(),
109 ],
110 );
111
112 Self { synonyms }
113 }
114
115 pub fn with_synonyms(synonyms: HashMap<String, Vec<String>>) -> Self {
117 Self { synonyms }
118 }
119
120 pub fn add_synonym(&mut self, word: &str, alternatives: Vec<String>) {
122 self.synonyms.insert(word.to_lowercase(), alternatives);
123 }
124}
125
126impl Default for RuleBasedExpander {
127 fn default() -> Self {
128 Self::new()
129 }
130}
131
132impl QueryExpander for RuleBasedExpander {
133 fn expand(&self, query: &str) -> Vec<String> {
134 let mut results = vec![query.to_string()];
136
137 let tokens: Vec<&str> = query.split_whitespace().collect();
139
140 if tokens.is_empty() {
141 return results;
142 }
143
144 for (i, token) in tokens.iter().enumerate() {
147 let lower = token.to_lowercase();
148 if let Some(syns) = self.synonyms.get(&lower) {
149 for syn in syns {
150 if results.len() >= Self::MAX_EXPANSIONS {
151 break;
152 }
153 let expanded: Vec<String> = tokens
155 .iter()
156 .enumerate()
157 .map(|(j, t)| if j == i { syn.clone() } else { t.to_string() })
158 .collect();
159 let expanded_query = expanded.join(" ");
160
161 if !results.contains(&expanded_query) {
163 results.push(expanded_query);
164 }
165 }
166 }
167 if results.len() >= Self::MAX_EXPANSIONS {
168 break;
169 }
170 }
171
172 results
173 }
174}
175
176pub fn deduplicate_results(results: Vec<(Uuid, f32)>) -> Vec<(Uuid, f32)> {
179 let mut best: HashMap<Uuid, f32> = HashMap::new();
180
181 for (id, score) in results {
182 let entry = best.entry(id).or_insert(score);
183 if score > *entry {
184 *entry = score;
185 }
186 }
187
188 let mut deduped: Vec<(Uuid, f32)> = best.into_iter().collect();
189 deduped.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
190
191 deduped
192}
193
194#[cfg(test)]
195#[allow(clippy::unwrap_used, clippy::expect_used)]
196mod tests {
197 use super::*;
198
199 #[test]
200 fn test_rule_based_expansion() {
201 let expander = RuleBasedExpander::new();
202 let results = expander.expand("fix the error in auth");
203 assert!(results.len() > 1);
204 assert_eq!(results[0], "fix the error in auth"); let has_bug = results.iter().any(|r| r.contains("bug"));
207 assert!(has_bug, "Should expand 'error' to 'bug': {results:?}");
208 }
209
210 #[test]
211 fn test_empty_query() {
212 let expander = RuleBasedExpander::new();
213 let results = expander.expand("");
214 assert_eq!(results.len(), 1);
215 assert_eq!(results[0], "");
216 }
217
218 #[test]
219 fn test_no_synonyms_match() {
220 let expander = RuleBasedExpander::new();
221 let results = expander.expand("hello world");
222 assert_eq!(results.len(), 1);
223 assert_eq!(results[0], "hello world");
224 }
225
226 #[test]
227 fn test_custom_synonyms() {
228 let mut synonyms = HashMap::new();
229 synonyms.insert(
230 "fast".to_string(),
231 vec!["quick".to_string(), "rapid".to_string()],
232 );
233 let expander = RuleBasedExpander::with_synonyms(synonyms);
234 let results = expander.expand("fast code");
235 assert!(results.len() > 1);
236 assert!(results.iter().any(|r| r.contains("quick")));
237 }
238
239 #[test]
240 fn test_dedup_results() {
241 let id1 = Uuid::new_v4();
242 let id2 = Uuid::new_v4();
243 let results = vec![
244 (id1, 0.5),
245 (id2, 0.3),
246 (id1, 0.8), ];
248 let deduped = deduplicate_results(results);
249 assert_eq!(deduped.len(), 2);
250 let id1_result = deduped.iter().find(|(id, _)| *id == id1).unwrap();
252 assert!((id1_result.1 - 0.8).abs() < 0.001);
253 }
254}