oxirs_vec/hybrid_search/
query_expansion.rs

1//! Query expansion for improved recall
2
3use std::collections::{HashMap, HashSet};
4
5/// Query expander for improving recall
6pub struct QueryExpander {
7    /// Synonym map: term -> synonyms
8    synonyms: HashMap<String, Vec<String>>,
9    /// Maximum expanded terms
10    max_expanded_terms: usize,
11}
12
13impl QueryExpander {
14    /// Create a new query expander
15    pub fn new(max_expanded_terms: usize) -> Self {
16        Self {
17            synonyms: Self::build_default_synonyms(),
18            max_expanded_terms,
19        }
20    }
21
22    /// Build default synonym dictionary
23    fn build_default_synonyms() -> HashMap<String, Vec<String>> {
24        let mut synonyms = HashMap::new();
25
26        // Common synonyms for search
27        synonyms.insert(
28            "search".to_string(),
29            vec![
30                "find".to_string(),
31                "lookup".to_string(),
32                "query".to_string(),
33            ],
34        );
35        synonyms.insert(
36            "find".to_string(),
37            vec!["search".to_string(), "locate".to_string()],
38        );
39        synonyms.insert(
40            "fast".to_string(),
41            vec![
42                "quick".to_string(),
43                "rapid".to_string(),
44                "speedy".to_string(),
45            ],
46        );
47        synonyms.insert(
48            "slow".to_string(),
49            vec!["sluggish".to_string(), "gradual".to_string()],
50        );
51        synonyms.insert(
52            "big".to_string(),
53            vec![
54                "large".to_string(),
55                "huge".to_string(),
56                "massive".to_string(),
57            ],
58        );
59        synonyms.insert(
60            "small".to_string(),
61            vec![
62                "tiny".to_string(),
63                "little".to_string(),
64                "compact".to_string(),
65            ],
66        );
67        synonyms.insert(
68            "good".to_string(),
69            vec![
70                "great".to_string(),
71                "excellent".to_string(),
72                "superb".to_string(),
73            ],
74        );
75        synonyms.insert(
76            "bad".to_string(),
77            vec!["poor".to_string(), "terrible".to_string()],
78        );
79
80        synonyms
81    }
82
83    /// Add synonyms for a term
84    pub fn add_synonyms(&mut self, term: &str, synonyms: Vec<String>) {
85        self.synonyms.insert(term.to_string(), synonyms);
86    }
87
88    /// Expand a query with synonyms
89    pub fn expand(&self, query: &str) -> Vec<String> {
90        let original_terms: Vec<String> = query
91            .to_lowercase()
92            .split_whitespace()
93            .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
94            .filter(|s| !s.is_empty())
95            .map(String::from)
96            .collect();
97
98        let mut expanded = HashSet::new();
99
100        // Add original terms
101        for term in &original_terms {
102            expanded.insert(term.clone());
103        }
104
105        // Add synonyms
106        for term in &original_terms {
107            if let Some(syns) = self.synonyms.get(term) {
108                for syn in syns {
109                    if expanded.len() < self.max_expanded_terms {
110                        expanded.insert(syn.clone());
111                    }
112                }
113            }
114        }
115
116        expanded.into_iter().collect()
117    }
118
119    /// Expand with co-occurrence based expansion
120    pub fn expand_with_cooccurrence(
121        &self,
122        query: &str,
123        cooccurrence_map: &HashMap<String, Vec<(String, f32)>>,
124        threshold: f32,
125    ) -> Vec<String> {
126        let original_terms: Vec<String> = query
127            .to_lowercase()
128            .split_whitespace()
129            .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
130            .filter(|s| !s.is_empty())
131            .map(String::from)
132            .collect();
133
134        let mut expanded = HashSet::new();
135
136        // Add original terms
137        for term in &original_terms {
138            expanded.insert(term.clone());
139        }
140
141        // Add co-occurring terms
142        for term in &original_terms {
143            if let Some(cooccurrences) = cooccurrence_map.get(term) {
144                for (coterm, score) in cooccurrences {
145                    if *score >= threshold && expanded.len() < self.max_expanded_terms {
146                        expanded.insert(coterm.clone());
147                    }
148                }
149            }
150        }
151
152        expanded.into_iter().collect()
153    }
154
155    /// Get synonym count
156    pub fn synonym_count(&self) -> usize {
157        self.synonyms.len()
158    }
159}
160
161impl Default for QueryExpander {
162    fn default() -> Self {
163        Self::new(10)
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    #[test]
172    fn test_basic_expansion() {
173        let expander = QueryExpander::new(10);
174        let expanded = expander.expand("fast search");
175
176        assert!(expanded.contains(&"fast".to_string()));
177        assert!(expanded.contains(&"search".to_string()));
178        // Should have synonyms
179        assert!(expanded.len() > 2);
180    }
181
182    #[test]
183    fn test_max_expansion_limit() {
184        let expander = QueryExpander::new(3);
185        let expanded = expander.expand("fast search");
186
187        assert!(expanded.len() <= 3);
188    }
189
190    #[test]
191    fn test_custom_synonyms() {
192        let mut expander = QueryExpander::new(10);
193        expander.add_synonyms("ml", vec!["machine learning".to_string(), "ai".to_string()]);
194
195        let expanded = expander.expand("ml");
196        assert!(expanded.contains(&"ml".to_string()));
197    }
198
199    #[test]
200    fn test_cooccurrence_expansion() {
201        let expander = QueryExpander::new(10);
202        let mut cooccurrence = HashMap::new();
203        cooccurrence.insert(
204            "machine".to_string(),
205            vec![
206                ("learning".to_string(), 0.9),
207                ("intelligence".to_string(), 0.7),
208                ("car".to_string(), 0.2),
209            ],
210        );
211
212        let expanded = expander.expand_with_cooccurrence("machine", &cooccurrence, 0.5);
213
214        assert!(expanded.contains(&"machine".to_string()));
215        assert!(expanded.contains(&"learning".to_string()));
216        assert!(expanded.contains(&"intelligence".to_string()));
217        assert!(!expanded.contains(&"car".to_string())); // Below threshold
218    }
219
220    #[test]
221    fn test_empty_query() {
222        let expander = QueryExpander::new(10);
223        let expanded = expander.expand("");
224        assert!(expanded.is_empty());
225    }
226
227    #[test]
228    fn test_unknown_terms() {
229        let expander = QueryExpander::new(10);
230        let expanded = expander.expand("zzz xyz abc");
231
232        // Should return original terms even without synonyms
233        assert_eq!(expanded.len(), 3);
234    }
235}