reddb_server/storage/query/rag/
fusion.rs1use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15
16use super::context::{ChunkSource, ContextChunk, RetrievalContext};
17use super::EntityType;
18use crate::storage::{EntityId, RefType, Store};
19
20#[derive(Debug, Clone)]
22pub struct FusionConfig {
23 pub rrf_k: f32,
25 pub vector_weight: f32,
27 pub graph_weight: f32,
29 pub table_weight: f32,
31 pub cross_ref_boost: f32,
33 pub dedup_threshold: f32,
35 pub diversify: bool,
37 pub max_per_type: usize,
39 pub graph_rerank: bool,
41}
42
43impl Default for FusionConfig {
44 fn default() -> Self {
45 Self {
46 rrf_k: 60.0,
47 vector_weight: 0.5,
48 graph_weight: 0.3,
49 table_weight: 0.2,
50 cross_ref_boost: 0.15,
51 dedup_threshold: 0.85,
52 diversify: true,
53 max_per_type: 5,
54 graph_rerank: true,
55 }
56 }
57}
58
59pub struct ContextFusion {
61 config: FusionConfig,
63 store: Option<Arc<Store>>,
65}
66
67impl ContextFusion {
68 pub fn new() -> Self {
70 Self {
71 config: FusionConfig::default(),
72 store: None,
73 }
74 }
75
76 pub fn with_config(config: FusionConfig) -> Self {
78 Self {
79 config,
80 store: None,
81 }
82 }
83
84 pub fn with_store(mut self, store: Arc<Store>) -> Self {
86 self.store = Some(store);
87 self
88 }
89
90 pub fn fuse(&self, context: &mut RetrievalContext) {
92 self.normalize_scores(context);
94
95 if context.sources_used.len() > 1 {
97 self.apply_rrf(context);
98 }
99
100 if self.config.graph_rerank {
102 self.graph_rerank(context);
103 }
104
105 self.deduplicate(context);
107
108 if self.config.diversify {
110 self.diversify(context);
111 }
112
113 context.sort_by_relevance();
115 }
116
117 fn normalize_scores(&self, context: &mut RetrievalContext) {
119 let mut vector_chunks: Vec<usize> = Vec::new();
121 let mut graph_chunks: Vec<usize> = Vec::new();
122 let mut table_chunks: Vec<usize> = Vec::new();
123 let mut other_chunks: Vec<usize> = Vec::new();
124
125 for (i, chunk) in context.chunks.iter().enumerate() {
126 match chunk.source {
127 ChunkSource::Vector(_) => vector_chunks.push(i),
128 ChunkSource::Graph => graph_chunks.push(i),
129 ChunkSource::Table(_) => table_chunks.push(i),
130 _ => other_chunks.push(i),
131 }
132 }
133
134 self.normalize_group(&mut context.chunks, &vector_chunks);
136 self.normalize_group(&mut context.chunks, &graph_chunks);
137 self.normalize_group(&mut context.chunks, &table_chunks);
138 }
139
140 fn normalize_group(&self, chunks: &mut [ContextChunk], indices: &[usize]) {
142 if indices.is_empty() {
143 return;
144 }
145
146 let max_score = indices
147 .iter()
148 .map(|&i| chunks[i].relevance)
149 .fold(f32::NEG_INFINITY, f32::max);
150 let min_score = indices
151 .iter()
152 .map(|&i| chunks[i].relevance)
153 .fold(f32::INFINITY, f32::min);
154
155 let range = max_score - min_score;
156 if range > 0.0001 {
157 for &i in indices {
158 chunks[i].relevance = (chunks[i].relevance - min_score) / range;
159 }
160 }
161 }
162
163 fn apply_rrf(&self, context: &mut RetrievalContext) {
165 let mut vector_rankings: HashMap<String, usize> = HashMap::new();
167 let mut graph_rankings: HashMap<String, usize> = HashMap::new();
168 let mut table_rankings: HashMap<String, usize> = HashMap::new();
169
170 let mut by_source: HashMap<String, Vec<(usize, f32)>> = HashMap::new();
172 for (i, chunk) in context.chunks.iter().enumerate() {
173 let source_key = match &chunk.source {
174 ChunkSource::Vector(c) => format!("vector:{}", c),
175 ChunkSource::Graph => "graph".to_string(),
176 ChunkSource::Table(t) => format!("table:{}", t),
177 _ => "other".to_string(),
178 };
179 by_source
180 .entry(source_key)
181 .or_default()
182 .push((i, chunk.relevance));
183 }
184
185 for (source, mut items) in by_source {
187 items.sort_by(|a, b| {
188 b.1.partial_cmp(&a.1)
189 .unwrap_or(std::cmp::Ordering::Equal)
190 .then_with(|| a.0.cmp(&b.0))
191 });
192 for (rank, (idx, _)) in items.iter().enumerate() {
193 let key = format!("chunk_{}", idx);
194 if source.starts_with("vector") {
195 vector_rankings.insert(key, rank + 1);
196 } else if source == "graph" {
197 graph_rankings.insert(key, rank + 1);
198 } else if source.starts_with("table") {
199 table_rankings.insert(key, rank + 1);
200 }
201 }
202 }
203
204 let k = self.config.rrf_k;
206 for (i, chunk) in context.chunks.iter_mut().enumerate() {
207 let key = format!("chunk_{}", i);
208
209 let mut rrf_score = 0.0;
210
211 if let Some(&rank) = vector_rankings.get(&key) {
212 rrf_score += self.config.vector_weight * (1.0 / (k + rank as f32));
213 }
214 if let Some(&rank) = graph_rankings.get(&key) {
215 rrf_score += self.config.graph_weight * (1.0 / (k + rank as f32));
216 }
217 if let Some(&rank) = table_rankings.get(&key) {
218 rrf_score += self.config.table_weight * (1.0 / (k + rank as f32));
219 }
220
221 chunk.relevance = 0.6 * chunk.relevance + 0.4 * rrf_score * 100.0;
223 }
224 }
225
226 fn graph_rerank(&self, context: &mut RetrievalContext) {
228 let store = match &self.store {
229 Some(s) => s,
230 None => return,
231 };
232
233 let mut entity_chunks: HashMap<EntityId, Vec<usize>> = HashMap::new();
235 for (i, chunk) in context.chunks.iter().enumerate() {
236 if let Some(ref id_str) = chunk.entity_id {
237 if let Ok(id) = id_str.parse::<u64>() {
238 entity_chunks.entry(EntityId(id)).or_default().push(i);
239 }
240 }
241 }
242
243 let mut boosts: HashMap<usize, f32> = HashMap::new();
245
246 for (entity_id, chunk_indices) in &entity_chunks {
247 let refs_from = store.get_refs_from(*entity_id);
249
250 for (target_id, ref_type, _collection) in refs_from {
251 if let Some(target_chunks) = entity_chunks.get(&target_id) {
252 let source_relevance: f32 = chunk_indices
254 .iter()
255 .map(|&i| context.chunks[i].relevance)
256 .sum::<f32>()
257 / chunk_indices.len() as f32;
258
259 let type_multiplier = match ref_type {
260 RefType::RelatedTo | RefType::DerivesFrom => 1.0,
261 RefType::Mentions | RefType::Contains => 0.8,
262 RefType::DependsOn => 0.7,
263 RefType::SimilarTo => 0.5,
264 _ => 0.3,
265 };
266
267 let boost = self.config.cross_ref_boost * source_relevance * type_multiplier;
268
269 for &chunk_idx in target_chunks {
270 *boosts.entry(chunk_idx).or_insert(0.0) += boost;
271 }
272 }
273 }
274 }
275
276 for (idx, boost) in boosts {
278 context.chunks[idx].relevance += boost;
279 }
280 }
281
282 fn deduplicate(&self, context: &mut RetrievalContext) {
284 if context.chunks.len() < 2 {
285 return;
286 }
287
288 let mut to_remove: HashSet<usize> = HashSet::new();
289 let threshold = self.config.dedup_threshold;
290
291 for i in 0..context.chunks.len() {
292 if to_remove.contains(&i) {
293 continue;
294 }
295
296 for j in (i + 1)..context.chunks.len() {
297 if to_remove.contains(&j) {
298 continue;
299 }
300
301 let similarity =
302 self.content_similarity(&context.chunks[i].content, &context.chunks[j].content);
303
304 if similarity > threshold {
305 if context.chunks[i].relevance >= context.chunks[j].relevance {
307 to_remove.insert(j);
308 } else {
309 to_remove.insert(i);
310 break;
311 }
312 }
313 }
314 }
315
316 let mut indices: Vec<usize> = to_remove.into_iter().collect();
318 indices.sort_by(|a, b| b.cmp(a));
319 for idx in indices {
320 context.chunks.remove(idx);
321 }
322 }
323
324 fn content_similarity(&self, a: &str, b: &str) -> f32 {
326 if a.is_empty() || b.is_empty() {
327 return 0.0;
328 }
329
330 let ngrams_a = self.extract_ngrams(a, 3);
331 let ngrams_b = self.extract_ngrams(b, 3);
332
333 if ngrams_a.is_empty() || ngrams_b.is_empty() {
334 return 0.0;
335 }
336
337 let intersection = ngrams_a.intersection(&ngrams_b).count();
338 let union = ngrams_a.union(&ngrams_b).count();
339
340 if union == 0 {
341 0.0
342 } else {
343 intersection as f32 / union as f32
344 }
345 }
346
347 fn extract_ngrams(&self, text: &str, n: usize) -> HashSet<String> {
349 let text = text.to_lowercase();
350 let chars: Vec<char> = text.chars().collect();
351
352 if chars.len() < n {
353 return HashSet::new();
354 }
355
356 (0..=chars.len() - n)
357 .map(|i| chars[i..i + n].iter().collect())
358 .collect()
359 }
360
361 fn diversify(&self, context: &mut RetrievalContext) {
363 let max_per_type = self.config.max_per_type;
364
365 let mut type_counts: HashMap<EntityType, usize> = HashMap::new();
367 let mut to_remove: HashSet<usize> = HashSet::new();
368
369 for (i, chunk) in context.chunks.iter().enumerate() {
371 let entity_type = chunk.entity_type.unwrap_or(EntityType::Unknown);
372 let count = type_counts.entry(entity_type).or_insert(0);
373
374 if *count >= max_per_type {
375 to_remove.insert(i);
376 } else {
377 *count += 1;
378 }
379 }
380
381 let mut indices: Vec<usize> = to_remove.into_iter().collect();
383 indices.sort_by(|a, b| b.cmp(a));
384 for idx in indices {
385 context.chunks.remove(idx);
386 }
387 }
388}
389
390impl Default for ContextFusion {
391 fn default() -> Self {
392 Self::new()
393 }
394}
395
396pub struct ResultReranker {
398 pub relevance_weight: f32,
400 pub recency_weight: f32,
401 pub connection_weight: f32,
402 pub type_priority: HashMap<EntityType, f32>,
403}
404
405impl Default for ResultReranker {
406 fn default() -> Self {
407 let mut type_priority = HashMap::new();
408 type_priority.insert(EntityType::Vulnerability, 1.0);
409 type_priority.insert(EntityType::Host, 0.9);
410 type_priority.insert(EntityType::Service, 0.85);
411 type_priority.insert(EntityType::Credential, 0.95);
412 type_priority.insert(EntityType::Certificate, 0.7);
413 type_priority.insert(EntityType::Domain, 0.75);
414 type_priority.insert(EntityType::Unknown, 0.5);
415
416 Self {
417 relevance_weight: 0.6,
418 recency_weight: 0.2,
419 connection_weight: 0.2,
420 type_priority,
421 }
422 }
423}
424
425impl ResultReranker {
426 pub fn rerank(&self, context: &mut RetrievalContext) {
428 for chunk in &mut context.chunks {
429 let mut final_score = self.relevance_weight * chunk.relevance;
430
431 let type_boost = chunk
433 .entity_type
434 .and_then(|t| self.type_priority.get(&t))
435 .unwrap_or(&0.5);
436 final_score += 0.1 * type_boost;
437
438 if let Some(depth) = chunk.graph_depth {
440 let connection_score = 1.0 / (1.0 + depth as f32);
442 final_score += self.connection_weight * connection_score;
443 }
444
445 chunk.relevance = final_score;
446 }
447
448 context.sort_by_relevance();
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455
456 #[test]
457 fn test_content_similarity() {
458 let fusion = ContextFusion::new();
459
460 let sim1 = fusion.content_similarity("This is a test string", "This is a test string");
461 assert!((sim1 - 1.0).abs() < 0.001);
462
463 let sim2 = fusion.content_similarity("completely different", "nothing alike");
464 assert!(sim2 < 0.5);
465
466 let sim3 = fusion.content_similarity("vulnerability in nginx", "vulnerability in apache");
467 assert!(sim3 > 0.3 && sim3 < 0.8);
468 }
469
470 #[test]
471 fn test_ngram_extraction() {
472 let fusion = ContextFusion::new();
473
474 let ngrams = fusion.extract_ngrams("hello", 3);
475 assert!(ngrams.contains("hel"));
476 assert!(ngrams.contains("ell"));
477 assert!(ngrams.contains("llo"));
478 assert_eq!(ngrams.len(), 3);
479 }
480
481 #[test]
482 fn test_fusion_config_defaults() {
483 let config = FusionConfig::default();
484 assert_eq!(config.rrf_k, 60.0);
485 assert!(config.diversify);
486 assert!(config.graph_rerank);
487 }
488}