Skip to main content

graphrag_core/query/
intelligence.rs

1//! Query Intelligence and Rewriting
2//!
3//! This module provides intelligent query processing including:
4//! - Query rewriting and expansion
5//! - Synonym expansion
6//! - Relevance feedback learning
7//! - Query templates
8//! - Natural language to structured query conversion
9//! - Query performance analysis
10
11use std::collections::{HashMap, HashSet};
12use serde::{Deserialize, Serialize};
13
14/// Query intelligence engine
15pub struct QueryIntelligence {
16    synonyms: HashMap<String, Vec<String>>,
17    templates: Vec<QueryTemplate>,
18    stop_words: HashSet<String>,
19    relevance_scores: HashMap<String, f32>,
20}
21
22/// Query template
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct QueryTemplate {
25    /// Regex pattern to match queries
26    pub pattern: String,
27    /// Type of query this template matches
28    pub query_type: QueryType,
29    /// Rewrite template for query optimization
30    pub rewrite: String,
31}
32
33/// Query type classification
34#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
35pub enum QueryType {
36    /// Entity lookup queries
37    EntityLookup,
38    /// Relationship queries
39    Relationship,
40    /// Aggregation queries
41    Aggregation,
42    /// Comparison queries
43    Comparison,
44    /// Temporal queries
45    Temporal,
46    /// Causal queries
47    Causal,
48    /// General queries
49    General,
50}
51
52/// Rewritten query
53#[derive(Debug, Clone)]
54pub struct RewrittenQuery {
55    /// Original query text
56    pub original: String,
57    /// Rewritten optimized query
58    pub rewritten: String,
59    /// Detected query type
60    pub query_type: QueryType,
61    /// Expanded search terms
62    pub expanded_terms: Vec<String>,
63    /// Confidence score of rewrite
64    pub confidence: f32,
65}
66
67impl QueryIntelligence {
68    /// Create a new query intelligence engine with default settings
69    pub fn new() -> Self {
70        let mut engine = Self {
71            synonyms: HashMap::new(),
72            templates: Vec::new(),
73            stop_words: HashSet::new(),
74            relevance_scores: HashMap::new(),
75        };
76
77        // Initialize with default templates and synonyms
78        engine.load_default_synonyms();
79        engine.load_default_templates();
80        engine.load_default_stop_words();
81
82        engine
83    }
84
85    /// Rewrite and expand a query
86    ///
87    /// # Arguments
88    /// * `query` - The original query string
89    ///
90    /// # Returns
91    /// RewrittenQuery with expanded terms and detected query type
92    pub fn rewrite_query(&self, query: &str) -> RewrittenQuery {
93        // Normalize query
94        let normalized = self.normalize_query(query);
95
96        // Detect query type
97        let query_type = self.detect_query_type(&normalized);
98
99        // Apply template matching
100        let template_rewritten = self.apply_templates(&normalized, &query_type);
101
102        // Expand synonyms
103        let expanded = self.expand_synonyms(&template_rewritten);
104
105        // Extract key terms
106        let expanded_terms = self.extract_key_terms(&expanded);
107
108        // Calculate confidence
109        let confidence = self.calculate_confidence(&normalized, &expanded_terms);
110
111        RewrittenQuery {
112            original: query.to_string(),
113            rewritten: expanded,
114            query_type,
115            expanded_terms,
116            confidence,
117        }
118    }
119
120    /// Add a custom synonym mapping
121    ///
122    /// # Arguments
123    /// * `term` - The original term
124    /// * `synonyms` - List of synonyms
125    pub fn add_synonym(&mut self, term: impl Into<String>, synonyms: Vec<String>) {
126        // Normalize the term to lowercase for consistent lookup
127        self.synonyms.insert(term.into().to_lowercase(), synonyms);
128    }
129
130    /// Add a query template
131    ///
132    /// # Arguments
133    /// * `template` - Query template
134    pub fn add_template(&mut self, template: QueryTemplate) {
135        self.templates.push(template);
136    }
137
138    /// Record relevance feedback
139    ///
140    /// # Arguments
141    /// * `term` - The search term
142    /// * `score` - Relevance score (0.0 to 1.0)
143    pub fn record_feedback(&mut self, term: impl Into<String>, score: f32) {
144        let term = term.into();
145        let current_score = self.relevance_scores.get(&term).unwrap_or(&0.5);
146        // Weighted average with new feedback (equal weight for faster learning)
147        let new_score = current_score * 0.5 + score * 0.5;
148        self.relevance_scores.insert(term, new_score);
149    }
150
151    /// Get relevance score for a term
152    ///
153    /// # Arguments
154    /// * `term` - The term to check
155    ///
156    /// # Returns
157    /// Relevance score between 0.0 and 1.0
158    pub fn get_relevance(&self, term: &str) -> f32 {
159        *self.relevance_scores.get(term).unwrap_or(&0.5)
160    }
161
162    // --- Private methods ---
163
164    /// Normalize query (lowercase, trim, etc.)
165    fn normalize_query(&self, query: &str) -> String {
166        query.trim().to_lowercase()
167    }
168
169    /// Detect query type based on patterns
170    fn detect_query_type(&self, query: &str) -> QueryType {
171        let query_lower = query.to_lowercase();
172
173        // Relationship patterns (check before entity lookup to handle "what is the relationship...")
174        if query_lower.contains("relationship between")
175            || query_lower.contains("how does")
176            || query_lower.contains("related to")
177            || query_lower.contains("connection between")
178        {
179            return QueryType::Relationship;
180        }
181
182        // Entity lookup patterns
183        if query_lower.starts_with("who is")
184            || query_lower.starts_with("what is")
185            || query_lower.starts_with("define")
186        {
187            return QueryType::EntityLookup;
188        }
189
190        // Aggregation patterns
191        if query_lower.starts_with("how many")
192            || query_lower.starts_with("count")
193            || query_lower.contains("total")
194            || query_lower.contains("sum")
195            || query_lower.contains("average")
196        {
197            return QueryType::Aggregation;
198        }
199
200        // Comparison patterns
201        if query_lower.contains("compare")
202            || query_lower.contains("difference between")
203            || query_lower.contains("versus")
204            || query_lower.contains("vs")
205        {
206            return QueryType::Comparison;
207        }
208
209        // Temporal patterns
210        if query_lower.contains("when")
211            || query_lower.contains("before")
212            || query_lower.contains("after")
213            || query_lower.contains("during")
214            || query_lower.contains("timeline")
215        {
216            return QueryType::Temporal;
217        }
218
219        // Causal patterns
220        if query_lower.contains("why")
221            || query_lower.contains("because")
222            || query_lower.contains("cause")
223            || query_lower.contains("reason")
224            || query_lower.contains("led to")
225        {
226            return QueryType::Causal;
227        }
228
229        QueryType::General
230    }
231
232    /// Apply query templates
233    fn apply_templates(&self, query: &str, query_type: &QueryType) -> String {
234        for template in &self.templates {
235            if &template.query_type == query_type && query.contains(&template.pattern) {
236                return query.replace(&template.pattern, &template.rewrite);
237            }
238        }
239        query.to_string()
240    }
241
242    /// Expand query with synonyms
243    fn expand_synonyms(&self, query: &str) -> String {
244        let words: Vec<&str> = query.split_whitespace().collect();
245        let mut expanded_words = Vec::new();
246
247        for word in words {
248            expanded_words.push(word.to_string());
249
250            // Add synonyms if available
251            if let Some(synonyms) = self.synonyms.get(word) {
252                for synonym in synonyms {
253                    if !expanded_words.contains(synonym) {
254                        expanded_words.push(synonym.clone());
255                    }
256                }
257            }
258        }
259
260        expanded_words.join(" ")
261    }
262
263    /// Extract key terms (remove stop words)
264    fn extract_key_terms(&self, query: &str) -> Vec<String> {
265        query
266            .split_whitespace()
267            .filter(|word| !self.stop_words.contains(*word))
268            .map(|s| s.to_string())
269            .collect()
270    }
271
272    /// Calculate confidence score
273    fn calculate_confidence(&self, query: &str, expanded_terms: &[String]) -> f32 {
274        if expanded_terms.is_empty() {
275            return 0.5;
276        }
277
278        // Base confidence on query length and term count
279        let word_count = query.split_whitespace().count() as f32;
280        let term_count = expanded_terms.len() as f32;
281
282        // Higher confidence for more specific queries
283        let specificity_score = (term_count / (word_count + 1.0)).min(1.0);
284
285        // Factor in relevance feedback
286        let relevance_score: f32 = expanded_terms
287            .iter()
288            .map(|t| self.get_relevance(t))
289            .sum::<f32>()
290            / term_count;
291
292        // Weighted average
293        specificity_score * 0.6 + relevance_score * 0.4
294    }
295
296    /// Load default synonyms
297    fn load_default_synonyms(&mut self) {
298        // Common synonyms
299        self.add_synonym("find", vec!["search".to_string(), "locate".to_string()]);
300        self.add_synonym("person", vec!["individual".to_string(), "people".to_string()]);
301        self.add_synonym("company", vec!["organization".to_string(), "business".to_string(), "firm".to_string()]);
302        self.add_synonym("show", vec!["display".to_string(), "present".to_string()]);
303        self.add_synonym("get", vec!["retrieve".to_string(), "fetch".to_string()]);
304        self.add_synonym("large", vec!["big".to_string(), "huge".to_string(), "significant".to_string()]);
305        self.add_synonym("small", vec!["tiny".to_string(), "minor".to_string()]);
306        self.add_synonym("important", vec!["significant".to_string(), "critical".to_string(), "key".to_string()]);
307    }
308
309    /// Load default query templates
310    fn load_default_templates(&mut self) {
311        self.add_template(QueryTemplate {
312            pattern: "who is".to_string(),
313            query_type: QueryType::EntityLookup,
314            rewrite: "entity:".to_string(),
315        });
316
317        self.add_template(QueryTemplate {
318            pattern: "what is".to_string(),
319            query_type: QueryType::EntityLookup,
320            rewrite: "define:".to_string(),
321        });
322
323        self.add_template(QueryTemplate {
324            pattern: "how many".to_string(),
325            query_type: QueryType::Aggregation,
326            rewrite: "count:".to_string(),
327        });
328
329        self.add_template(QueryTemplate {
330            pattern: "compare".to_string(),
331            query_type: QueryType::Comparison,
332            rewrite: "compare:".to_string(),
333        });
334    }
335
336    /// Load default stop words
337    fn load_default_stop_words(&mut self) {
338        let stop_words = vec![
339            "a", "an", "and", "are", "as", "at", "be", "by", "for",
340            "from", "has", "he", "in", "is", "it", "its", "of", "on",
341            "that", "the", "to", "was", "will", "with",
342        ];
343
344        for word in stop_words {
345            self.stop_words.insert(word.to_string());
346        }
347    }
348}
349
350impl Default for QueryIntelligence {
351    fn default() -> Self {
352        Self::new()
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    #[test]
361    fn test_query_type_detection() {
362        let engine = QueryIntelligence::new();
363
364        let query = "who is the CEO of OpenAI?";
365        let result = engine.rewrite_query(query);
366        assert_eq!(result.query_type, QueryType::EntityLookup);
367
368        let query = "how many employees work at Google?";
369        let result = engine.rewrite_query(query);
370        assert_eq!(result.query_type, QueryType::Aggregation);
371
372        let query = "what is the relationship between Apple and Microsoft?";
373        let result = engine.rewrite_query(query);
374        assert_eq!(result.query_type, QueryType::Relationship);
375    }
376
377    #[test]
378    fn test_synonym_expansion() {
379        let engine = QueryIntelligence::new();
380
381        let query = "find large companies";
382        let result = engine.rewrite_query(query);
383
384        // Should expand "find" and "large"
385        assert!(result.expanded_terms.contains(&"search".to_string()) ||
386                result.expanded_terms.contains(&"big".to_string()));
387    }
388
389    #[test]
390    fn test_stop_word_removal() {
391        let engine = QueryIntelligence::new();
392
393        let query = "what is the best approach";
394        let result = engine.rewrite_query(query);
395
396        // "the" and "is" should be filtered out
397        assert!(!result.expanded_terms.contains(&"the".to_string()));
398        assert!(!result.expanded_terms.contains(&"is".to_string()));
399    }
400
401    #[test]
402    fn test_relevance_feedback() {
403        let mut engine = QueryIntelligence::new();
404
405        engine.record_feedback("artificial_intelligence", 0.9);
406        engine.record_feedback("artificial_intelligence", 0.8);
407
408        let score = engine.get_relevance("artificial_intelligence");
409        assert!(score > 0.7);
410    }
411
412    #[test]
413    fn test_custom_synonyms() {
414        let mut engine = QueryIntelligence::new();
415        engine.add_synonym("AI", vec!["artificial intelligence".to_string(), "machine learning".to_string()]);
416
417        let query = "AI applications";
418        let result = engine.rewrite_query(query);
419
420        assert!(result.rewritten.contains("artificial") || result.rewritten.contains("machine"));
421    }
422}