Skip to main content

reddb_server/storage/query/rag/
fusion.rs

1//! Context Fusion and Re-ranking
2//!
3//! Advanced algorithms for combining and ranking results from multiple
4//! retrieval sources (vectors, graphs, tables) to produce optimal RAG context.
5//!
6//! # Algorithms
7//!
8//! - **Reciprocal Rank Fusion (RRF)**: Combines rankings from multiple sources
9//! - **Graph-Aware Re-ranking**: Boosts entities connected to high-scoring ones
10//! - **Deduplication**: Removes semantically similar chunks
11//! - **Diversification**: Ensures variety in entity types and sources
12
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15
16use super::context::{ChunkSource, ContextChunk, RetrievalContext};
17use super::EntityType;
18use crate::storage::{EntityId, RefType, Store};
19
20/// Configuration for context fusion
21#[derive(Debug, Clone)]
22pub struct FusionConfig {
23    /// RRF constant k (typically 60)
24    pub rrf_k: f32,
25    /// Weight for vector similarity scores
26    pub vector_weight: f32,
27    /// Weight for graph-based scores
28    pub graph_weight: f32,
29    /// Weight for table/structured matches
30    pub table_weight: f32,
31    /// Cross-reference boost factor
32    pub cross_ref_boost: f32,
33    /// Minimum similarity for deduplication
34    pub dedup_threshold: f32,
35    /// Enable diversification
36    pub diversify: bool,
37    /// Maximum chunks per entity type
38    pub max_per_type: usize,
39    /// Enable graph-aware re-ranking
40    pub graph_rerank: bool,
41}
42
43impl Default for FusionConfig {
44    fn default() -> Self {
45        Self {
46            rrf_k: 60.0,
47            vector_weight: 0.5,
48            graph_weight: 0.3,
49            table_weight: 0.2,
50            cross_ref_boost: 0.15,
51            dedup_threshold: 0.85,
52            diversify: true,
53            max_per_type: 5,
54            graph_rerank: true,
55        }
56    }
57}
58
59/// Context fusion engine
60pub struct ContextFusion {
61    /// Fusion configuration
62    config: FusionConfig,
63    /// Optional store for cross-reference lookup
64    store: Option<Arc<Store>>,
65}
66
67impl ContextFusion {
68    /// Create a new fusion engine with default config
69    pub fn new() -> Self {
70        Self {
71            config: FusionConfig::default(),
72            store: None,
73        }
74    }
75
76    /// Create with custom config
77    pub fn with_config(config: FusionConfig) -> Self {
78        Self {
79            config,
80            store: None,
81        }
82    }
83
84    /// Attach store for graph-aware operations
85    pub fn with_store(mut self, store: Arc<Store>) -> Self {
86        self.store = Some(store);
87        self
88    }
89
90    /// Apply full fusion pipeline to a context
91    pub fn fuse(&self, context: &mut RetrievalContext) {
92        // 1. Normalize scores per source
93        self.normalize_scores(context);
94
95        // 2. Apply RRF if multiple sources
96        if context.sources_used.len() > 1 {
97            self.apply_rrf(context);
98        }
99
100        // 3. Graph-aware re-ranking
101        if self.config.graph_rerank {
102            self.graph_rerank(context);
103        }
104
105        // 4. Deduplicate similar chunks
106        self.deduplicate(context);
107
108        // 5. Diversify results
109        if self.config.diversify {
110            self.diversify(context);
111        }
112
113        // 6. Final sort
114        context.sort_by_relevance();
115    }
116
117    /// Normalize scores within each source type to [0, 1]
118    fn normalize_scores(&self, context: &mut RetrievalContext) {
119        // Group by source type
120        let mut vector_chunks: Vec<usize> = Vec::new();
121        let mut graph_chunks: Vec<usize> = Vec::new();
122        let mut table_chunks: Vec<usize> = Vec::new();
123        let mut other_chunks: Vec<usize> = Vec::new();
124
125        for (i, chunk) in context.chunks.iter().enumerate() {
126            match chunk.source {
127                ChunkSource::Vector(_) => vector_chunks.push(i),
128                ChunkSource::Graph => graph_chunks.push(i),
129                ChunkSource::Table(_) => table_chunks.push(i),
130                _ => other_chunks.push(i),
131            }
132        }
133
134        // Normalize each group
135        self.normalize_group(&mut context.chunks, &vector_chunks);
136        self.normalize_group(&mut context.chunks, &graph_chunks);
137        self.normalize_group(&mut context.chunks, &table_chunks);
138    }
139
140    /// Normalize a group of chunks by index
141    fn normalize_group(&self, chunks: &mut [ContextChunk], indices: &[usize]) {
142        if indices.is_empty() {
143            return;
144        }
145
146        let max_score = indices
147            .iter()
148            .map(|&i| chunks[i].relevance)
149            .fold(f32::NEG_INFINITY, f32::max);
150        let min_score = indices
151            .iter()
152            .map(|&i| chunks[i].relevance)
153            .fold(f32::INFINITY, f32::min);
154
155        let range = max_score - min_score;
156        if range > 0.0001 {
157            for &i in indices {
158                chunks[i].relevance = (chunks[i].relevance - min_score) / range;
159            }
160        }
161    }
162
163    /// Apply Reciprocal Rank Fusion across sources
164    fn apply_rrf(&self, context: &mut RetrievalContext) {
165        // Build rankings per source
166        let mut vector_rankings: HashMap<String, usize> = HashMap::new();
167        let mut graph_rankings: HashMap<String, usize> = HashMap::new();
168        let mut table_rankings: HashMap<String, usize> = HashMap::new();
169
170        // Sort by relevance within each source and assign ranks
171        let mut by_source: HashMap<String, Vec<(usize, f32)>> = HashMap::new();
172        for (i, chunk) in context.chunks.iter().enumerate() {
173            let source_key = match &chunk.source {
174                ChunkSource::Vector(c) => format!("vector:{}", c),
175                ChunkSource::Graph => "graph".to_string(),
176                ChunkSource::Table(t) => format!("table:{}", t),
177                _ => "other".to_string(),
178            };
179            by_source
180                .entry(source_key)
181                .or_default()
182                .push((i, chunk.relevance));
183        }
184
185        // Assign ranks
186        for (source, mut items) in by_source {
187            items.sort_by(|a, b| {
188                b.1.partial_cmp(&a.1)
189                    .unwrap_or(std::cmp::Ordering::Equal)
190                    .then_with(|| a.0.cmp(&b.0))
191            });
192            for (rank, (idx, _)) in items.iter().enumerate() {
193                let key = format!("chunk_{}", idx);
194                if source.starts_with("vector") {
195                    vector_rankings.insert(key, rank + 1);
196                } else if source == "graph" {
197                    graph_rankings.insert(key, rank + 1);
198                } else if source.starts_with("table") {
199                    table_rankings.insert(key, rank + 1);
200                }
201            }
202        }
203
204        // Calculate RRF scores
205        let k = self.config.rrf_k;
206        for (i, chunk) in context.chunks.iter_mut().enumerate() {
207            let key = format!("chunk_{}", i);
208
209            let mut rrf_score = 0.0;
210
211            if let Some(&rank) = vector_rankings.get(&key) {
212                rrf_score += self.config.vector_weight * (1.0 / (k + rank as f32));
213            }
214            if let Some(&rank) = graph_rankings.get(&key) {
215                rrf_score += self.config.graph_weight * (1.0 / (k + rank as f32));
216            }
217            if let Some(&rank) = table_rankings.get(&key) {
218                rrf_score += self.config.table_weight * (1.0 / (k + rank as f32));
219            }
220
221            // Blend RRF with original relevance
222            chunk.relevance = 0.6 * chunk.relevance + 0.4 * rrf_score * 100.0;
223        }
224    }
225
226    /// Re-rank based on graph relationships
227    fn graph_rerank(&self, context: &mut RetrievalContext) {
228        let store = match &self.store {
229            Some(s) => s,
230            None => return,
231        };
232
233        // Build entity ID to chunk index mapping
234        let mut entity_chunks: HashMap<EntityId, Vec<usize>> = HashMap::new();
235        for (i, chunk) in context.chunks.iter().enumerate() {
236            if let Some(ref id_str) = chunk.entity_id {
237                if let Ok(id) = id_str.parse::<u64>() {
238                    entity_chunks.entry(EntityId(id)).or_default().push(i);
239                }
240            }
241        }
242
243        // For each entity, boost chunks connected to it
244        let mut boosts: HashMap<usize, f32> = HashMap::new();
245
246        for (entity_id, chunk_indices) in &entity_chunks {
247            // Get cross-references from this entity
248            let refs_from = store.get_refs_from(*entity_id);
249
250            for (target_id, ref_type, _collection) in refs_from {
251                if let Some(target_chunks) = entity_chunks.get(&target_id) {
252                    // Calculate boost based on reference type and source relevance
253                    let source_relevance: f32 = chunk_indices
254                        .iter()
255                        .map(|&i| context.chunks[i].relevance)
256                        .sum::<f32>()
257                        / chunk_indices.len() as f32;
258
259                    let type_multiplier = match ref_type {
260                        RefType::RelatedTo | RefType::DerivesFrom => 1.0,
261                        RefType::Mentions | RefType::Contains => 0.8,
262                        RefType::DependsOn => 0.7,
263                        RefType::SimilarTo => 0.5,
264                        _ => 0.3,
265                    };
266
267                    let boost = self.config.cross_ref_boost * source_relevance * type_multiplier;
268
269                    for &chunk_idx in target_chunks {
270                        *boosts.entry(chunk_idx).or_insert(0.0) += boost;
271                    }
272                }
273            }
274        }
275
276        // Apply boosts
277        for (idx, boost) in boosts {
278            context.chunks[idx].relevance += boost;
279        }
280    }
281
282    /// Remove semantically similar chunks
283    fn deduplicate(&self, context: &mut RetrievalContext) {
284        if context.chunks.len() < 2 {
285            return;
286        }
287
288        let mut to_remove: HashSet<usize> = HashSet::new();
289        let threshold = self.config.dedup_threshold;
290
291        for i in 0..context.chunks.len() {
292            if to_remove.contains(&i) {
293                continue;
294            }
295
296            for j in (i + 1)..context.chunks.len() {
297                if to_remove.contains(&j) {
298                    continue;
299                }
300
301                let similarity =
302                    self.content_similarity(&context.chunks[i].content, &context.chunks[j].content);
303
304                if similarity > threshold {
305                    // Keep the one with higher relevance
306                    if context.chunks[i].relevance >= context.chunks[j].relevance {
307                        to_remove.insert(j);
308                    } else {
309                        to_remove.insert(i);
310                        break;
311                    }
312                }
313            }
314        }
315
316        // Remove duplicates (in reverse order to maintain indices)
317        let mut indices: Vec<usize> = to_remove.into_iter().collect();
318        indices.sort_by(|a, b| b.cmp(a));
319        for idx in indices {
320            context.chunks.remove(idx);
321        }
322    }
323
324    /// Calculate content similarity using Jaccard on n-grams
325    fn content_similarity(&self, a: &str, b: &str) -> f32 {
326        if a.is_empty() || b.is_empty() {
327            return 0.0;
328        }
329
330        let ngrams_a = self.extract_ngrams(a, 3);
331        let ngrams_b = self.extract_ngrams(b, 3);
332
333        if ngrams_a.is_empty() || ngrams_b.is_empty() {
334            return 0.0;
335        }
336
337        let intersection = ngrams_a.intersection(&ngrams_b).count();
338        let union = ngrams_a.union(&ngrams_b).count();
339
340        if union == 0 {
341            0.0
342        } else {
343            intersection as f32 / union as f32
344        }
345    }
346
347    /// Extract character n-grams from text
348    fn extract_ngrams(&self, text: &str, n: usize) -> HashSet<String> {
349        let text = text.to_lowercase();
350        let chars: Vec<char> = text.chars().collect();
351
352        if chars.len() < n {
353            return HashSet::new();
354        }
355
356        (0..=chars.len() - n)
357            .map(|i| chars[i..i + n].iter().collect())
358            .collect()
359    }
360
361    /// Diversify results by entity type
362    fn diversify(&self, context: &mut RetrievalContext) {
363        let max_per_type = self.config.max_per_type;
364
365        // Count by entity type
366        let mut type_counts: HashMap<EntityType, usize> = HashMap::new();
367        let mut to_remove: HashSet<usize> = HashSet::new();
368
369        // Process in relevance order (already sorted)
370        for (i, chunk) in context.chunks.iter().enumerate() {
371            let entity_type = chunk.entity_type.unwrap_or(EntityType::Unknown);
372            let count = type_counts.entry(entity_type).or_insert(0);
373
374            if *count >= max_per_type {
375                to_remove.insert(i);
376            } else {
377                *count += 1;
378            }
379        }
380
381        // Remove excess chunks
382        let mut indices: Vec<usize> = to_remove.into_iter().collect();
383        indices.sort_by(|a, b| b.cmp(a));
384        for idx in indices {
385            context.chunks.remove(idx);
386        }
387    }
388}
389
390impl Default for ContextFusion {
391    fn default() -> Self {
392        Self::new()
393    }
394}
395
396/// Result re-ranker for final scoring
397pub struct ResultReranker {
398    /// Weights for different scoring factors
399    pub relevance_weight: f32,
400    pub recency_weight: f32,
401    pub connection_weight: f32,
402    pub type_priority: HashMap<EntityType, f32>,
403}
404
405impl Default for ResultReranker {
406    fn default() -> Self {
407        let mut type_priority = HashMap::new();
408        type_priority.insert(EntityType::Vulnerability, 1.0);
409        type_priority.insert(EntityType::Host, 0.9);
410        type_priority.insert(EntityType::Service, 0.85);
411        type_priority.insert(EntityType::Credential, 0.95);
412        type_priority.insert(EntityType::Certificate, 0.7);
413        type_priority.insert(EntityType::Domain, 0.75);
414        type_priority.insert(EntityType::Unknown, 0.5);
415
416        Self {
417            relevance_weight: 0.6,
418            recency_weight: 0.2,
419            connection_weight: 0.2,
420            type_priority,
421        }
422    }
423}
424
425impl ResultReranker {
426    /// Rerank chunks with multiple factors
427    pub fn rerank(&self, context: &mut RetrievalContext) {
428        for chunk in &mut context.chunks {
429            let mut final_score = self.relevance_weight * chunk.relevance;
430
431            // Type priority boost
432            let type_boost = chunk
433                .entity_type
434                .and_then(|t| self.type_priority.get(&t))
435                .unwrap_or(&0.5);
436            final_score += 0.1 * type_boost;
437
438            // Connection bonus (from graph depth)
439            if let Some(depth) = chunk.graph_depth {
440                // Closer connections score higher
441                let connection_score = 1.0 / (1.0 + depth as f32);
442                final_score += self.connection_weight * connection_score;
443            }
444
445            chunk.relevance = final_score;
446        }
447
448        context.sort_by_relevance();
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455
456    #[test]
457    fn test_content_similarity() {
458        let fusion = ContextFusion::new();
459
460        let sim1 = fusion.content_similarity("This is a test string", "This is a test string");
461        assert!((sim1 - 1.0).abs() < 0.001);
462
463        let sim2 = fusion.content_similarity("completely different", "nothing alike");
464        assert!(sim2 < 0.5);
465
466        let sim3 = fusion.content_similarity("vulnerability in nginx", "vulnerability in apache");
467        assert!(sim3 > 0.3 && sim3 < 0.8);
468    }
469
470    #[test]
471    fn test_ngram_extraction() {
472        let fusion = ContextFusion::new();
473
474        let ngrams = fusion.extract_ngrams("hello", 3);
475        assert!(ngrams.contains("hel"));
476        assert!(ngrams.contains("ell"));
477        assert!(ngrams.contains("llo"));
478        assert_eq!(ngrams.len(), 3);
479    }
480
481    #[test]
482    fn test_fusion_config_defaults() {
483        let config = FusionConfig::default();
484        assert_eq!(config.rrf_k, 60.0);
485        assert!(config.diversify);
486        assert!(config.graph_rerank);
487    }
488}