rexis_rag/query/
expander.rs

1//! # Query Expander
2//!
3//! Intelligent query expansion using synonyms, related terms, and semantic similarity.
4//! Improves recall by adding relevant terms that might appear in target documents.
5
6use crate::RragResult;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Query expander for adding related terms
11pub struct QueryExpander {
12    /// Configuration
13    config: ExpansionConfig,
14
15    /// Synonym dictionary
16    synonyms: HashMap<String, Vec<String>>,
17
18    /// Related terms dictionary
19    related_terms: HashMap<String, Vec<String>>,
20
21    /// Domain-specific expansions
22    domain_expansions: HashMap<String, HashMap<String, Vec<String>>>,
23}
24
25/// Configuration for query expansion
26#[derive(Debug, Clone)]
27pub struct ExpansionConfig {
28    /// Maximum number of synonyms to add
29    pub max_synonyms: usize,
30
31    /// Maximum number of related terms to add
32    pub max_related_terms: usize,
33
34    /// Enable synonym expansion
35    pub enable_synonyms: bool,
36
37    /// Enable related term expansion
38    pub enable_related_terms: bool,
39
40    /// Enable semantic expansion
41    pub enable_semantic_expansion: bool,
42
43    /// Enable domain-specific expansion
44    pub enable_domain_expansion: bool,
45
46    /// Minimum relevance score for expansions
47    pub min_relevance_score: f32,
48}
49
50impl Default for ExpansionConfig {
51    fn default() -> Self {
52        Self {
53            max_synonyms: 3,
54            max_related_terms: 2,
55            enable_synonyms: true,
56            enable_related_terms: true,
57            enable_semantic_expansion: true,
58            enable_domain_expansion: true,
59            min_relevance_score: 0.6,
60        }
61    }
62}
63
64/// Expansion strategies
65#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
66pub enum ExpansionStrategy {
67    /// Add synonyms
68    Synonyms,
69    /// Add related terms
70    RelatedTerms,
71    /// Semantic expansion using embeddings
72    Semantic,
73    /// Domain-specific expansion
74    DomainSpecific,
75    /// Contextual expansion
76    Contextual,
77}
78
79/// Result of query expansion
80#[derive(Debug, Clone)]
81pub struct ExpansionResult {
82    /// Original query
83    pub original_query: String,
84
85    /// Expanded query
86    pub expanded_query: String,
87
88    /// Terms that were added
89    pub added_terms: Vec<String>,
90
91    /// Expansion strategy used
92    pub expansion_type: ExpansionStrategy,
93
94    /// Confidence score (0.0 to 1.0)
95    pub confidence: f32,
96
97    /// Relevance scores for added terms
98    pub term_scores: HashMap<String, f32>,
99}
100
101impl QueryExpander {
102    /// Create a new query expander
103    pub fn new(config: ExpansionConfig) -> Self {
104        let synonyms = Self::init_synonyms();
105        let related_terms = Self::init_related_terms();
106        let domain_expansions = Self::init_domain_expansions();
107
108        Self {
109            config,
110            synonyms,
111            related_terms,
112            domain_expansions,
113        }
114    }
115
116    /// Expand a query using all enabled strategies
117    pub async fn expand(&self, query: &str) -> RragResult<Vec<ExpansionResult>> {
118        let mut results = Vec::new();
119
120        // Tokenize query
121        let tokens = self.tokenize(query);
122
123        // Apply synonym expansion
124        if self.config.enable_synonyms {
125            if let Some(result) = self.expand_with_synonyms(query, &tokens) {
126                if result.confidence >= self.config.min_relevance_score {
127                    results.push(result);
128                }
129            }
130        }
131
132        // Apply related terms expansion
133        if self.config.enable_related_terms {
134            if let Some(result) = self.expand_with_related_terms(query, &tokens) {
135                if result.confidence >= self.config.min_relevance_score {
136                    results.push(result);
137                }
138            }
139        }
140
141        // Apply semantic expansion
142        if self.config.enable_semantic_expansion {
143            if let Some(result) = self.expand_semantically(query, &tokens) {
144                if result.confidence >= self.config.min_relevance_score {
145                    results.push(result);
146                }
147            }
148        }
149
150        // Apply domain-specific expansion
151        if self.config.enable_domain_expansion {
152            let domain_results = self.expand_domain_specific(query, &tokens);
153            results.extend(
154                domain_results
155                    .into_iter()
156                    .filter(|r| r.confidence >= self.config.min_relevance_score),
157            );
158        }
159
160        Ok(results)
161    }
162
163    /// Expand query with synonyms
164    fn expand_with_synonyms(&self, query: &str, tokens: &[String]) -> Option<ExpansionResult> {
165        let mut added_terms = Vec::new();
166        let mut term_scores = HashMap::new();
167
168        for token in tokens {
169            if let Some(synonyms) = self.synonyms.get(&token.to_lowercase()) {
170                for synonym in synonyms.iter().take(self.config.max_synonyms) {
171                    if !tokens
172                        .iter()
173                        .any(|t| t.to_lowercase() == synonym.to_lowercase())
174                    {
175                        added_terms.push(synonym.clone());
176                        term_scores.insert(synonym.clone(), 0.8); // Fixed score for synonyms
177                    }
178                }
179            }
180        }
181
182        if !added_terms.is_empty() {
183            let expanded_query = format!("{} {}", query, added_terms.join(" "));
184            Some(ExpansionResult {
185                original_query: query.to_string(),
186                expanded_query,
187                added_terms,
188                expansion_type: ExpansionStrategy::Synonyms,
189                confidence: 0.8,
190                term_scores,
191            })
192        } else {
193            None
194        }
195    }
196
197    /// Expand query with related terms
198    fn expand_with_related_terms(&self, query: &str, tokens: &[String]) -> Option<ExpansionResult> {
199        let mut added_terms = Vec::new();
200        let mut term_scores = HashMap::new();
201
202        for token in tokens {
203            if let Some(related) = self.related_terms.get(&token.to_lowercase()) {
204                for term in related.iter().take(self.config.max_related_terms) {
205                    if !tokens
206                        .iter()
207                        .any(|t| t.to_lowercase() == term.to_lowercase())
208                    {
209                        added_terms.push(term.clone());
210                        term_scores.insert(term.clone(), 0.7); // Slightly lower than synonyms
211                    }
212                }
213            }
214        }
215
216        if !added_terms.is_empty() {
217            let expanded_query = format!("{} {}", query, added_terms.join(" "));
218            Some(ExpansionResult {
219                original_query: query.to_string(),
220                expanded_query,
221                added_terms,
222                expansion_type: ExpansionStrategy::RelatedTerms,
223                confidence: 0.7,
224                term_scores,
225            })
226        } else {
227            None
228        }
229    }
230
231    /// Expand query semantically
232    fn expand_semantically(&self, query: &str, _tokens: &[String]) -> Option<ExpansionResult> {
233        // For now, implement a simple semantic expansion
234        // In production, this would use word embeddings or language models
235        let semantic_expansions = self.get_semantic_expansions(query);
236
237        if !semantic_expansions.is_empty() {
238            let mut term_scores = HashMap::new();
239            for term in &semantic_expansions {
240                term_scores.insert(term.clone(), 0.6);
241            }
242
243            let expanded_query = format!("{} {}", query, semantic_expansions.join(" "));
244            Some(ExpansionResult {
245                original_query: query.to_string(),
246                expanded_query,
247                added_terms: semantic_expansions,
248                expansion_type: ExpansionStrategy::Semantic,
249                confidence: 0.6,
250                term_scores,
251            })
252        } else {
253            None
254        }
255    }
256
257    /// Apply domain-specific expansions
258    fn expand_domain_specific(&self, query: &str, tokens: &[String]) -> Vec<ExpansionResult> {
259        let mut results = Vec::new();
260
261        // Detect domain
262        let domain = self.detect_domain(tokens);
263
264        if let Some(domain_dict) = self.domain_expansions.get(&domain) {
265            for token in tokens {
266                if let Some(expansions) = domain_dict.get(&token.to_lowercase()) {
267                    let mut term_scores = HashMap::new();
268                    for term in expansions {
269                        term_scores.insert(term.clone(), 0.75);
270                    }
271
272                    let expanded_query = format!("{} {}", query, expansions.join(" "));
273                    results.push(ExpansionResult {
274                        original_query: query.to_string(),
275                        expanded_query,
276                        added_terms: expansions.clone(),
277                        expansion_type: ExpansionStrategy::DomainSpecific,
278                        confidence: 0.75,
279                        term_scores,
280                    });
281                }
282            }
283        }
284
285        results
286    }
287
288    /// Get semantic expansions for a query
289    fn get_semantic_expansions(&self, query: &str) -> Vec<String> {
290        // Simple rule-based semantic expansion
291        // In production, use proper semantic models
292        let mut expansions = Vec::new();
293
294        let query_lower = query.to_lowercase();
295
296        if query_lower.contains("learn") || query_lower.contains("study") {
297            expansions.extend_from_slice(&["education", "training", "tutorial"]);
298        }
299
300        if query_lower.contains("build") || query_lower.contains("create") {
301            expansions.extend_from_slice(&["develop", "construct", "implement"]);
302        }
303
304        if query_lower.contains("fast") || query_lower.contains("quick") {
305            expansions.extend_from_slice(&["rapid", "efficient", "performance"]);
306        }
307
308        if query_lower.contains("problem") || query_lower.contains("issue") {
309            expansions.extend_from_slice(&["solution", "fix", "troubleshoot"]);
310        }
311
312        expansions.into_iter().map(String::from).collect()
313    }
314
315    /// Detect the domain of a query
316    fn detect_domain(&self, tokens: &[String]) -> String {
317        let tech_terms = [
318            "code",
319            "programming",
320            "software",
321            "api",
322            "database",
323            "algorithm",
324        ];
325        let business_terms = ["market", "sales", "revenue", "customer", "profit"];
326        let science_terms = ["research", "study", "experiment", "theory", "analysis"];
327
328        let tokens_lower: Vec<String> = tokens.iter().map(|t| t.to_lowercase()).collect();
329
330        let tech_count = tech_terms
331            .iter()
332            .filter(|&&term| tokens_lower.iter().any(|t| t.contains(term)))
333            .count();
334        let business_count = business_terms
335            .iter()
336            .filter(|&&term| tokens_lower.iter().any(|t| t.contains(term)))
337            .count();
338        let science_count = science_terms
339            .iter()
340            .filter(|&&term| tokens_lower.iter().any(|t| t.contains(term)))
341            .count();
342
343        if tech_count > business_count && tech_count > science_count {
344            "technology".to_string()
345        } else if business_count > science_count {
346            "business".to_string()
347        } else if science_count > 0 {
348            "science".to_string()
349        } else {
350            "general".to_string()
351        }
352    }
353
354    /// Tokenize query into individual terms
355    fn tokenize(&self, query: &str) -> Vec<String> {
356        query
357            .to_lowercase()
358            .split_whitespace()
359            .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
360            .filter(|s| !s.is_empty())
361            .filter(|s| s.len() > 2) // Filter out very short words
362            .map(String::from)
363            .collect()
364    }
365
366    /// Initialize synonym dictionary
367    fn init_synonyms() -> HashMap<String, Vec<String>> {
368        let mut synonyms = HashMap::new();
369
370        // Technology synonyms
371        synonyms.insert(
372            "fast".to_string(),
373            vec![
374                "quick".to_string(),
375                "rapid".to_string(),
376                "speedy".to_string(),
377            ],
378        );
379        synonyms.insert(
380            "big".to_string(),
381            vec![
382                "large".to_string(),
383                "huge".to_string(),
384                "massive".to_string(),
385            ],
386        );
387        synonyms.insert(
388            "small".to_string(),
389            vec![
390                "tiny".to_string(),
391                "little".to_string(),
392                "compact".to_string(),
393            ],
394        );
395        synonyms.insert(
396            "good".to_string(),
397            vec![
398                "excellent".to_string(),
399                "great".to_string(),
400                "quality".to_string(),
401            ],
402        );
403        synonyms.insert(
404            "bad".to_string(),
405            vec![
406                "poor".to_string(),
407                "terrible".to_string(),
408                "awful".to_string(),
409            ],
410        );
411        synonyms.insert(
412            "simple".to_string(),
413            vec![
414                "easy".to_string(),
415                "basic".to_string(),
416                "straightforward".to_string(),
417            ],
418        );
419        synonyms.insert(
420            "difficult".to_string(),
421            vec![
422                "hard".to_string(),
423                "challenging".to_string(),
424                "complex".to_string(),
425            ],
426        );
427        synonyms.insert(
428            "method".to_string(),
429            vec![
430                "approach".to_string(),
431                "technique".to_string(),
432                "way".to_string(),
433            ],
434        );
435        synonyms.insert(
436            "create".to_string(),
437            vec![
438                "build".to_string(),
439                "make".to_string(),
440                "develop".to_string(),
441            ],
442        );
443        synonyms.insert(
444            "use".to_string(),
445            vec![
446                "utilize".to_string(),
447                "employ".to_string(),
448                "apply".to_string(),
449            ],
450        );
451
452        synonyms
453    }
454
455    /// Initialize related terms dictionary
456    fn init_related_terms() -> HashMap<String, Vec<String>> {
457        let mut related = HashMap::new();
458
459        // Technology related terms
460        related.insert(
461            "programming".to_string(),
462            vec![
463                "coding".to_string(),
464                "development".to_string(),
465                "software".to_string(),
466            ],
467        );
468        related.insert(
469            "database".to_string(),
470            vec![
471                "data".to_string(),
472                "storage".to_string(),
473                "query".to_string(),
474            ],
475        );
476        related.insert(
477            "algorithm".to_string(),
478            vec![
479                "logic".to_string(),
480                "computation".to_string(),
481                "optimization".to_string(),
482            ],
483        );
484        related.insert(
485            "machine".to_string(),
486            vec![
487                "learning".to_string(),
488                "ai".to_string(),
489                "model".to_string(),
490            ],
491        );
492        related.insert(
493            "web".to_string(),
494            vec![
495                "website".to_string(),
496                "internet".to_string(),
497                "browser".to_string(),
498            ],
499        );
500        related.insert(
501            "api".to_string(),
502            vec![
503                "interface".to_string(),
504                "endpoint".to_string(),
505                "service".to_string(),
506            ],
507        );
508        related.insert(
509            "security".to_string(),
510            vec![
511                "encryption".to_string(),
512                "authentication".to_string(),
513                "protection".to_string(),
514            ],
515        );
516        related.insert(
517            "performance".to_string(),
518            vec![
519                "speed".to_string(),
520                "optimization".to_string(),
521                "efficiency".to_string(),
522            ],
523        );
524
525        related
526    }
527
528    /// Initialize domain-specific expansions
529    fn init_domain_expansions() -> HashMap<String, HashMap<String, Vec<String>>> {
530        let mut domains = HashMap::new();
531
532        // Technology domain
533        let mut tech_expansions = HashMap::new();
534        tech_expansions.insert(
535            "ml".to_string(),
536            vec![
537                "machine learning".to_string(),
538                "artificial intelligence".to_string(),
539            ],
540        );
541        tech_expansions.insert(
542            "ai".to_string(),
543            vec![
544                "artificial intelligence".to_string(),
545                "machine learning".to_string(),
546                "neural networks".to_string(),
547            ],
548        );
549        tech_expansions.insert(
550            "nlp".to_string(),
551            vec![
552                "natural language processing".to_string(),
553                "text analysis".to_string(),
554            ],
555        );
556        tech_expansions.insert(
557            "api".to_string(),
558            vec![
559                "rest".to_string(),
560                "endpoint".to_string(),
561                "microservice".to_string(),
562            ],
563        );
564        tech_expansions.insert(
565            "db".to_string(),
566            vec![
567                "database".to_string(),
568                "sql".to_string(),
569                "storage".to_string(),
570            ],
571        );
572
573        domains.insert("technology".to_string(), tech_expansions);
574
575        // Business domain
576        let mut business_expansions = HashMap::new();
577        business_expansions.insert(
578            "roi".to_string(),
579            vec![
580                "return on investment".to_string(),
581                "profitability".to_string(),
582            ],
583        );
584        business_expansions.insert(
585            "kpi".to_string(),
586            vec![
587                "key performance indicator".to_string(),
588                "metrics".to_string(),
589            ],
590        );
591        business_expansions.insert(
592            "b2b".to_string(),
593            vec!["business to business".to_string(), "enterprise".to_string()],
594        );
595        business_expansions.insert(
596            "b2c".to_string(),
597            vec!["business to consumer".to_string(), "retail".to_string()],
598        );
599
600        domains.insert("business".to_string(), business_expansions);
601
602        domains
603    }
604}
605
606#[cfg(test)]
607mod tests {
608    use super::*;
609
610    #[tokio::test]
611    async fn test_synonym_expansion() {
612        let expander = QueryExpander::new(ExpansionConfig::default());
613
614        let results = expander.expand("fast algorithm").await.unwrap();
615
616        let synonym_result = results
617            .iter()
618            .find(|r| r.expansion_type == ExpansionStrategy::Synonyms);
619        assert!(synonym_result.is_some());
620
621        let result = synonym_result.unwrap();
622        assert!(result.expanded_query.contains("quick") || result.expanded_query.contains("rapid"));
623    }
624
625    #[tokio::test]
626    async fn test_domain_expansion() {
627        let expander = QueryExpander::new(ExpansionConfig::default());
628
629        let results = expander.expand("ML model").await.unwrap();
630
631        let domain_result = results
632            .iter()
633            .find(|r| r.expansion_type == ExpansionStrategy::DomainSpecific);
634        assert!(domain_result.is_some());
635
636        let result = domain_result.unwrap();
637        assert!(result.expanded_query.contains("machine learning"));
638    }
639}