Skip to main content

graphrag_core/retrieval/
adaptive.rs

1//! Adaptive strategy selection for intelligent retrieval
2
3use crate::{
4    core::KnowledgeGraph,
5    retrieval::{QueryAnalysisResult, QueryType, RetrievalSystem, SearchResult},
6    summarization::DocumentTree,
7    vector::VectorIndex,
8    Result,
9};
10use std::collections::HashMap;
11
12/// Weights for different retrieval strategies
13#[derive(Debug, Clone)]
14pub struct StrategyWeights {
15    /// Weight for vector similarity-based retrieval
16    pub vector_weight: f32,
17    /// Weight for graph-based traversal retrieval
18    pub graph_weight: f32,
19    /// Weight for hierarchical document tree retrieval
20    pub hierarchical_weight: f32,
21    /// Weight for BM25 keyword-based retrieval
22    pub bm25_weight: f32,
23}
24
25impl Default for StrategyWeights {
26    fn default() -> Self {
27        Self {
28            vector_weight: 0.25,
29            graph_weight: 0.25,
30            hierarchical_weight: 0.25,
31            bm25_weight: 0.25,
32        }
33    }
34}
35
36/// Configuration for adaptive strategy selection
37#[derive(Debug, Clone)]
38pub struct AdaptiveConfig {
39    /// Strategy weights for entity-focused queries
40    pub entity_weights: StrategyWeights,
41    /// Strategy weights for conceptual queries
42    pub conceptual_weights: StrategyWeights,
43    /// Strategy weights for factual queries
44    pub factual_weights: StrategyWeights,
45    /// Strategy weights for relational queries
46    pub relational_weights: StrategyWeights,
47    /// Strategy weights for complex multi-part queries
48    pub complex_weights: StrategyWeights,
49    /// Minimum confidence to use specialized weights
50    pub min_confidence_for_specialization: f32,
51    /// Number of results to retrieve per strategy
52    pub results_per_strategy: usize,
53}
54
55impl Default for AdaptiveConfig {
56    fn default() -> Self {
57        Self {
58            entity_weights: StrategyWeights {
59                vector_weight: 0.2,
60                graph_weight: 0.5,
61                hierarchical_weight: 0.2,
62                bm25_weight: 0.1,
63            },
64            conceptual_weights: StrategyWeights {
65                vector_weight: 0.6,
66                graph_weight: 0.1,
67                hierarchical_weight: 0.3,
68                bm25_weight: 0.0,
69            },
70            factual_weights: StrategyWeights {
71                vector_weight: 0.2,
72                graph_weight: 0.1,
73                hierarchical_weight: 0.1,
74                bm25_weight: 0.6,
75            },
76            relational_weights: StrategyWeights {
77                vector_weight: 0.2,
78                graph_weight: 0.6,
79                hierarchical_weight: 0.1,
80                bm25_weight: 0.1,
81            },
82            complex_weights: StrategyWeights::default(),
83            min_confidence_for_specialization: 0.6,
84            results_per_strategy: 10,
85        }
86    }
87}
88
89/// Result of adaptive strategy selection
90#[derive(Debug, Clone)]
91pub struct AdaptiveRetrievalResult {
92    /// Final ranked search results after fusion
93    pub results: Vec<SearchResult>,
94    /// Strategy weights applied during retrieval
95    pub strategy_weights_used: StrategyWeights,
96    /// Analysis results from query classification
97    pub query_analysis: QueryAnalysisResult,
98    /// Name of fusion method used
99    pub fusion_method: String,
100    /// Total number of results before deduplication
101    pub total_results_before_fusion: usize,
102}
103
104/// Adaptive retrieval system that selects strategies based on query analysis
105pub struct AdaptiveRetriever {
106    config: AdaptiveConfig,
107    retrieval_system: RetrievalSystem,
108}
109
110impl AdaptiveRetriever {
111    /// Create a new adaptive retriever
112    pub fn new(
113        config: AdaptiveConfig,
114        _vector_index: VectorIndex,
115        _knowledge_graph: KnowledgeGraph,
116        _document_trees: HashMap<String, DocumentTree>,
117    ) -> Result<Self> {
118        // Create a default config for the retrieval system
119        let default_config = crate::config::Config::default();
120        let retrieval_system = RetrievalSystem::new(&default_config)?;
121
122        Ok(Self {
123            config,
124            retrieval_system,
125        })
126    }
127
128    /// Perform adaptive retrieval based on query analysis
129    pub fn retrieve(
130        &mut self,
131        query: &str,
132        query_analysis: &QueryAnalysisResult,
133        max_results: usize,
134    ) -> Result<AdaptiveRetrievalResult> {
135        // Select strategy weights based on query type and confidence
136        let strategy_weights = self.select_strategy_weights(query_analysis);
137
138        // Retrieve results using different strategies
139        let mut all_results = Vec::new();
140
141        // Vector similarity search
142        if strategy_weights.vector_weight > 0.0 {
143            let vector_results = self.retrieval_system.vector_search(
144                query,
145                (self.config.results_per_strategy as f32 * strategy_weights.vector_weight) as usize,
146            )?;
147            all_results.extend(self.weight_results(vector_results, strategy_weights.vector_weight));
148        }
149
150        // Graph-based search
151        if strategy_weights.graph_weight > 0.0 {
152            let graph_results = self.retrieval_system.graph_search(
153                query,
154                (self.config.results_per_strategy as f32 * strategy_weights.graph_weight) as usize,
155            )?;
156            all_results.extend(self.weight_results(graph_results, strategy_weights.graph_weight));
157        }
158
159        // Hierarchical search
160        if strategy_weights.hierarchical_weight > 0.0 {
161            let max_results = (self.config.results_per_strategy as f32
162                * strategy_weights.hierarchical_weight) as usize;
163            let hierarchical_results = self
164                .retrieval_system
165                .public_hierarchical_search(query, max_results)?;
166            all_results.extend(
167                self.weight_results(hierarchical_results, strategy_weights.hierarchical_weight),
168            );
169        }
170
171        // BM25 search
172        if strategy_weights.bm25_weight > 0.0 {
173            let bm25_results = self.retrieval_system.bm25_search(
174                query,
175                (self.config.results_per_strategy as f32 * strategy_weights.bm25_weight) as usize,
176            )?;
177            all_results.extend(self.weight_results(bm25_results, strategy_weights.bm25_weight));
178        }
179
180        let total_results_before_fusion = all_results.len();
181
182        // Perform cross-strategy fusion
183        let fused_results = self.cross_strategy_fusion(all_results, max_results)?;
184
185        Ok(AdaptiveRetrievalResult {
186            results: fused_results,
187            strategy_weights_used: strategy_weights,
188            query_analysis: query_analysis.clone(),
189            fusion_method: "weighted_score_fusion".to_string(),
190            total_results_before_fusion,
191        })
192    }
193
194    /// Select strategy weights based on query analysis
195    fn select_strategy_weights(&self, query_analysis: &QueryAnalysisResult) -> StrategyWeights {
196        // If confidence is low, use default balanced weights
197        if query_analysis.confidence < self.config.min_confidence_for_specialization {
198            return self.config.complex_weights.clone();
199        }
200
201        // Select weights based on query type
202        match query_analysis.query_type {
203            QueryType::EntityFocused => self.config.entity_weights.clone(),
204            QueryType::Conceptual => self.config.conceptual_weights.clone(),
205            QueryType::Factual => self.config.factual_weights.clone(),
206            QueryType::Relationship => self.config.relational_weights.clone(),
207            QueryType::Exploratory => self.config.complex_weights.clone(),
208        }
209    }
210
211    /// Apply strategy weight to results
212    fn weight_results(&self, mut results: Vec<SearchResult>, weight: f32) -> Vec<SearchResult> {
213        for result in &mut results {
214            result.score *= weight;
215        }
216        results
217    }
218
219    /// Perform cross-strategy fusion of results
220    fn cross_strategy_fusion(
221        &self,
222        results: Vec<SearchResult>,
223        max_results: usize,
224    ) -> Result<Vec<SearchResult>> {
225        // Remove duplicates by chunk ID, keeping highest scored version
226        let mut seen_chunks = HashMap::new();
227        let mut deduplicated_results = Vec::new();
228
229        for result in results {
230            let chunk_id = &result.id;
231
232            if let Some(existing_score) = seen_chunks.get(chunk_id) {
233                if result.score > *existing_score {
234                    // Replace with higher scored version
235                    seen_chunks.insert(chunk_id.clone(), result.score);
236                    // Remove old version and add new one
237                    deduplicated_results.retain(|r: &SearchResult| r.id != *chunk_id);
238                    deduplicated_results.push(result);
239                }
240            } else {
241                seen_chunks.insert(chunk_id.clone(), result.score);
242                deduplicated_results.push(result);
243            }
244        }
245
246        // Sort by final weighted score
247        deduplicated_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
248
249        // Apply diversity-aware selection
250        let final_results = self.diversity_aware_selection(deduplicated_results, max_results);
251
252        Ok(final_results)
253    }
254
255    /// Apply diversity-aware selection to avoid redundant results
256    fn diversity_aware_selection(
257        &self,
258        results: Vec<SearchResult>,
259        max_results: usize,
260    ) -> Vec<SearchResult> {
261        let mut selected_results = Vec::new();
262        let mut selected_entities = std::collections::HashSet::new();
263        let _remaining_results = results.clone();
264
265        for result in &results {
266            if selected_results.len() >= max_results {
267                break;
268            }
269
270            // Check for entity diversity
271            let has_new_entities = result
272                .entities
273                .iter()
274                .any(|entity| !selected_entities.contains(entity));
275
276            // Always include high-scoring results or those with new entities
277            if result.score > 0.8 || has_new_entities || selected_results.len() < max_results / 2 {
278                for entity in &result.entities {
279                    selected_entities.insert(entity.clone());
280                }
281                selected_results.push(result.clone());
282            }
283        }
284
285        // If we don't have enough results, fill with remaining high-scoring ones
286        if selected_results.len() < max_results {
287            for result in results {
288                if selected_results.len() >= max_results {
289                    break;
290                }
291                if !selected_results.iter().any(|r| r.id == result.id) {
292                    selected_results.push(result);
293                }
294            }
295        }
296
297        selected_results
298    }
299
300    /// Get adaptive retrieval statistics
301    pub fn get_statistics(&self) -> AdaptiveRetrieverStatistics {
302        AdaptiveRetrieverStatistics {
303            config: self.config.clone(),
304            retrieval_system_stats: format!("RetrievalSystem with {} strategies", 4),
305        }
306    }
307}
308
309/// Statistics about the adaptive retriever
310#[derive(Debug)]
311pub struct AdaptiveRetrieverStatistics {
312    /// Configuration used by the retriever
313    pub config: AdaptiveConfig,
314    /// Summary statistics from underlying retrieval system
315    pub retrieval_system_stats: String,
316}
317
318impl AdaptiveRetrieverStatistics {
319    /// Print adaptive retriever statistics to stdout
320    pub fn print(&self) {
321        println!("Adaptive Retriever Statistics:");
322        println!(
323            "  Min confidence for specialization: {:.2}",
324            self.config.min_confidence_for_specialization
325        );
326        println!(
327            "  Results per strategy: {}",
328            self.config.results_per_strategy
329        );
330        println!(
331            "  Entity weights: V:{:.2} G:{:.2} H:{:.2} B:{:.2}",
332            self.config.entity_weights.vector_weight,
333            self.config.entity_weights.graph_weight,
334            self.config.entity_weights.hierarchical_weight,
335            self.config.entity_weights.bm25_weight
336        );
337        println!(
338            "  Factual weights: V:{:.2} G:{:.2} H:{:.2} B:{:.2}",
339            self.config.factual_weights.vector_weight,
340            self.config.factual_weights.graph_weight,
341            self.config.factual_weights.hierarchical_weight,
342            self.config.factual_weights.bm25_weight
343        );
344        println!("  {}", self.retrieval_system_stats);
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn test_strategy_weight_selection() {
354        let _config = AdaptiveConfig::default();
355
356        // Mock query analysis for entity-focused query
357        let entity_analysis = QueryAnalysisResult {
358            query_type: QueryType::EntityFocused,
359            confidence: 0.8,
360            keywords_matched: vec!["who".to_string()],
361            suggested_strategies: vec!["entity_search".to_string()],
362            complexity_score: 0.2,
363        };
364
365        // Test that we would select entity weights for high-confidence entity query
366        // This is a unit test for the weight selection logic
367        assert_eq!(entity_analysis.query_type, QueryType::EntityFocused);
368        assert!(entity_analysis.confidence > 0.6);
369    }
370
371    #[test]
372    fn test_diversity_aware_selection() {
373        // Create mock adaptive retriever with default config
374        let config = AdaptiveConfig::default();
375
376        // Test diversity logic by checking that the function exists
377        // In a real test environment, we would create full mock objects
378        assert!(config.min_confidence_for_specialization > 0.0);
379    }
380}