Skip to main content

graphrag_core/retrieval/
adaptive.rs

1//! Adaptive strategy selection for intelligent retrieval
2
3use crate::{
4    core::KnowledgeGraph,
5    retrieval::{QueryAnalysisResult, QueryType, RetrievalSystem, SearchResult},
6    summarization::DocumentTree,
7    vector::VectorIndex,
8    Result,
9};
10use std::collections::HashMap;
11
12/// Weights for different retrieval strategies
13#[derive(Debug, Clone)]
14pub struct StrategyWeights {
15    /// Weight for vector similarity-based retrieval
16    pub vector_weight: f32,
17    /// Weight for graph-based traversal retrieval
18    pub graph_weight: f32,
19    /// Weight for hierarchical document tree retrieval
20    pub hierarchical_weight: f32,
21    /// Weight for BM25 keyword-based retrieval
22    pub bm25_weight: f32,
23}
24
25impl Default for StrategyWeights {
26    fn default() -> Self {
27        Self {
28            vector_weight: 0.25,
29            graph_weight: 0.25,
30            hierarchical_weight: 0.25,
31            bm25_weight: 0.25,
32        }
33    }
34}
35
36/// Configuration for adaptive strategy selection
37#[derive(Debug, Clone)]
38pub struct AdaptiveConfig {
39    /// Strategy weights for entity-focused queries
40    pub entity_weights: StrategyWeights,
41    /// Strategy weights for conceptual queries
42    pub conceptual_weights: StrategyWeights,
43    /// Strategy weights for factual queries
44    pub factual_weights: StrategyWeights,
45    /// Strategy weights for relational queries
46    pub relational_weights: StrategyWeights,
47    /// Strategy weights for complex multi-part queries
48    pub complex_weights: StrategyWeights,
49    /// Minimum confidence to use specialized weights
50    pub min_confidence_for_specialization: f32,
51    /// Number of results to retrieve per strategy
52    pub results_per_strategy: usize,
53}
54
55impl Default for AdaptiveConfig {
56    fn default() -> Self {
57        Self {
58            entity_weights: StrategyWeights {
59                vector_weight: 0.2,
60                graph_weight: 0.5,
61                hierarchical_weight: 0.2,
62                bm25_weight: 0.1,
63            },
64            conceptual_weights: StrategyWeights {
65                vector_weight: 0.6,
66                graph_weight: 0.1,
67                hierarchical_weight: 0.3,
68                bm25_weight: 0.0,
69            },
70            factual_weights: StrategyWeights {
71                vector_weight: 0.2,
72                graph_weight: 0.1,
73                hierarchical_weight: 0.1,
74                bm25_weight: 0.6,
75            },
76            relational_weights: StrategyWeights {
77                vector_weight: 0.2,
78                graph_weight: 0.6,
79                hierarchical_weight: 0.1,
80                bm25_weight: 0.1,
81            },
82            complex_weights: StrategyWeights::default(),
83            min_confidence_for_specialization: 0.6,
84            results_per_strategy: 10,
85        }
86    }
87}
88
89/// Result of adaptive strategy selection
90#[derive(Debug, Clone)]
91pub struct AdaptiveRetrievalResult {
92    /// Final ranked search results after fusion
93    pub results: Vec<SearchResult>,
94    /// Strategy weights applied during retrieval
95    pub strategy_weights_used: StrategyWeights,
96    /// Analysis results from query classification
97    pub query_analysis: QueryAnalysisResult,
98    /// Name of fusion method used
99    pub fusion_method: String,
100    /// Total number of results before deduplication
101    pub total_results_before_fusion: usize,
102}
103
104/// Adaptive retrieval system that selects strategies based on query analysis
105pub struct AdaptiveRetriever {
106    config: AdaptiveConfig,
107    retrieval_system: RetrievalSystem,
108}
109
110impl AdaptiveRetriever {
111    /// Create a new adaptive retriever
112    pub fn new(
113        config: AdaptiveConfig,
114        _vector_index: VectorIndex,
115        _knowledge_graph: KnowledgeGraph,
116        _document_trees: HashMap<String, DocumentTree>,
117    ) -> Result<Self> {
118        // Create a default config for the retrieval system
119        let default_config = crate::config::Config::default();
120        let retrieval_system = RetrievalSystem::new(&default_config)?;
121
122        Ok(Self {
123            config,
124            retrieval_system,
125        })
126    }
127
128    /// Perform adaptive retrieval based on query analysis
129    pub async fn retrieve(
130        &mut self,
131        query: &str,
132        query_analysis: &QueryAnalysisResult,
133        max_results: usize,
134    ) -> Result<AdaptiveRetrievalResult> {
135        // Select strategy weights based on query type and confidence
136        let strategy_weights = self.select_strategy_weights(query_analysis);
137
138        // Retrieve results using different strategies
139        let mut all_results = Vec::new();
140
141        // Vector similarity search
142        if strategy_weights.vector_weight > 0.0 {
143            let vector_results = self
144                .retrieval_system
145                .vector_search(
146                    query,
147                    (self.config.results_per_strategy as f32 * strategy_weights.vector_weight)
148                        as usize,
149                )
150                .await?;
151            all_results.extend(self.weight_results(vector_results, strategy_weights.vector_weight));
152        }
153
154        // Graph-based search
155        if strategy_weights.graph_weight > 0.0 {
156            let graph_results = self.retrieval_system.graph_search(
157                query,
158                (self.config.results_per_strategy as f32 * strategy_weights.graph_weight) as usize,
159            )?;
160            all_results.extend(self.weight_results(graph_results, strategy_weights.graph_weight));
161        }
162
163        // Hierarchical search
164        if strategy_weights.hierarchical_weight > 0.0 {
165            let max_results = (self.config.results_per_strategy as f32
166                * strategy_weights.hierarchical_weight) as usize;
167            let hierarchical_results = self
168                .retrieval_system
169                .public_hierarchical_search(query, max_results)?;
170            all_results.extend(
171                self.weight_results(hierarchical_results, strategy_weights.hierarchical_weight),
172            );
173        }
174
175        // BM25 search
176        if strategy_weights.bm25_weight > 0.0 {
177            let bm25_results = self.retrieval_system.bm25_search(
178                query,
179                (self.config.results_per_strategy as f32 * strategy_weights.bm25_weight) as usize,
180            )?;
181            all_results.extend(self.weight_results(bm25_results, strategy_weights.bm25_weight));
182        }
183
184        let total_results_before_fusion = all_results.len();
185
186        // Perform cross-strategy fusion
187        let fused_results = self.cross_strategy_fusion(all_results, max_results)?;
188
189        Ok(AdaptiveRetrievalResult {
190            results: fused_results,
191            strategy_weights_used: strategy_weights,
192            query_analysis: query_analysis.clone(),
193            fusion_method: "weighted_score_fusion".to_string(),
194            total_results_before_fusion,
195        })
196    }
197
198    /// Select strategy weights based on query analysis
199    fn select_strategy_weights(&self, query_analysis: &QueryAnalysisResult) -> StrategyWeights {
200        // If confidence is low, use default balanced weights
201        if query_analysis.confidence < self.config.min_confidence_for_specialization {
202            return self.config.complex_weights.clone();
203        }
204
205        // Select weights based on query type
206        match query_analysis.query_type {
207            QueryType::EntityFocused => self.config.entity_weights.clone(),
208            QueryType::Conceptual => self.config.conceptual_weights.clone(),
209            QueryType::Factual => self.config.factual_weights.clone(),
210            QueryType::Relationship => self.config.relational_weights.clone(),
211            QueryType::Exploratory => self.config.complex_weights.clone(),
212        }
213    }
214
215    /// Apply strategy weight to results
216    fn weight_results(&self, mut results: Vec<SearchResult>, weight: f32) -> Vec<SearchResult> {
217        for result in &mut results {
218            result.score *= weight;
219        }
220        results
221    }
222
223    /// Perform cross-strategy fusion of results
224    fn cross_strategy_fusion(
225        &self,
226        results: Vec<SearchResult>,
227        max_results: usize,
228    ) -> Result<Vec<SearchResult>> {
229        // Remove duplicates by chunk ID, keeping highest scored version
230        let mut seen_chunks = HashMap::new();
231        let mut deduplicated_results = Vec::new();
232
233        for result in results {
234            let chunk_id = &result.id;
235
236            if let Some(existing_score) = seen_chunks.get(chunk_id) {
237                if result.score > *existing_score {
238                    // Replace with higher scored version
239                    seen_chunks.insert(chunk_id.clone(), result.score);
240                    // Remove old version and add new one
241                    deduplicated_results.retain(|r: &SearchResult| r.id != *chunk_id);
242                    deduplicated_results.push(result);
243                }
244            } else {
245                seen_chunks.insert(chunk_id.clone(), result.score);
246                deduplicated_results.push(result);
247            }
248        }
249
250        // Sort by final weighted score
251        deduplicated_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
252
253        // Apply diversity-aware selection
254        let final_results = self.diversity_aware_selection(deduplicated_results, max_results);
255
256        Ok(final_results)
257    }
258
259    /// Apply diversity-aware selection to avoid redundant results
260    fn diversity_aware_selection(
261        &self,
262        results: Vec<SearchResult>,
263        max_results: usize,
264    ) -> Vec<SearchResult> {
265        let mut selected_results = Vec::new();
266        let mut selected_entities = std::collections::HashSet::new();
267        let _remaining_results = results.clone();
268
269        for result in &results {
270            if selected_results.len() >= max_results {
271                break;
272            }
273
274            // Check for entity diversity
275            let has_new_entities = result
276                .entities
277                .iter()
278                .any(|entity| !selected_entities.contains(entity));
279
280            // Always include high-scoring results or those with new entities
281            if result.score > 0.8 || has_new_entities || selected_results.len() < max_results / 2 {
282                for entity in &result.entities {
283                    selected_entities.insert(entity.clone());
284                }
285                selected_results.push(result.clone());
286            }
287        }
288
289        // If we don't have enough results, fill with remaining high-scoring ones
290        if selected_results.len() < max_results {
291            for result in results {
292                if selected_results.len() >= max_results {
293                    break;
294                }
295                if !selected_results.iter().any(|r| r.id == result.id) {
296                    selected_results.push(result);
297                }
298            }
299        }
300
301        selected_results
302    }
303
304    /// Get adaptive retrieval statistics
305    pub fn get_statistics(&self) -> AdaptiveRetrieverStatistics {
306        AdaptiveRetrieverStatistics {
307            config: self.config.clone(),
308            retrieval_system_stats: format!("RetrievalSystem with {} strategies", 4),
309        }
310    }
311}
312
313/// Statistics about the adaptive retriever
314#[derive(Debug)]
315pub struct AdaptiveRetrieverStatistics {
316    /// Configuration used by the retriever
317    pub config: AdaptiveConfig,
318    /// Summary statistics from underlying retrieval system
319    pub retrieval_system_stats: String,
320}
321
322impl AdaptiveRetrieverStatistics {
323    /// Print adaptive retriever statistics to stdout
324    pub fn print(&self) {
325        println!("Adaptive Retriever Statistics:");
326        println!(
327            "  Min confidence for specialization: {:.2}",
328            self.config.min_confidence_for_specialization
329        );
330        println!(
331            "  Results per strategy: {}",
332            self.config.results_per_strategy
333        );
334        println!(
335            "  Entity weights: V:{:.2} G:{:.2} H:{:.2} B:{:.2}",
336            self.config.entity_weights.vector_weight,
337            self.config.entity_weights.graph_weight,
338            self.config.entity_weights.hierarchical_weight,
339            self.config.entity_weights.bm25_weight
340        );
341        println!(
342            "  Factual weights: V:{:.2} G:{:.2} H:{:.2} B:{:.2}",
343            self.config.factual_weights.vector_weight,
344            self.config.factual_weights.graph_weight,
345            self.config.factual_weights.hierarchical_weight,
346            self.config.factual_weights.bm25_weight
347        );
348        println!("  {}", self.retrieval_system_stats);
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    #[test]
357    fn test_strategy_weight_selection() {
358        let _config = AdaptiveConfig::default();
359
360        // Mock query analysis for entity-focused query
361        let entity_analysis = QueryAnalysisResult {
362            query_type: QueryType::EntityFocused,
363            confidence: 0.8,
364            keywords_matched: vec!["who".to_string()],
365            suggested_strategies: vec!["entity_search".to_string()],
366            complexity_score: 0.2,
367        };
368
369        // Test that we would select entity weights for high-confidence entity query
370        // This is a unit test for the weight selection logic
371        assert_eq!(entity_analysis.query_type, QueryType::EntityFocused);
372        assert!(entity_analysis.confidence > 0.6);
373    }
374
375    #[test]
376    fn test_diversity_aware_selection() {
377        // Create mock adaptive retriever with default config
378        let config = AdaptiveConfig::default();
379
380        // Test diversity logic by checking that the function exists
381        // In a real test environment, we would create full mock objects
382        assert!(config.min_confidence_for_specialization > 0.0);
383    }
384}