1use crate::RragResult;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10pub struct QueryExpander {
12 config: ExpansionConfig,
14
15 synonyms: HashMap<String, Vec<String>>,
17
18 related_terms: HashMap<String, Vec<String>>,
20
21 domain_expansions: HashMap<String, HashMap<String, Vec<String>>>,
23}
24
25#[derive(Debug, Clone)]
27pub struct ExpansionConfig {
28 pub max_synonyms: usize,
30
31 pub max_related_terms: usize,
33
34 pub enable_synonyms: bool,
36
37 pub enable_related_terms: bool,
39
40 pub enable_semantic_expansion: bool,
42
43 pub enable_domain_expansion: bool,
45
46 pub min_relevance_score: f32,
48}
49
50impl Default for ExpansionConfig {
51 fn default() -> Self {
52 Self {
53 max_synonyms: 3,
54 max_related_terms: 2,
55 enable_synonyms: true,
56 enable_related_terms: true,
57 enable_semantic_expansion: true,
58 enable_domain_expansion: true,
59 min_relevance_score: 0.6,
60 }
61 }
62}
63
64#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
66pub enum ExpansionStrategy {
67 Synonyms,
69 RelatedTerms,
71 Semantic,
73 DomainSpecific,
75 Contextual,
77}
78
79#[derive(Debug, Clone)]
81pub struct ExpansionResult {
82 pub original_query: String,
84
85 pub expanded_query: String,
87
88 pub added_terms: Vec<String>,
90
91 pub expansion_type: ExpansionStrategy,
93
94 pub confidence: f32,
96
97 pub term_scores: HashMap<String, f32>,
99}
100
101impl QueryExpander {
102 pub fn new(config: ExpansionConfig) -> Self {
104 let synonyms = Self::init_synonyms();
105 let related_terms = Self::init_related_terms();
106 let domain_expansions = Self::init_domain_expansions();
107
108 Self {
109 config,
110 synonyms,
111 related_terms,
112 domain_expansions,
113 }
114 }
115
116 pub async fn expand(&self, query: &str) -> RragResult<Vec<ExpansionResult>> {
118 let mut results = Vec::new();
119
120 let tokens = self.tokenize(query);
122
123 if self.config.enable_synonyms {
125 if let Some(result) = self.expand_with_synonyms(query, &tokens) {
126 if result.confidence >= self.config.min_relevance_score {
127 results.push(result);
128 }
129 }
130 }
131
132 if self.config.enable_related_terms {
134 if let Some(result) = self.expand_with_related_terms(query, &tokens) {
135 if result.confidence >= self.config.min_relevance_score {
136 results.push(result);
137 }
138 }
139 }
140
141 if self.config.enable_semantic_expansion {
143 if let Some(result) = self.expand_semantically(query, &tokens) {
144 if result.confidence >= self.config.min_relevance_score {
145 results.push(result);
146 }
147 }
148 }
149
150 if self.config.enable_domain_expansion {
152 let domain_results = self.expand_domain_specific(query, &tokens);
153 results.extend(
154 domain_results
155 .into_iter()
156 .filter(|r| r.confidence >= self.config.min_relevance_score),
157 );
158 }
159
160 Ok(results)
161 }
162
163 fn expand_with_synonyms(&self, query: &str, tokens: &[String]) -> Option<ExpansionResult> {
165 let mut added_terms = Vec::new();
166 let mut term_scores = HashMap::new();
167
168 for token in tokens {
169 if let Some(synonyms) = self.synonyms.get(&token.to_lowercase()) {
170 for synonym in synonyms.iter().take(self.config.max_synonyms) {
171 if !tokens
172 .iter()
173 .any(|t| t.to_lowercase() == synonym.to_lowercase())
174 {
175 added_terms.push(synonym.clone());
176 term_scores.insert(synonym.clone(), 0.8); }
178 }
179 }
180 }
181
182 if !added_terms.is_empty() {
183 let expanded_query = format!("{} {}", query, added_terms.join(" "));
184 Some(ExpansionResult {
185 original_query: query.to_string(),
186 expanded_query,
187 added_terms,
188 expansion_type: ExpansionStrategy::Synonyms,
189 confidence: 0.8,
190 term_scores,
191 })
192 } else {
193 None
194 }
195 }
196
197 fn expand_with_related_terms(&self, query: &str, tokens: &[String]) -> Option<ExpansionResult> {
199 let mut added_terms = Vec::new();
200 let mut term_scores = HashMap::new();
201
202 for token in tokens {
203 if let Some(related) = self.related_terms.get(&token.to_lowercase()) {
204 for term in related.iter().take(self.config.max_related_terms) {
205 if !tokens
206 .iter()
207 .any(|t| t.to_lowercase() == term.to_lowercase())
208 {
209 added_terms.push(term.clone());
210 term_scores.insert(term.clone(), 0.7); }
212 }
213 }
214 }
215
216 if !added_terms.is_empty() {
217 let expanded_query = format!("{} {}", query, added_terms.join(" "));
218 Some(ExpansionResult {
219 original_query: query.to_string(),
220 expanded_query,
221 added_terms,
222 expansion_type: ExpansionStrategy::RelatedTerms,
223 confidence: 0.7,
224 term_scores,
225 })
226 } else {
227 None
228 }
229 }
230
231 fn expand_semantically(&self, query: &str, _tokens: &[String]) -> Option<ExpansionResult> {
233 let semantic_expansions = self.get_semantic_expansions(query);
236
237 if !semantic_expansions.is_empty() {
238 let mut term_scores = HashMap::new();
239 for term in &semantic_expansions {
240 term_scores.insert(term.clone(), 0.6);
241 }
242
243 let expanded_query = format!("{} {}", query, semantic_expansions.join(" "));
244 Some(ExpansionResult {
245 original_query: query.to_string(),
246 expanded_query,
247 added_terms: semantic_expansions,
248 expansion_type: ExpansionStrategy::Semantic,
249 confidence: 0.6,
250 term_scores,
251 })
252 } else {
253 None
254 }
255 }
256
257 fn expand_domain_specific(&self, query: &str, tokens: &[String]) -> Vec<ExpansionResult> {
259 let mut results = Vec::new();
260
261 let domain = self.detect_domain(tokens);
263
264 if let Some(domain_dict) = self.domain_expansions.get(&domain) {
265 for token in tokens {
266 if let Some(expansions) = domain_dict.get(&token.to_lowercase()) {
267 let mut term_scores = HashMap::new();
268 for term in expansions {
269 term_scores.insert(term.clone(), 0.75);
270 }
271
272 let expanded_query = format!("{} {}", query, expansions.join(" "));
273 results.push(ExpansionResult {
274 original_query: query.to_string(),
275 expanded_query,
276 added_terms: expansions.clone(),
277 expansion_type: ExpansionStrategy::DomainSpecific,
278 confidence: 0.75,
279 term_scores,
280 });
281 }
282 }
283 }
284
285 results
286 }
287
288 fn get_semantic_expansions(&self, query: &str) -> Vec<String> {
290 let mut expansions = Vec::new();
293
294 let query_lower = query.to_lowercase();
295
296 if query_lower.contains("learn") || query_lower.contains("study") {
297 expansions.extend_from_slice(&["education", "training", "tutorial"]);
298 }
299
300 if query_lower.contains("build") || query_lower.contains("create") {
301 expansions.extend_from_slice(&["develop", "construct", "implement"]);
302 }
303
304 if query_lower.contains("fast") || query_lower.contains("quick") {
305 expansions.extend_from_slice(&["rapid", "efficient", "performance"]);
306 }
307
308 if query_lower.contains("problem") || query_lower.contains("issue") {
309 expansions.extend_from_slice(&["solution", "fix", "troubleshoot"]);
310 }
311
312 expansions.into_iter().map(String::from).collect()
313 }
314
315 fn detect_domain(&self, tokens: &[String]) -> String {
317 let tech_terms = [
318 "code",
319 "programming",
320 "software",
321 "api",
322 "database",
323 "algorithm",
324 ];
325 let business_terms = ["market", "sales", "revenue", "customer", "profit"];
326 let science_terms = ["research", "study", "experiment", "theory", "analysis"];
327
328 let tokens_lower: Vec<String> = tokens.iter().map(|t| t.to_lowercase()).collect();
329
330 let tech_count = tech_terms
331 .iter()
332 .filter(|&&term| tokens_lower.iter().any(|t| t.contains(term)))
333 .count();
334 let business_count = business_terms
335 .iter()
336 .filter(|&&term| tokens_lower.iter().any(|t| t.contains(term)))
337 .count();
338 let science_count = science_terms
339 .iter()
340 .filter(|&&term| tokens_lower.iter().any(|t| t.contains(term)))
341 .count();
342
343 if tech_count > business_count && tech_count > science_count {
344 "technology".to_string()
345 } else if business_count > science_count {
346 "business".to_string()
347 } else if science_count > 0 {
348 "science".to_string()
349 } else {
350 "general".to_string()
351 }
352 }
353
354 fn tokenize(&self, query: &str) -> Vec<String> {
356 query
357 .to_lowercase()
358 .split_whitespace()
359 .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
360 .filter(|s| !s.is_empty())
361 .filter(|s| s.len() > 2) .map(String::from)
363 .collect()
364 }
365
366 fn init_synonyms() -> HashMap<String, Vec<String>> {
368 let mut synonyms = HashMap::new();
369
370 synonyms.insert(
372 "fast".to_string(),
373 vec![
374 "quick".to_string(),
375 "rapid".to_string(),
376 "speedy".to_string(),
377 ],
378 );
379 synonyms.insert(
380 "big".to_string(),
381 vec![
382 "large".to_string(),
383 "huge".to_string(),
384 "massive".to_string(),
385 ],
386 );
387 synonyms.insert(
388 "small".to_string(),
389 vec![
390 "tiny".to_string(),
391 "little".to_string(),
392 "compact".to_string(),
393 ],
394 );
395 synonyms.insert(
396 "good".to_string(),
397 vec![
398 "excellent".to_string(),
399 "great".to_string(),
400 "quality".to_string(),
401 ],
402 );
403 synonyms.insert(
404 "bad".to_string(),
405 vec![
406 "poor".to_string(),
407 "terrible".to_string(),
408 "awful".to_string(),
409 ],
410 );
411 synonyms.insert(
412 "simple".to_string(),
413 vec![
414 "easy".to_string(),
415 "basic".to_string(),
416 "straightforward".to_string(),
417 ],
418 );
419 synonyms.insert(
420 "difficult".to_string(),
421 vec![
422 "hard".to_string(),
423 "challenging".to_string(),
424 "complex".to_string(),
425 ],
426 );
427 synonyms.insert(
428 "method".to_string(),
429 vec![
430 "approach".to_string(),
431 "technique".to_string(),
432 "way".to_string(),
433 ],
434 );
435 synonyms.insert(
436 "create".to_string(),
437 vec![
438 "build".to_string(),
439 "make".to_string(),
440 "develop".to_string(),
441 ],
442 );
443 synonyms.insert(
444 "use".to_string(),
445 vec![
446 "utilize".to_string(),
447 "employ".to_string(),
448 "apply".to_string(),
449 ],
450 );
451
452 synonyms
453 }
454
455 fn init_related_terms() -> HashMap<String, Vec<String>> {
457 let mut related = HashMap::new();
458
459 related.insert(
461 "programming".to_string(),
462 vec![
463 "coding".to_string(),
464 "development".to_string(),
465 "software".to_string(),
466 ],
467 );
468 related.insert(
469 "database".to_string(),
470 vec![
471 "data".to_string(),
472 "storage".to_string(),
473 "query".to_string(),
474 ],
475 );
476 related.insert(
477 "algorithm".to_string(),
478 vec![
479 "logic".to_string(),
480 "computation".to_string(),
481 "optimization".to_string(),
482 ],
483 );
484 related.insert(
485 "machine".to_string(),
486 vec![
487 "learning".to_string(),
488 "ai".to_string(),
489 "model".to_string(),
490 ],
491 );
492 related.insert(
493 "web".to_string(),
494 vec![
495 "website".to_string(),
496 "internet".to_string(),
497 "browser".to_string(),
498 ],
499 );
500 related.insert(
501 "api".to_string(),
502 vec![
503 "interface".to_string(),
504 "endpoint".to_string(),
505 "service".to_string(),
506 ],
507 );
508 related.insert(
509 "security".to_string(),
510 vec![
511 "encryption".to_string(),
512 "authentication".to_string(),
513 "protection".to_string(),
514 ],
515 );
516 related.insert(
517 "performance".to_string(),
518 vec![
519 "speed".to_string(),
520 "optimization".to_string(),
521 "efficiency".to_string(),
522 ],
523 );
524
525 related
526 }
527
528 fn init_domain_expansions() -> HashMap<String, HashMap<String, Vec<String>>> {
530 let mut domains = HashMap::new();
531
532 let mut tech_expansions = HashMap::new();
534 tech_expansions.insert(
535 "ml".to_string(),
536 vec![
537 "machine learning".to_string(),
538 "artificial intelligence".to_string(),
539 ],
540 );
541 tech_expansions.insert(
542 "ai".to_string(),
543 vec![
544 "artificial intelligence".to_string(),
545 "machine learning".to_string(),
546 "neural networks".to_string(),
547 ],
548 );
549 tech_expansions.insert(
550 "nlp".to_string(),
551 vec![
552 "natural language processing".to_string(),
553 "text analysis".to_string(),
554 ],
555 );
556 tech_expansions.insert(
557 "api".to_string(),
558 vec![
559 "rest".to_string(),
560 "endpoint".to_string(),
561 "microservice".to_string(),
562 ],
563 );
564 tech_expansions.insert(
565 "db".to_string(),
566 vec![
567 "database".to_string(),
568 "sql".to_string(),
569 "storage".to_string(),
570 ],
571 );
572
573 domains.insert("technology".to_string(), tech_expansions);
574
575 let mut business_expansions = HashMap::new();
577 business_expansions.insert(
578 "roi".to_string(),
579 vec![
580 "return on investment".to_string(),
581 "profitability".to_string(),
582 ],
583 );
584 business_expansions.insert(
585 "kpi".to_string(),
586 vec![
587 "key performance indicator".to_string(),
588 "metrics".to_string(),
589 ],
590 );
591 business_expansions.insert(
592 "b2b".to_string(),
593 vec!["business to business".to_string(), "enterprise".to_string()],
594 );
595 business_expansions.insert(
596 "b2c".to_string(),
597 vec!["business to consumer".to_string(), "retail".to_string()],
598 );
599
600 domains.insert("business".to_string(), business_expansions);
601
602 domains
603 }
604}
605
606#[cfg(test)]
607mod tests {
608 use super::*;
609
610 #[tokio::test]
611 async fn test_synonym_expansion() {
612 let expander = QueryExpander::new(ExpansionConfig::default());
613
614 let results = expander.expand("fast algorithm").await.unwrap();
615
616 let synonym_result = results
617 .iter()
618 .find(|r| r.expansion_type == ExpansionStrategy::Synonyms);
619 assert!(synonym_result.is_some());
620
621 let result = synonym_result.unwrap();
622 assert!(result.expanded_query.contains("quick") || result.expanded_query.contains("rapid"));
623 }
624
625 #[tokio::test]
626 async fn test_domain_expansion() {
627 let expander = QueryExpander::new(ExpansionConfig::default());
628
629 let results = expander.expand("ML model").await.unwrap();
630
631 let domain_result = results
632 .iter()
633 .find(|r| r.expansion_type == ExpansionStrategy::DomainSpecific);
634 assert!(domain_result.is_some());
635
636 let result = domain_result.unwrap();
637 assert!(result.expanded_query.contains("machine learning"));
638 }
639}