1use crate::{
4 core::KnowledgeGraph,
5 retrieval::{QueryAnalysisResult, QueryType, RetrievalSystem, SearchResult},
6 summarization::DocumentTree,
7 vector::VectorIndex,
8 Result,
9};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone)]
14pub struct StrategyWeights {
15 pub vector_weight: f32,
17 pub graph_weight: f32,
19 pub hierarchical_weight: f32,
21 pub bm25_weight: f32,
23}
24
25impl Default for StrategyWeights {
26 fn default() -> Self {
27 Self {
28 vector_weight: 0.25,
29 graph_weight: 0.25,
30 hierarchical_weight: 0.25,
31 bm25_weight: 0.25,
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
38pub struct AdaptiveConfig {
39 pub entity_weights: StrategyWeights,
41 pub conceptual_weights: StrategyWeights,
43 pub factual_weights: StrategyWeights,
45 pub relational_weights: StrategyWeights,
47 pub complex_weights: StrategyWeights,
49 pub min_confidence_for_specialization: f32,
51 pub results_per_strategy: usize,
53}
54
55impl Default for AdaptiveConfig {
56 fn default() -> Self {
57 Self {
58 entity_weights: StrategyWeights {
59 vector_weight: 0.2,
60 graph_weight: 0.5,
61 hierarchical_weight: 0.2,
62 bm25_weight: 0.1,
63 },
64 conceptual_weights: StrategyWeights {
65 vector_weight: 0.6,
66 graph_weight: 0.1,
67 hierarchical_weight: 0.3,
68 bm25_weight: 0.0,
69 },
70 factual_weights: StrategyWeights {
71 vector_weight: 0.2,
72 graph_weight: 0.1,
73 hierarchical_weight: 0.1,
74 bm25_weight: 0.6,
75 },
76 relational_weights: StrategyWeights {
77 vector_weight: 0.2,
78 graph_weight: 0.6,
79 hierarchical_weight: 0.1,
80 bm25_weight: 0.1,
81 },
82 complex_weights: StrategyWeights::default(),
83 min_confidence_for_specialization: 0.6,
84 results_per_strategy: 10,
85 }
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct AdaptiveRetrievalResult {
92 pub results: Vec<SearchResult>,
94 pub strategy_weights_used: StrategyWeights,
96 pub query_analysis: QueryAnalysisResult,
98 pub fusion_method: String,
100 pub total_results_before_fusion: usize,
102}
103
104pub struct AdaptiveRetriever {
106 config: AdaptiveConfig,
107 retrieval_system: RetrievalSystem,
108}
109
110impl AdaptiveRetriever {
111 pub fn new(
113 config: AdaptiveConfig,
114 _vector_index: VectorIndex,
115 _knowledge_graph: KnowledgeGraph,
116 _document_trees: HashMap<String, DocumentTree>,
117 ) -> Result<Self> {
118 let default_config = crate::config::Config::default();
120 let retrieval_system = RetrievalSystem::new(&default_config)?;
121
122 Ok(Self {
123 config,
124 retrieval_system,
125 })
126 }
127
128 pub async fn retrieve(
130 &mut self,
131 query: &str,
132 query_analysis: &QueryAnalysisResult,
133 max_results: usize,
134 ) -> Result<AdaptiveRetrievalResult> {
135 let strategy_weights = self.select_strategy_weights(query_analysis);
137
138 let mut all_results = Vec::new();
140
141 if strategy_weights.vector_weight > 0.0 {
143 let vector_results = self
144 .retrieval_system
145 .vector_search(
146 query,
147 (self.config.results_per_strategy as f32 * strategy_weights.vector_weight)
148 as usize,
149 )
150 .await?;
151 all_results.extend(self.weight_results(vector_results, strategy_weights.vector_weight));
152 }
153
154 if strategy_weights.graph_weight > 0.0 {
156 let graph_results = self.retrieval_system.graph_search(
157 query,
158 (self.config.results_per_strategy as f32 * strategy_weights.graph_weight) as usize,
159 )?;
160 all_results.extend(self.weight_results(graph_results, strategy_weights.graph_weight));
161 }
162
163 if strategy_weights.hierarchical_weight > 0.0 {
165 let max_results = (self.config.results_per_strategy as f32
166 * strategy_weights.hierarchical_weight) as usize;
167 let hierarchical_results = self
168 .retrieval_system
169 .public_hierarchical_search(query, max_results)?;
170 all_results.extend(
171 self.weight_results(hierarchical_results, strategy_weights.hierarchical_weight),
172 );
173 }
174
175 if strategy_weights.bm25_weight > 0.0 {
177 let bm25_results = self.retrieval_system.bm25_search(
178 query,
179 (self.config.results_per_strategy as f32 * strategy_weights.bm25_weight) as usize,
180 )?;
181 all_results.extend(self.weight_results(bm25_results, strategy_weights.bm25_weight));
182 }
183
184 let total_results_before_fusion = all_results.len();
185
186 let fused_results = self.cross_strategy_fusion(all_results, max_results)?;
188
189 Ok(AdaptiveRetrievalResult {
190 results: fused_results,
191 strategy_weights_used: strategy_weights,
192 query_analysis: query_analysis.clone(),
193 fusion_method: "weighted_score_fusion".to_string(),
194 total_results_before_fusion,
195 })
196 }
197
198 fn select_strategy_weights(&self, query_analysis: &QueryAnalysisResult) -> StrategyWeights {
200 if query_analysis.confidence < self.config.min_confidence_for_specialization {
202 return self.config.complex_weights.clone();
203 }
204
205 match query_analysis.query_type {
207 QueryType::EntityFocused => self.config.entity_weights.clone(),
208 QueryType::Conceptual => self.config.conceptual_weights.clone(),
209 QueryType::Factual => self.config.factual_weights.clone(),
210 QueryType::Relationship => self.config.relational_weights.clone(),
211 QueryType::Exploratory => self.config.complex_weights.clone(),
212 }
213 }
214
215 fn weight_results(&self, mut results: Vec<SearchResult>, weight: f32) -> Vec<SearchResult> {
217 for result in &mut results {
218 result.score *= weight;
219 }
220 results
221 }
222
223 fn cross_strategy_fusion(
225 &self,
226 results: Vec<SearchResult>,
227 max_results: usize,
228 ) -> Result<Vec<SearchResult>> {
229 let mut seen_chunks = HashMap::new();
231 let mut deduplicated_results = Vec::new();
232
233 for result in results {
234 let chunk_id = &result.id;
235
236 if let Some(existing_score) = seen_chunks.get(chunk_id) {
237 if result.score > *existing_score {
238 seen_chunks.insert(chunk_id.clone(), result.score);
240 deduplicated_results.retain(|r: &SearchResult| r.id != *chunk_id);
242 deduplicated_results.push(result);
243 }
244 } else {
245 seen_chunks.insert(chunk_id.clone(), result.score);
246 deduplicated_results.push(result);
247 }
248 }
249
250 deduplicated_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
252
253 let final_results = self.diversity_aware_selection(deduplicated_results, max_results);
255
256 Ok(final_results)
257 }
258
259 fn diversity_aware_selection(
261 &self,
262 results: Vec<SearchResult>,
263 max_results: usize,
264 ) -> Vec<SearchResult> {
265 let mut selected_results = Vec::new();
266 let mut selected_entities = std::collections::HashSet::new();
267 let _remaining_results = results.clone();
268
269 for result in &results {
270 if selected_results.len() >= max_results {
271 break;
272 }
273
274 let has_new_entities = result
276 .entities
277 .iter()
278 .any(|entity| !selected_entities.contains(entity));
279
280 if result.score > 0.8 || has_new_entities || selected_results.len() < max_results / 2 {
282 for entity in &result.entities {
283 selected_entities.insert(entity.clone());
284 }
285 selected_results.push(result.clone());
286 }
287 }
288
289 if selected_results.len() < max_results {
291 for result in results {
292 if selected_results.len() >= max_results {
293 break;
294 }
295 if !selected_results.iter().any(|r| r.id == result.id) {
296 selected_results.push(result);
297 }
298 }
299 }
300
301 selected_results
302 }
303
304 pub fn get_statistics(&self) -> AdaptiveRetrieverStatistics {
306 AdaptiveRetrieverStatistics {
307 config: self.config.clone(),
308 retrieval_system_stats: format!("RetrievalSystem with {} strategies", 4),
309 }
310 }
311}
312
313#[derive(Debug)]
315pub struct AdaptiveRetrieverStatistics {
316 pub config: AdaptiveConfig,
318 pub retrieval_system_stats: String,
320}
321
322impl AdaptiveRetrieverStatistics {
323 pub fn print(&self) {
325 println!("Adaptive Retriever Statistics:");
326 println!(
327 " Min confidence for specialization: {:.2}",
328 self.config.min_confidence_for_specialization
329 );
330 println!(
331 " Results per strategy: {}",
332 self.config.results_per_strategy
333 );
334 println!(
335 " Entity weights: V:{:.2} G:{:.2} H:{:.2} B:{:.2}",
336 self.config.entity_weights.vector_weight,
337 self.config.entity_weights.graph_weight,
338 self.config.entity_weights.hierarchical_weight,
339 self.config.entity_weights.bm25_weight
340 );
341 println!(
342 " Factual weights: V:{:.2} G:{:.2} H:{:.2} B:{:.2}",
343 self.config.factual_weights.vector_weight,
344 self.config.factual_weights.graph_weight,
345 self.config.factual_weights.hierarchical_weight,
346 self.config.factual_weights.bm25_weight
347 );
348 println!(" {}", self.retrieval_system_stats);
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
357 fn test_strategy_weight_selection() {
358 let _config = AdaptiveConfig::default();
359
360 let entity_analysis = QueryAnalysisResult {
362 query_type: QueryType::EntityFocused,
363 confidence: 0.8,
364 keywords_matched: vec!["who".to_string()],
365 suggested_strategies: vec!["entity_search".to_string()],
366 complexity_score: 0.2,
367 };
368
369 assert_eq!(entity_analysis.query_type, QueryType::EntityFocused);
372 assert!(entity_analysis.confidence > 0.6);
373 }
374
375 #[test]
376 fn test_diversity_aware_selection() {
377 let config = AdaptiveConfig::default();
379
380 assert!(config.min_confidence_for_specialization > 0.0);
383 }
384}