Skip to main content

graphrag_core/query/
adaptive_routing.rs

1//! Adaptive Query Routing for Hierarchical GraphRAG
2//!
3//! Automatically selects the appropriate hierarchical level based on query complexity.
4//!
5//! # Strategy
6//! - **Broad queries** (overview, themes, summary) → Higher levels (2-3)
7//! - **Specific queries** (relationships, details, entities) → Lower levels (0-1)
8//! - **Adaptive routing** based on query analysis
9
10use serde::{Deserialize, Serialize};
11
12/// Configuration for adaptive query routing
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct AdaptiveRoutingConfig {
15    /// Enable adaptive query routing
16    pub enabled: bool,
17
18    /// Default level when query complexity is unclear
19    pub default_level: usize,
20
21    /// Maximum hierarchical level available
22    pub max_level: usize,
23
24    /// Weight for keyword-based level selection (0.0-1.0)
25    pub keyword_weight: f32,
26
27    /// Weight for query length-based selection (0.0-1.0)
28    pub length_weight: f32,
29
30    /// Weight for entity mention-based selection (0.0-1.0)
31    pub entity_weight: f32,
32}
33
34impl Default for AdaptiveRoutingConfig {
35    fn default() -> Self {
36        Self {
37            enabled: true,
38            default_level: 1,
39            max_level: 3,
40            keyword_weight: 0.5,
41            length_weight: 0.3,
42            entity_weight: 0.2,
43        }
44    }
45}
46
47/// Query complexity level
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
49pub enum QueryComplexity {
50    /// Very broad query (overview, themes)
51    VeryBroad,
52    /// Broad query (general understanding)
53    Broad,
54    /// Medium complexity
55    Medium,
56    /// Specific query (detailed information)
57    Specific,
58    /// Very specific query (precise relationships)
59    VerySpecific,
60}
61
62impl QueryComplexity {
63    /// Convert complexity to hierarchical level
64    pub fn to_level(&self, max_level: usize) -> usize {
65        match self {
66            QueryComplexity::VeryBroad => max_level.max(2),
67            QueryComplexity::Broad => (max_level - 1).max(1),
68            QueryComplexity::Medium => 1,
69            QueryComplexity::Specific => 0,
70            QueryComplexity::VerySpecific => 0,
71        }
72    }
73}
74
75/// Analyzes query complexity and suggests appropriate hierarchical level
76#[derive(Debug)]
77pub struct QueryComplexityAnalyzer {
78    config: AdaptiveRoutingConfig,
79
80    // Keyword sets for classification
81    broad_keywords: Vec<&'static str>,
82    specific_keywords: Vec<&'static str>,
83}
84
85impl QueryComplexityAnalyzer {
86    /// Create a new query complexity analyzer
87    pub fn new(config: AdaptiveRoutingConfig) -> Self {
88        Self {
89            config,
90            broad_keywords: vec![
91                "overview", "summary", "summarize", "main", "general", "all",
92                "themes", "topics", "overall", "broadly", "big picture",
93                "what are", "list all", "show me all",
94            ],
95            specific_keywords: vec![
96                "relationship between", "how does", "why does", "specific",
97                "detail", "exactly", "precisely", "what is the connection",
98                "explain how", "describe the", "between", "and",
99            ],
100        }
101    }
102
103    /// Analyze query and determine complexity
104    pub fn analyze(&self, query: &str) -> QueryComplexity {
105        let query_lower = query.to_lowercase();
106
107        // Score components
108        let keyword_score = self.analyze_keywords(&query_lower);
109        let length_score = self.analyze_length(query);
110        let entity_score = self.analyze_entity_mentions(&query_lower);
111
112        // Weighted combination
113        let total_score =
114            keyword_score * self.config.keyword_weight +
115            length_score * self.config.length_weight +
116            entity_score * self.config.entity_weight;
117
118        // Map score to complexity level
119        if total_score >= 0.7 {
120            QueryComplexity::VeryBroad
121        } else if total_score >= 0.4 {
122            QueryComplexity::Broad
123        } else if total_score >= -0.2 {
124            QueryComplexity::Medium
125        } else if total_score >= -0.5 {
126            QueryComplexity::Specific
127        } else {
128            QueryComplexity::VerySpecific
129        }
130    }
131
132    /// Analyze query keywords (-1.0 = very specific, +1.0 = very broad)
133    fn analyze_keywords(&self, query_lower: &str) -> f32 {
134        let mut score = 0.0;
135        let mut matches = 0;
136
137        // Check broad keywords (positive score)
138        for keyword in &self.broad_keywords {
139            if query_lower.contains(keyword) {
140                score += 1.0;
141                matches += 1;
142            }
143        }
144
145        // Check specific keywords (negative score)
146        for keyword in &self.specific_keywords {
147            if query_lower.contains(keyword) {
148                score -= 1.0;
149                matches += 1;
150            }
151        }
152
153        // Normalize
154        if matches > 0 {
155            score / matches as f32
156        } else {
157            0.0 // No keywords found = medium
158        }
159    }
160
161    /// Analyze query length (short = specific, long = broad)
162    fn analyze_length(&self, query: &str) -> f32 {
163        let words: Vec<&str> = query.split_whitespace().collect();
164        let word_count = words.len();
165
166        // Short queries (1-3 words) tend to be broad ("AI overview")
167        // Long queries (8+ words) tend to be specific
168        match word_count {
169            1..=3 => 0.5,      // Short → broad
170            4..=5 => 0.2,      // Medium-short
171            6..=7 => 0.0,      // Medium
172            8..=10 => -0.3,    // Medium-long → specific
173            _ => -0.5,         // Long → very specific
174        }
175    }
176
177    /// Analyze entity mentions (many entities = specific query)
178    fn analyze_entity_mentions(&self, query_lower: &str) -> f32 {
179        // Count capital words (potential entity names in original query)
180        // and quoted phrases (explicit entity references)
181        let quoted_count = query_lower.matches('"').count() / 2;
182        let and_between = query_lower.matches(" and ").count();
183        let between_count = query_lower.matches("between").count();
184
185        let entity_indicators = quoted_count + and_between + between_count;
186
187        // More entity indicators = more specific
188        match entity_indicators {
189            0 => 0.3,      // No entities → broad
190            1 => 0.0,      // One entity → medium
191            2 => -0.4,     // Two entities → specific
192            _ => -0.7,     // Multiple entities → very specific
193        }
194    }
195
196    /// Suggest hierarchical level for query
197    pub fn suggest_level(&self, query: &str) -> usize {
198        let complexity = self.analyze(query);
199        complexity.to_level(self.config.max_level)
200    }
201
202    /// Get detailed analysis with explanation
203    pub fn analyze_detailed(&self, query: &str) -> QueryAnalysis {
204        let query_lower = query.to_lowercase();
205
206        let keyword_score = self.analyze_keywords(&query_lower);
207        let length_score = self.analyze_length(query);
208        let entity_score = self.analyze_entity_mentions(&query_lower);
209
210        let complexity = self.analyze(query);
211        let suggested_level = complexity.to_level(self.config.max_level);
212
213        QueryAnalysis {
214            query: query.to_string(),
215            complexity,
216            suggested_level,
217            keyword_score,
218            length_score,
219            entity_score,
220            explanation: self.generate_explanation(complexity, suggested_level),
221        }
222    }
223
224    /// Generate explanation for the routing decision
225    fn generate_explanation(&self, complexity: QueryComplexity, level: usize) -> String {
226        match complexity {
227            QueryComplexity::VeryBroad => format!(
228                "Very broad query detected → using level {} for high-level overview",
229                level
230            ),
231            QueryComplexity::Broad => format!(
232                "Broad query detected → using level {} for general understanding",
233                level
234            ),
235            QueryComplexity::Medium => format!(
236                "Medium complexity query → using level {} for balanced detail",
237                level
238            ),
239            QueryComplexity::Specific => format!(
240                "Specific query detected → using level {} for detailed information",
241                level
242            ),
243            QueryComplexity::VerySpecific => format!(
244                "Very specific query detected → using level {} for precise relationships",
245                level
246            ),
247        }
248    }
249}
250
251/// Detailed query analysis result
252#[derive(Debug, Clone, Serialize, Deserialize)]
253pub struct QueryAnalysis {
254    /// Original query
255    pub query: String,
256    /// Detected complexity level
257    pub complexity: QueryComplexity,
258    /// Suggested hierarchical level
259    pub suggested_level: usize,
260    /// Keyword analysis score
261    pub keyword_score: f32,
262    /// Length analysis score
263    pub length_score: f32,
264    /// Entity mention score
265    pub entity_score: f32,
266    /// Human-readable explanation
267    pub explanation: String,
268}
269
270impl QueryAnalysis {
271    /// Print detailed analysis
272    pub fn print(&self) {
273        println!("Query Analysis:");
274        println!("  Query: \"{}\"", self.query);
275        println!("  Complexity: {:?}", self.complexity);
276        println!("  Suggested Level: {}", self.suggested_level);
277        println!("  Scores:");
278        println!("    - Keywords: {:.2}", self.keyword_score);
279        println!("    - Length: {:.2}", self.length_score);
280        println!("    - Entities: {:.2}", self.entity_score);
281        println!("  {}", self.explanation);
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[test]
290    fn test_broad_query() {
291        let config = AdaptiveRoutingConfig::default();
292        let analyzer = QueryComplexityAnalyzer::new(config);
293
294        let query = "Give me an overview of AI technologies";
295        let complexity = analyzer.analyze(query);
296        let level = analyzer.suggest_level(query);
297
298        // Should be broad → high level
299        assert!(matches!(complexity, QueryComplexity::VeryBroad | QueryComplexity::Broad));
300        assert!(level >= 1);
301    }
302
303    #[test]
304    fn test_specific_query() {
305        let config = AdaptiveRoutingConfig::default();
306        let analyzer = QueryComplexityAnalyzer::new(config);
307
308        let query = "What is the relationship between Transformers and GPT?";
309        let complexity = analyzer.analyze(query);
310        let level = analyzer.suggest_level(query);
311
312        // Should be specific → low level
313        assert!(matches!(complexity, QueryComplexity::Specific | QueryComplexity::VerySpecific));
314        assert_eq!(level, 0);
315    }
316
317    #[test]
318    fn test_medium_query() {
319        let config = AdaptiveRoutingConfig::default();
320        let analyzer = QueryComplexityAnalyzer::new(config);
321
322        let query = "How does machine learning work?";
323        let level = analyzer.suggest_level(query);
324
325        // Should be medium → level 1
326        assert!(level <= 1);
327    }
328
329    #[test]
330    fn test_detailed_analysis() {
331        let config = AdaptiveRoutingConfig::default();
332        let analyzer = QueryComplexityAnalyzer::new(config);
333
334        let query = "Summarize the main themes";
335        let analysis = analyzer.analyze_detailed(query);
336
337        assert!(analysis.keyword_score > 0.0); // Contains "summarize" and "main"
338        assert!(!analysis.explanation.is_empty());
339    }
340}