1use crate::{
4 core::KnowledgeGraph,
5 retrieval::{QueryAnalysisResult, QueryType, RetrievalSystem, SearchResult},
6 summarization::DocumentTree,
7 vector::VectorIndex,
8 Result,
9};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone)]
14pub struct StrategyWeights {
15 pub vector_weight: f32,
17 pub graph_weight: f32,
19 pub hierarchical_weight: f32,
21 pub bm25_weight: f32,
23}
24
25impl Default for StrategyWeights {
26 fn default() -> Self {
27 Self {
28 vector_weight: 0.25,
29 graph_weight: 0.25,
30 hierarchical_weight: 0.25,
31 bm25_weight: 0.25,
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
38pub struct AdaptiveConfig {
39 pub entity_weights: StrategyWeights,
41 pub conceptual_weights: StrategyWeights,
43 pub factual_weights: StrategyWeights,
45 pub relational_weights: StrategyWeights,
47 pub complex_weights: StrategyWeights,
49 pub min_confidence_for_specialization: f32,
51 pub results_per_strategy: usize,
53}
54
55impl Default for AdaptiveConfig {
56 fn default() -> Self {
57 Self {
58 entity_weights: StrategyWeights {
59 vector_weight: 0.2,
60 graph_weight: 0.5,
61 hierarchical_weight: 0.2,
62 bm25_weight: 0.1,
63 },
64 conceptual_weights: StrategyWeights {
65 vector_weight: 0.6,
66 graph_weight: 0.1,
67 hierarchical_weight: 0.3,
68 bm25_weight: 0.0,
69 },
70 factual_weights: StrategyWeights {
71 vector_weight: 0.2,
72 graph_weight: 0.1,
73 hierarchical_weight: 0.1,
74 bm25_weight: 0.6,
75 },
76 relational_weights: StrategyWeights {
77 vector_weight: 0.2,
78 graph_weight: 0.6,
79 hierarchical_weight: 0.1,
80 bm25_weight: 0.1,
81 },
82 complex_weights: StrategyWeights::default(),
83 min_confidence_for_specialization: 0.6,
84 results_per_strategy: 10,
85 }
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct AdaptiveRetrievalResult {
92 pub results: Vec<SearchResult>,
94 pub strategy_weights_used: StrategyWeights,
96 pub query_analysis: QueryAnalysisResult,
98 pub fusion_method: String,
100 pub total_results_before_fusion: usize,
102}
103
104pub struct AdaptiveRetriever {
106 config: AdaptiveConfig,
107 retrieval_system: RetrievalSystem,
108}
109
110impl AdaptiveRetriever {
111 pub fn new(
113 config: AdaptiveConfig,
114 _vector_index: VectorIndex,
115 _knowledge_graph: KnowledgeGraph,
116 _document_trees: HashMap<String, DocumentTree>,
117 ) -> Result<Self> {
118 let default_config = crate::config::Config::default();
120 let retrieval_system = RetrievalSystem::new(&default_config)?;
121
122 Ok(Self {
123 config,
124 retrieval_system,
125 })
126 }
127
128 pub fn retrieve(
130 &mut self,
131 query: &str,
132 query_analysis: &QueryAnalysisResult,
133 max_results: usize,
134 ) -> Result<AdaptiveRetrievalResult> {
135 let strategy_weights = self.select_strategy_weights(query_analysis);
137
138 let mut all_results = Vec::new();
140
141 if strategy_weights.vector_weight > 0.0 {
143 let vector_results = self.retrieval_system.vector_search(
144 query,
145 (self.config.results_per_strategy as f32 * strategy_weights.vector_weight) as usize,
146 )?;
147 all_results.extend(self.weight_results(vector_results, strategy_weights.vector_weight));
148 }
149
150 if strategy_weights.graph_weight > 0.0 {
152 let graph_results = self.retrieval_system.graph_search(
153 query,
154 (self.config.results_per_strategy as f32 * strategy_weights.graph_weight) as usize,
155 )?;
156 all_results.extend(self.weight_results(graph_results, strategy_weights.graph_weight));
157 }
158
159 if strategy_weights.hierarchical_weight > 0.0 {
161 let max_results = (self.config.results_per_strategy as f32
162 * strategy_weights.hierarchical_weight) as usize;
163 let hierarchical_results = self
164 .retrieval_system
165 .public_hierarchical_search(query, max_results)?;
166 all_results.extend(
167 self.weight_results(hierarchical_results, strategy_weights.hierarchical_weight),
168 );
169 }
170
171 if strategy_weights.bm25_weight > 0.0 {
173 let bm25_results = self.retrieval_system.bm25_search(
174 query,
175 (self.config.results_per_strategy as f32 * strategy_weights.bm25_weight) as usize,
176 )?;
177 all_results.extend(self.weight_results(bm25_results, strategy_weights.bm25_weight));
178 }
179
180 let total_results_before_fusion = all_results.len();
181
182 let fused_results = self.cross_strategy_fusion(all_results, max_results)?;
184
185 Ok(AdaptiveRetrievalResult {
186 results: fused_results,
187 strategy_weights_used: strategy_weights,
188 query_analysis: query_analysis.clone(),
189 fusion_method: "weighted_score_fusion".to_string(),
190 total_results_before_fusion,
191 })
192 }
193
194 fn select_strategy_weights(&self, query_analysis: &QueryAnalysisResult) -> StrategyWeights {
196 if query_analysis.confidence < self.config.min_confidence_for_specialization {
198 return self.config.complex_weights.clone();
199 }
200
201 match query_analysis.query_type {
203 QueryType::EntityFocused => self.config.entity_weights.clone(),
204 QueryType::Conceptual => self.config.conceptual_weights.clone(),
205 QueryType::Factual => self.config.factual_weights.clone(),
206 QueryType::Relationship => self.config.relational_weights.clone(),
207 QueryType::Exploratory => self.config.complex_weights.clone(),
208 }
209 }
210
211 fn weight_results(&self, mut results: Vec<SearchResult>, weight: f32) -> Vec<SearchResult> {
213 for result in &mut results {
214 result.score *= weight;
215 }
216 results
217 }
218
219 fn cross_strategy_fusion(
221 &self,
222 results: Vec<SearchResult>,
223 max_results: usize,
224 ) -> Result<Vec<SearchResult>> {
225 let mut seen_chunks = HashMap::new();
227 let mut deduplicated_results = Vec::new();
228
229 for result in results {
230 let chunk_id = &result.id;
231
232 if let Some(existing_score) = seen_chunks.get(chunk_id) {
233 if result.score > *existing_score {
234 seen_chunks.insert(chunk_id.clone(), result.score);
236 deduplicated_results.retain(|r: &SearchResult| r.id != *chunk_id);
238 deduplicated_results.push(result);
239 }
240 } else {
241 seen_chunks.insert(chunk_id.clone(), result.score);
242 deduplicated_results.push(result);
243 }
244 }
245
246 deduplicated_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
248
249 let final_results = self.diversity_aware_selection(deduplicated_results, max_results);
251
252 Ok(final_results)
253 }
254
255 fn diversity_aware_selection(
257 &self,
258 results: Vec<SearchResult>,
259 max_results: usize,
260 ) -> Vec<SearchResult> {
261 let mut selected_results = Vec::new();
262 let mut selected_entities = std::collections::HashSet::new();
263 let _remaining_results = results.clone();
264
265 for result in &results {
266 if selected_results.len() >= max_results {
267 break;
268 }
269
270 let has_new_entities = result
272 .entities
273 .iter()
274 .any(|entity| !selected_entities.contains(entity));
275
276 if result.score > 0.8 || has_new_entities || selected_results.len() < max_results / 2 {
278 for entity in &result.entities {
279 selected_entities.insert(entity.clone());
280 }
281 selected_results.push(result.clone());
282 }
283 }
284
285 if selected_results.len() < max_results {
287 for result in results {
288 if selected_results.len() >= max_results {
289 break;
290 }
291 if !selected_results.iter().any(|r| r.id == result.id) {
292 selected_results.push(result);
293 }
294 }
295 }
296
297 selected_results
298 }
299
300 pub fn get_statistics(&self) -> AdaptiveRetrieverStatistics {
302 AdaptiveRetrieverStatistics {
303 config: self.config.clone(),
304 retrieval_system_stats: format!("RetrievalSystem with {} strategies", 4),
305 }
306 }
307}
308
309#[derive(Debug)]
311pub struct AdaptiveRetrieverStatistics {
312 pub config: AdaptiveConfig,
314 pub retrieval_system_stats: String,
316}
317
318impl AdaptiveRetrieverStatistics {
319 pub fn print(&self) {
321 println!("Adaptive Retriever Statistics:");
322 println!(
323 " Min confidence for specialization: {:.2}",
324 self.config.min_confidence_for_specialization
325 );
326 println!(
327 " Results per strategy: {}",
328 self.config.results_per_strategy
329 );
330 println!(
331 " Entity weights: V:{:.2} G:{:.2} H:{:.2} B:{:.2}",
332 self.config.entity_weights.vector_weight,
333 self.config.entity_weights.graph_weight,
334 self.config.entity_weights.hierarchical_weight,
335 self.config.entity_weights.bm25_weight
336 );
337 println!(
338 " Factual weights: V:{:.2} G:{:.2} H:{:.2} B:{:.2}",
339 self.config.factual_weights.vector_weight,
340 self.config.factual_weights.graph_weight,
341 self.config.factual_weights.hierarchical_weight,
342 self.config.factual_weights.bm25_weight
343 );
344 println!(" {}", self.retrieval_system_stats);
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 fn test_strategy_weight_selection() {
354 let _config = AdaptiveConfig::default();
355
356 let entity_analysis = QueryAnalysisResult {
358 query_type: QueryType::EntityFocused,
359 confidence: 0.8,
360 keywords_matched: vec!["who".to_string()],
361 suggested_strategies: vec!["entity_search".to_string()],
362 complexity_score: 0.2,
363 };
364
365 assert_eq!(entity_analysis.query_type, QueryType::EntityFocused);
368 assert!(entity_analysis.confidence > 0.6);
369 }
370
371 #[test]
372 fn test_diversity_aware_selection() {
373 let config = AdaptiveConfig::default();
375
376 assert!(config.min_confidence_for_specialization > 0.0);
379 }
380}