oxirs_vec/hybrid_search/
query_expansion.rs1use std::collections::{HashMap, HashSet};
4
5pub struct QueryExpander {
7 synonyms: HashMap<String, Vec<String>>,
9 max_expanded_terms: usize,
11}
12
13impl QueryExpander {
14 pub fn new(max_expanded_terms: usize) -> Self {
16 Self {
17 synonyms: Self::build_default_synonyms(),
18 max_expanded_terms,
19 }
20 }
21
22 fn build_default_synonyms() -> HashMap<String, Vec<String>> {
24 let mut synonyms = HashMap::new();
25
26 synonyms.insert(
28 "search".to_string(),
29 vec![
30 "find".to_string(),
31 "lookup".to_string(),
32 "query".to_string(),
33 ],
34 );
35 synonyms.insert(
36 "find".to_string(),
37 vec!["search".to_string(), "locate".to_string()],
38 );
39 synonyms.insert(
40 "fast".to_string(),
41 vec![
42 "quick".to_string(),
43 "rapid".to_string(),
44 "speedy".to_string(),
45 ],
46 );
47 synonyms.insert(
48 "slow".to_string(),
49 vec!["sluggish".to_string(), "gradual".to_string()],
50 );
51 synonyms.insert(
52 "big".to_string(),
53 vec![
54 "large".to_string(),
55 "huge".to_string(),
56 "massive".to_string(),
57 ],
58 );
59 synonyms.insert(
60 "small".to_string(),
61 vec![
62 "tiny".to_string(),
63 "little".to_string(),
64 "compact".to_string(),
65 ],
66 );
67 synonyms.insert(
68 "good".to_string(),
69 vec![
70 "great".to_string(),
71 "excellent".to_string(),
72 "superb".to_string(),
73 ],
74 );
75 synonyms.insert(
76 "bad".to_string(),
77 vec!["poor".to_string(), "terrible".to_string()],
78 );
79
80 synonyms
81 }
82
83 pub fn add_synonyms(&mut self, term: &str, synonyms: Vec<String>) {
85 self.synonyms.insert(term.to_string(), synonyms);
86 }
87
88 pub fn expand(&self, query: &str) -> Vec<String> {
90 let original_terms: Vec<String> = query
91 .to_lowercase()
92 .split_whitespace()
93 .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
94 .filter(|s| !s.is_empty())
95 .map(String::from)
96 .collect();
97
98 let mut expanded = HashSet::new();
99
100 for term in &original_terms {
102 expanded.insert(term.clone());
103 }
104
105 for term in &original_terms {
107 if let Some(syns) = self.synonyms.get(term) {
108 for syn in syns {
109 if expanded.len() < self.max_expanded_terms {
110 expanded.insert(syn.clone());
111 }
112 }
113 }
114 }
115
116 expanded.into_iter().collect()
117 }
118
119 pub fn expand_with_cooccurrence(
121 &self,
122 query: &str,
123 cooccurrence_map: &HashMap<String, Vec<(String, f32)>>,
124 threshold: f32,
125 ) -> Vec<String> {
126 let original_terms: Vec<String> = query
127 .to_lowercase()
128 .split_whitespace()
129 .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
130 .filter(|s| !s.is_empty())
131 .map(String::from)
132 .collect();
133
134 let mut expanded = HashSet::new();
135
136 for term in &original_terms {
138 expanded.insert(term.clone());
139 }
140
141 for term in &original_terms {
143 if let Some(cooccurrences) = cooccurrence_map.get(term) {
144 for (coterm, score) in cooccurrences {
145 if *score >= threshold && expanded.len() < self.max_expanded_terms {
146 expanded.insert(coterm.clone());
147 }
148 }
149 }
150 }
151
152 expanded.into_iter().collect()
153 }
154
155 pub fn synonym_count(&self) -> usize {
157 self.synonyms.len()
158 }
159}
160
161impl Default for QueryExpander {
162 fn default() -> Self {
163 Self::new(10)
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn test_basic_expansion() {
173 let expander = QueryExpander::new(10);
174 let expanded = expander.expand("fast search");
175
176 assert!(expanded.contains(&"fast".to_string()));
177 assert!(expanded.contains(&"search".to_string()));
178 assert!(expanded.len() > 2);
180 }
181
182 #[test]
183 fn test_max_expansion_limit() {
184 let expander = QueryExpander::new(3);
185 let expanded = expander.expand("fast search");
186
187 assert!(expanded.len() <= 3);
188 }
189
190 #[test]
191 fn test_custom_synonyms() {
192 let mut expander = QueryExpander::new(10);
193 expander.add_synonyms("ml", vec!["machine learning".to_string(), "ai".to_string()]);
194
195 let expanded = expander.expand("ml");
196 assert!(expanded.contains(&"ml".to_string()));
197 }
198
199 #[test]
200 fn test_cooccurrence_expansion() {
201 let expander = QueryExpander::new(10);
202 let mut cooccurrence = HashMap::new();
203 cooccurrence.insert(
204 "machine".to_string(),
205 vec![
206 ("learning".to_string(), 0.9),
207 ("intelligence".to_string(), 0.7),
208 ("car".to_string(), 0.2),
209 ],
210 );
211
212 let expanded = expander.expand_with_cooccurrence("machine", &cooccurrence, 0.5);
213
214 assert!(expanded.contains(&"machine".to_string()));
215 assert!(expanded.contains(&"learning".to_string()));
216 assert!(expanded.contains(&"intelligence".to_string()));
217 assert!(!expanded.contains(&"car".to_string())); }
219
220 #[test]
221 fn test_empty_query() {
222 let expander = QueryExpander::new(10);
223 let expanded = expander.expand("");
224 assert!(expanded.is_empty());
225 }
226
227 #[test]
228 fn test_unknown_terms() {
229 let expander = QueryExpander::new(10);
230 let expanded = expander.expand("zzz xyz abc");
231
232 assert_eq!(expanded.len(), 3);
234 }
235}