Skip to main content

argentor_memory/
query_expansion.rs

1use std::collections::HashMap;
2use uuid::Uuid;
3
4/// Trait for expanding a query into multiple alternative queries.
5/// Implementations should always include the original query first in the results.
6pub trait QueryExpander: Send + Sync {
7    /// Expand a query into multiple alternative queries, including the original.
8    fn expand(&self, query: &str) -> Vec<String>;
9}
10
11/// Rule-based query expander that uses a synonym map to generate alternative queries.
12/// Each word in the query is checked against the synonym map, and if a match is found,
13/// new queries are generated by replacing that word with each of its synonyms.
14pub struct RuleBasedExpander {
15    synonyms: HashMap<String, Vec<String>>,
16}
17
18impl RuleBasedExpander {
19    /// Maximum number of expanded queries to return (including the original).
20    const MAX_EXPANSIONS: usize = 5;
21
22    /// Create a new `RuleBasedExpander` with default programming-related synonyms.
23    pub fn new() -> Self {
24        let mut synonyms = HashMap::new();
25        synonyms.insert(
26            "error".to_string(),
27            vec![
28                "bug".to_string(),
29                "issue".to_string(),
30                "problem".to_string(),
31                "exception".to_string(),
32            ],
33        );
34        synonyms.insert(
35            "function".to_string(),
36            vec![
37                "method".to_string(),
38                "fn".to_string(),
39                "procedure".to_string(),
40                "routine".to_string(),
41            ],
42        );
43        synonyms.insert(
44            "create".to_string(),
45            vec![
46                "make".to_string(),
47                "build".to_string(),
48                "generate".to_string(),
49                "new".to_string(),
50            ],
51        );
52        synonyms.insert(
53            "delete".to_string(),
54            vec![
55                "remove".to_string(),
56                "drop".to_string(),
57                "destroy".to_string(),
58                "erase".to_string(),
59            ],
60        );
61        synonyms.insert(
62            "update".to_string(),
63            vec![
64                "modify".to_string(),
65                "change".to_string(),
66                "edit".to_string(),
67                "patch".to_string(),
68            ],
69        );
70        synonyms.insert(
71            "list".to_string(),
72            vec![
73                "array".to_string(),
74                "vector".to_string(),
75                "collection".to_string(),
76                "slice".to_string(),
77            ],
78        );
79        synonyms.insert(
80            "config".to_string(),
81            vec![
82                "configuration".to_string(),
83                "settings".to_string(),
84                "options".to_string(),
85            ],
86        );
87        synonyms.insert(
88            "auth".to_string(),
89            vec![
90                "authentication".to_string(),
91                "authorization".to_string(),
92                "login".to_string(),
93            ],
94        );
95        synonyms.insert(
96            "db".to_string(),
97            vec![
98                "database".to_string(),
99                "storage".to_string(),
100                "datastore".to_string(),
101            ],
102        );
103        synonyms.insert(
104            "api".to_string(),
105            vec![
106                "endpoint".to_string(),
107                "interface".to_string(),
108                "service".to_string(),
109            ],
110        );
111
112        Self { synonyms }
113    }
114
115    /// Create a new `RuleBasedExpander` with a custom synonym map.
116    pub fn with_synonyms(synonyms: HashMap<String, Vec<String>>) -> Self {
117        Self { synonyms }
118    }
119
120    /// Add a synonym entry. If the word already exists, replaces its alternatives.
121    pub fn add_synonym(&mut self, word: &str, alternatives: Vec<String>) {
122        self.synonyms.insert(word.to_lowercase(), alternatives);
123    }
124}
125
126impl Default for RuleBasedExpander {
127    fn default() -> Self {
128        Self::new()
129    }
130}
131
132impl QueryExpander for RuleBasedExpander {
133    fn expand(&self, query: &str) -> Vec<String> {
134        // Always include the original query first.
135        let mut results = vec![query.to_string()];
136
137        // Tokenize: split by whitespace and lowercase for lookup.
138        let tokens: Vec<&str> = query.split_whitespace().collect();
139
140        if tokens.is_empty() {
141            return results;
142        }
143
144        // For each token, if it has synonyms, generate a new query by replacing
145        // that token with each synonym.
146        for (i, token) in tokens.iter().enumerate() {
147            let lower = token.to_lowercase();
148            if let Some(syns) = self.synonyms.get(&lower) {
149                for syn in syns {
150                    if results.len() >= Self::MAX_EXPANSIONS {
151                        break;
152                    }
153                    // Build the expanded query by replacing token at position i.
154                    let expanded: Vec<String> = tokens
155                        .iter()
156                        .enumerate()
157                        .map(|(j, t)| if j == i { syn.clone() } else { t.to_string() })
158                        .collect();
159                    let expanded_query = expanded.join(" ");
160
161                    // Deduplicate: only add if not already present.
162                    if !results.contains(&expanded_query) {
163                        results.push(expanded_query);
164                    }
165                }
166            }
167            if results.len() >= Self::MAX_EXPANSIONS {
168                break;
169            }
170        }
171
172        results
173    }
174}
175
176/// Merge search results by `Uuid`, keeping the highest score for each unique ID.
177/// Returns results sorted by score in descending order.
178pub fn deduplicate_results(results: Vec<(Uuid, f32)>) -> Vec<(Uuid, f32)> {
179    let mut best: HashMap<Uuid, f32> = HashMap::new();
180
181    for (id, score) in results {
182        let entry = best.entry(id).or_insert(score);
183        if score > *entry {
184            *entry = score;
185        }
186    }
187
188    let mut deduped: Vec<(Uuid, f32)> = best.into_iter().collect();
189    deduped.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
190
191    deduped
192}
193
194#[cfg(test)]
195#[allow(clippy::unwrap_used, clippy::expect_used)]
196mod tests {
197    use super::*;
198
199    #[test]
200    fn test_rule_based_expansion() {
201        let expander = RuleBasedExpander::new();
202        let results = expander.expand("fix the error in auth");
203        assert!(results.len() > 1);
204        assert_eq!(results[0], "fix the error in auth"); // original first
205                                                         // Should contain variations with synonyms
206        let has_bug = results.iter().any(|r| r.contains("bug"));
207        assert!(has_bug, "Should expand 'error' to 'bug': {results:?}");
208    }
209
210    #[test]
211    fn test_empty_query() {
212        let expander = RuleBasedExpander::new();
213        let results = expander.expand("");
214        assert_eq!(results.len(), 1);
215        assert_eq!(results[0], "");
216    }
217
218    #[test]
219    fn test_no_synonyms_match() {
220        let expander = RuleBasedExpander::new();
221        let results = expander.expand("hello world");
222        assert_eq!(results.len(), 1);
223        assert_eq!(results[0], "hello world");
224    }
225
226    #[test]
227    fn test_custom_synonyms() {
228        let mut synonyms = HashMap::new();
229        synonyms.insert(
230            "fast".to_string(),
231            vec!["quick".to_string(), "rapid".to_string()],
232        );
233        let expander = RuleBasedExpander::with_synonyms(synonyms);
234        let results = expander.expand("fast code");
235        assert!(results.len() > 1);
236        assert!(results.iter().any(|r| r.contains("quick")));
237    }
238
239    #[test]
240    fn test_dedup_results() {
241        let id1 = Uuid::new_v4();
242        let id2 = Uuid::new_v4();
243        let results = vec![
244            (id1, 0.5),
245            (id2, 0.3),
246            (id1, 0.8), // duplicate id1 with higher score
247        ];
248        let deduped = deduplicate_results(results);
249        assert_eq!(deduped.len(), 2);
250        // id1 should have the higher score (0.8)
251        let id1_result = deduped.iter().find(|(id, _)| *id == id1).unwrap();
252        assert!((id1_result.1 - 0.8).abs() < 0.001);
253    }
254}