Skip to main content

graphrag_core/query/
adaptive_routing.rs

1//! Adaptive Query Routing for Hierarchical GraphRAG
2//!
3//! Automatically selects the appropriate hierarchical level based on query complexity.
4//!
5//! # Strategy
6//! - **Broad queries** (overview, themes, summary) → Higher levels (2-3)
7//! - **Specific queries** (relationships, details, entities) → Lower levels (0-1)
8//! - **Adaptive routing** based on query analysis
9
10use serde::{Deserialize, Serialize};
11
12/// Configuration for adaptive query routing
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct AdaptiveRoutingConfig {
15    /// Enable adaptive query routing
16    pub enabled: bool,
17
18    /// Default level when query complexity is unclear
19    pub default_level: usize,
20
21    /// Maximum hierarchical level available
22    pub max_level: usize,
23
24    /// Weight for keyword-based level selection (0.0-1.0)
25    pub keyword_weight: f32,
26
27    /// Weight for query length-based selection (0.0-1.0)
28    pub length_weight: f32,
29
30    /// Weight for entity mention-based selection (0.0-1.0)
31    pub entity_weight: f32,
32}
33
34impl Default for AdaptiveRoutingConfig {
35    fn default() -> Self {
36        Self {
37            enabled: true,
38            default_level: 1,
39            max_level: 3,
40            keyword_weight: 0.5,
41            length_weight: 0.3,
42            entity_weight: 0.2,
43        }
44    }
45}
46
47/// Query complexity level
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
49pub enum QueryComplexity {
50    /// Very broad query (overview, themes)
51    VeryBroad,
52    /// Broad query (general understanding)
53    Broad,
54    /// Medium complexity
55    Medium,
56    /// Specific query (detailed information)
57    Specific,
58    /// Very specific query (precise relationships)
59    VerySpecific,
60}
61
62impl QueryComplexity {
63    /// Convert complexity to hierarchical level
64    pub fn to_level(self, max_level: usize) -> usize {
65        match self {
66            QueryComplexity::VeryBroad => max_level.max(2),
67            QueryComplexity::Broad => (max_level - 1).max(1),
68            QueryComplexity::Medium => 1,
69            QueryComplexity::Specific => 0,
70            QueryComplexity::VerySpecific => 0,
71        }
72    }
73}
74
75/// Analyzes query complexity and suggests appropriate hierarchical level
76#[derive(Debug)]
77pub struct QueryComplexityAnalyzer {
78    config: AdaptiveRoutingConfig,
79
80    // Keyword sets for classification
81    broad_keywords: Vec<&'static str>,
82    specific_keywords: Vec<&'static str>,
83}
84
85impl QueryComplexityAnalyzer {
86    /// Create a new query complexity analyzer
87    pub fn new(config: AdaptiveRoutingConfig) -> Self {
88        Self {
89            config,
90            broad_keywords: vec![
91                "overview",
92                "summary",
93                "summarize",
94                "main",
95                "general",
96                "all",
97                "themes",
98                "topics",
99                "overall",
100                "broadly",
101                "big picture",
102                "what are",
103                "list all",
104                "show me all",
105            ],
106            specific_keywords: vec![
107                "relationship between",
108                "how does",
109                "why does",
110                "specific",
111                "detail",
112                "exactly",
113                "precisely",
114                "what is the connection",
115                "explain how",
116                "describe the",
117                "between",
118                "and",
119            ],
120        }
121    }
122
123    /// Analyze query and determine complexity
124    pub fn analyze(&self, query: &str) -> QueryComplexity {
125        let query_lower = query.to_lowercase();
126
127        // Score components
128        let keyword_score = self.analyze_keywords(&query_lower);
129        let length_score = self.analyze_length(query);
130        let entity_score = self.analyze_entity_mentions(&query_lower);
131
132        // Weighted combination
133        let total_score = keyword_score * self.config.keyword_weight
134            + length_score * self.config.length_weight
135            + entity_score * self.config.entity_weight;
136
137        // Map score to complexity level
138        if total_score >= 0.7 {
139            QueryComplexity::VeryBroad
140        } else if total_score >= 0.4 {
141            QueryComplexity::Broad
142        } else if total_score >= -0.2 {
143            QueryComplexity::Medium
144        } else if total_score >= -0.5 {
145            QueryComplexity::Specific
146        } else {
147            QueryComplexity::VerySpecific
148        }
149    }
150
151    /// Analyze query keywords (-1.0 = very specific, +1.0 = very broad)
152    fn analyze_keywords(&self, query_lower: &str) -> f32 {
153        let mut score = 0.0;
154        let mut matches = 0;
155
156        // Check broad keywords (positive score)
157        for keyword in &self.broad_keywords {
158            if query_lower.contains(keyword) {
159                score += 1.0;
160                matches += 1;
161            }
162        }
163
164        // Check specific keywords (negative score)
165        for keyword in &self.specific_keywords {
166            if query_lower.contains(keyword) {
167                score -= 1.0;
168                matches += 1;
169            }
170        }
171
172        // Normalize
173        if matches > 0 {
174            score / matches as f32
175        } else {
176            0.0 // No keywords found = medium
177        }
178    }
179
180    /// Analyze query length (short = specific, long = broad)
181    fn analyze_length(&self, query: &str) -> f32 {
182        let words: Vec<&str> = query.split_whitespace().collect();
183        let word_count = words.len();
184
185        // Short queries (1-3 words) tend to be broad ("AI overview")
186        // Long queries (8+ words) tend to be specific
187        match word_count {
188            1..=3 => 0.5,   // Short → broad
189            4..=5 => 0.2,   // Medium-short
190            6..=7 => 0.0,   // Medium
191            8..=10 => -0.3, // Medium-long → specific
192            _ => -0.5,      // Long → very specific
193        }
194    }
195
196    /// Analyze entity mentions (many entities = specific query)
197    fn analyze_entity_mentions(&self, query_lower: &str) -> f32 {
198        // Count capital words (potential entity names in original query)
199        // and quoted phrases (explicit entity references)
200        let quoted_count = query_lower.matches('"').count() / 2;
201        let and_between = query_lower.matches(" and ").count();
202        let between_count = query_lower.matches("between").count();
203
204        let entity_indicators = quoted_count + and_between + between_count;
205
206        // More entity indicators = more specific
207        match entity_indicators {
208            0 => 0.3,  // No entities → broad
209            1 => 0.0,  // One entity → medium
210            2 => -0.4, // Two entities → specific
211            _ => -0.7, // Multiple entities → very specific
212        }
213    }
214
215    /// Suggest hierarchical level for query
216    pub fn suggest_level(&self, query: &str) -> usize {
217        let complexity = self.analyze(query);
218        complexity.to_level(self.config.max_level)
219    }
220
221    /// Get detailed analysis with explanation
222    pub fn analyze_detailed(&self, query: &str) -> QueryAnalysis {
223        let query_lower = query.to_lowercase();
224
225        let keyword_score = self.analyze_keywords(&query_lower);
226        let length_score = self.analyze_length(query);
227        let entity_score = self.analyze_entity_mentions(&query_lower);
228
229        let complexity = self.analyze(query);
230        let suggested_level = complexity.to_level(self.config.max_level);
231
232        QueryAnalysis {
233            query: query.to_string(),
234            complexity,
235            suggested_level,
236            keyword_score,
237            length_score,
238            entity_score,
239            explanation: self.generate_explanation(complexity, suggested_level),
240        }
241    }
242
243    /// Generate explanation for the routing decision
244    fn generate_explanation(&self, complexity: QueryComplexity, level: usize) -> String {
245        match complexity {
246            QueryComplexity::VeryBroad => format!(
247                "Very broad query detected → using level {} for high-level overview",
248                level
249            ),
250            QueryComplexity::Broad => format!(
251                "Broad query detected → using level {} for general understanding",
252                level
253            ),
254            QueryComplexity::Medium => format!(
255                "Medium complexity query → using level {} for balanced detail",
256                level
257            ),
258            QueryComplexity::Specific => format!(
259                "Specific query detected → using level {} for detailed information",
260                level
261            ),
262            QueryComplexity::VerySpecific => format!(
263                "Very specific query detected → using level {} for precise relationships",
264                level
265            ),
266        }
267    }
268}
269
270/// Detailed query analysis result
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct QueryAnalysis {
273    /// Original query
274    pub query: String,
275    /// Detected complexity level
276    pub complexity: QueryComplexity,
277    /// Suggested hierarchical level
278    pub suggested_level: usize,
279    /// Keyword analysis score
280    pub keyword_score: f32,
281    /// Length analysis score
282    pub length_score: f32,
283    /// Entity mention score
284    pub entity_score: f32,
285    /// Human-readable explanation
286    pub explanation: String,
287}
288
289impl QueryAnalysis {
290    /// Print detailed analysis
291    pub fn print(&self) {
292        println!("Query Analysis:");
293        println!("  Query: \"{}\"", self.query);
294        println!("  Complexity: {:?}", self.complexity);
295        println!("  Suggested Level: {}", self.suggested_level);
296        println!("  Scores:");
297        println!("    - Keywords: {:.2}", self.keyword_score);
298        println!("    - Length: {:.2}", self.length_score);
299        println!("    - Entities: {:.2}", self.entity_score);
300        println!("  {}", self.explanation);
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[test]
309    fn test_broad_query() {
310        let config = AdaptiveRoutingConfig::default();
311        let analyzer = QueryComplexityAnalyzer::new(config);
312
313        let query = "Give me an overview of AI technologies";
314        let complexity = analyzer.analyze(query);
315        let level = analyzer.suggest_level(query);
316
317        // Should be broad → high level
318        assert!(matches!(
319            complexity,
320            QueryComplexity::VeryBroad | QueryComplexity::Broad
321        ));
322        assert!(level >= 1);
323    }
324
325    #[test]
326    fn test_specific_query() {
327        let config = AdaptiveRoutingConfig::default();
328        let analyzer = QueryComplexityAnalyzer::new(config);
329
330        let query = "What is the relationship between Transformers and GPT?";
331        let complexity = analyzer.analyze(query);
332        let level = analyzer.suggest_level(query);
333
334        // Should be specific → low level
335        assert!(matches!(
336            complexity,
337            QueryComplexity::Specific | QueryComplexity::VerySpecific
338        ));
339        assert_eq!(level, 0);
340    }
341
342    #[test]
343    fn test_medium_query() {
344        let config = AdaptiveRoutingConfig::default();
345        let analyzer = QueryComplexityAnalyzer::new(config);
346
347        let query = "How does machine learning work?";
348        let level = analyzer.suggest_level(query);
349
350        // Should be medium → level 1
351        assert!(level <= 1);
352    }
353
354    #[test]
355    fn test_detailed_analysis() {
356        let config = AdaptiveRoutingConfig::default();
357        let analyzer = QueryComplexityAnalyzer::new(config);
358
359        let query = "Summarize the main themes";
360        let analysis = analyzer.analyze_detailed(query);
361
362        assert!(analysis.keyword_score > 0.0); // Contains "summarize" and "main"
363        assert!(!analysis.explanation.is_empty());
364    }
365}