use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use super::context::{ChunkSource, ContextChunk, RetrievalContext};
use super::EntityType;
use crate::storage::{EntityId, RefType, Store};
#[derive(Debug, Clone)]
pub struct FusionConfig {
pub rrf_k: f32,
pub vector_weight: f32,
pub graph_weight: f32,
pub table_weight: f32,
pub cross_ref_boost: f32,
pub dedup_threshold: f32,
pub diversify: bool,
pub max_per_type: usize,
pub graph_rerank: bool,
}
impl Default for FusionConfig {
fn default() -> Self {
Self {
rrf_k: 60.0,
vector_weight: 0.5,
graph_weight: 0.3,
table_weight: 0.2,
cross_ref_boost: 0.15,
dedup_threshold: 0.85,
diversify: true,
max_per_type: 5,
graph_rerank: true,
}
}
}
pub struct ContextFusion {
config: FusionConfig,
store: Option<Arc<Store>>,
}
impl ContextFusion {
pub fn new() -> Self {
Self {
config: FusionConfig::default(),
store: None,
}
}
pub fn with_config(config: FusionConfig) -> Self {
Self {
config,
store: None,
}
}
pub fn with_store(mut self, store: Arc<Store>) -> Self {
self.store = Some(store);
self
}
pub fn fuse(&self, context: &mut RetrievalContext) {
self.normalize_scores(context);
if context.sources_used.len() > 1 {
self.apply_rrf(context);
}
if self.config.graph_rerank {
self.graph_rerank(context);
}
self.deduplicate(context);
if self.config.diversify {
self.diversify(context);
}
context.sort_by_relevance();
}
fn normalize_scores(&self, context: &mut RetrievalContext) {
let mut vector_chunks: Vec<usize> = Vec::new();
let mut graph_chunks: Vec<usize> = Vec::new();
let mut table_chunks: Vec<usize> = Vec::new();
let mut other_chunks: Vec<usize> = Vec::new();
for (i, chunk) in context.chunks.iter().enumerate() {
match chunk.source {
ChunkSource::Vector(_) => vector_chunks.push(i),
ChunkSource::Graph => graph_chunks.push(i),
ChunkSource::Table(_) => table_chunks.push(i),
_ => other_chunks.push(i),
}
}
self.normalize_group(&mut context.chunks, &vector_chunks);
self.normalize_group(&mut context.chunks, &graph_chunks);
self.normalize_group(&mut context.chunks, &table_chunks);
}
fn normalize_group(&self, chunks: &mut [ContextChunk], indices: &[usize]) {
if indices.is_empty() {
return;
}
let max_score = indices
.iter()
.map(|&i| chunks[i].relevance)
.fold(f32::NEG_INFINITY, f32::max);
let min_score = indices
.iter()
.map(|&i| chunks[i].relevance)
.fold(f32::INFINITY, f32::min);
let range = max_score - min_score;
if range > 0.0001 {
for &i in indices {
chunks[i].relevance = (chunks[i].relevance - min_score) / range;
}
}
}
fn apply_rrf(&self, context: &mut RetrievalContext) {
let mut vector_rankings: HashMap<String, usize> = HashMap::new();
let mut graph_rankings: HashMap<String, usize> = HashMap::new();
let mut table_rankings: HashMap<String, usize> = HashMap::new();
let mut by_source: HashMap<String, Vec<(usize, f32)>> = HashMap::new();
for (i, chunk) in context.chunks.iter().enumerate() {
let source_key = match &chunk.source {
ChunkSource::Vector(c) => format!("vector:{}", c),
ChunkSource::Graph => "graph".to_string(),
ChunkSource::Table(t) => format!("table:{}", t),
_ => "other".to_string(),
};
by_source
.entry(source_key)
.or_default()
.push((i, chunk.relevance));
}
for (source, mut items) in by_source {
items.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
for (rank, (idx, _)) in items.iter().enumerate() {
let key = format!("chunk_{}", idx);
if source.starts_with("vector") {
vector_rankings.insert(key, rank + 1);
} else if source == "graph" {
graph_rankings.insert(key, rank + 1);
} else if source.starts_with("table") {
table_rankings.insert(key, rank + 1);
}
}
}
let k = self.config.rrf_k;
for (i, chunk) in context.chunks.iter_mut().enumerate() {
let key = format!("chunk_{}", i);
let mut rrf_score = 0.0;
if let Some(&rank) = vector_rankings.get(&key) {
rrf_score += self.config.vector_weight * (1.0 / (k + rank as f32));
}
if let Some(&rank) = graph_rankings.get(&key) {
rrf_score += self.config.graph_weight * (1.0 / (k + rank as f32));
}
if let Some(&rank) = table_rankings.get(&key) {
rrf_score += self.config.table_weight * (1.0 / (k + rank as f32));
}
chunk.relevance = 0.6 * chunk.relevance + 0.4 * rrf_score * 100.0;
}
}
fn graph_rerank(&self, context: &mut RetrievalContext) {
let store = match &self.store {
Some(s) => s,
None => return,
};
let mut entity_chunks: HashMap<EntityId, Vec<usize>> = HashMap::new();
for (i, chunk) in context.chunks.iter().enumerate() {
if let Some(ref id_str) = chunk.entity_id {
if let Ok(id) = id_str.parse::<u64>() {
entity_chunks.entry(EntityId(id)).or_default().push(i);
}
}
}
let mut boosts: HashMap<usize, f32> = HashMap::new();
for (entity_id, chunk_indices) in &entity_chunks {
let refs_from = store.get_refs_from(*entity_id);
for (target_id, ref_type, _collection) in refs_from {
if let Some(target_chunks) = entity_chunks.get(&target_id) {
let source_relevance: f32 = chunk_indices
.iter()
.map(|&i| context.chunks[i].relevance)
.sum::<f32>()
/ chunk_indices.len() as f32;
let type_multiplier = match ref_type {
RefType::RelatedTo | RefType::DerivesFrom => 1.0,
RefType::Mentions | RefType::Contains => 0.8,
RefType::DependsOn => 0.7,
RefType::SimilarTo => 0.5,
_ => 0.3,
};
let boost = self.config.cross_ref_boost * source_relevance * type_multiplier;
for &chunk_idx in target_chunks {
*boosts.entry(chunk_idx).or_insert(0.0) += boost;
}
}
}
}
for (idx, boost) in boosts {
context.chunks[idx].relevance += boost;
}
}
fn deduplicate(&self, context: &mut RetrievalContext) {
if context.chunks.len() < 2 {
return;
}
let mut to_remove: HashSet<usize> = HashSet::new();
let threshold = self.config.dedup_threshold;
for i in 0..context.chunks.len() {
if to_remove.contains(&i) {
continue;
}
for j in (i + 1)..context.chunks.len() {
if to_remove.contains(&j) {
continue;
}
let similarity =
self.content_similarity(&context.chunks[i].content, &context.chunks[j].content);
if similarity > threshold {
if context.chunks[i].relevance >= context.chunks[j].relevance {
to_remove.insert(j);
} else {
to_remove.insert(i);
break;
}
}
}
}
let mut indices: Vec<usize> = to_remove.into_iter().collect();
indices.sort_by(|a, b| b.cmp(a));
for idx in indices {
context.chunks.remove(idx);
}
}
fn content_similarity(&self, a: &str, b: &str) -> f32 {
if a.is_empty() || b.is_empty() {
return 0.0;
}
let ngrams_a = self.extract_ngrams(a, 3);
let ngrams_b = self.extract_ngrams(b, 3);
if ngrams_a.is_empty() || ngrams_b.is_empty() {
return 0.0;
}
let intersection = ngrams_a.intersection(&ngrams_b).count();
let union = ngrams_a.union(&ngrams_b).count();
if union == 0 {
0.0
} else {
intersection as f32 / union as f32
}
}
fn extract_ngrams(&self, text: &str, n: usize) -> HashSet<String> {
let text = text.to_lowercase();
let chars: Vec<char> = text.chars().collect();
if chars.len() < n {
return HashSet::new();
}
(0..=chars.len() - n)
.map(|i| chars[i..i + n].iter().collect())
.collect()
}
fn diversify(&self, context: &mut RetrievalContext) {
let max_per_type = self.config.max_per_type;
let mut type_counts: HashMap<EntityType, usize> = HashMap::new();
let mut to_remove: HashSet<usize> = HashSet::new();
for (i, chunk) in context.chunks.iter().enumerate() {
let entity_type = chunk.entity_type.unwrap_or(EntityType::Unknown);
let count = type_counts.entry(entity_type).or_insert(0);
if *count >= max_per_type {
to_remove.insert(i);
} else {
*count += 1;
}
}
let mut indices: Vec<usize> = to_remove.into_iter().collect();
indices.sort_by(|a, b| b.cmp(a));
for idx in indices {
context.chunks.remove(idx);
}
}
}
impl Default for ContextFusion {
fn default() -> Self {
Self::new()
}
}
pub struct ResultReranker {
pub relevance_weight: f32,
pub recency_weight: f32,
pub connection_weight: f32,
pub type_priority: HashMap<EntityType, f32>,
}
impl Default for ResultReranker {
fn default() -> Self {
let mut type_priority = HashMap::new();
type_priority.insert(EntityType::Vulnerability, 1.0);
type_priority.insert(EntityType::Host, 0.9);
type_priority.insert(EntityType::Service, 0.85);
type_priority.insert(EntityType::Credential, 0.95);
type_priority.insert(EntityType::Certificate, 0.7);
type_priority.insert(EntityType::Domain, 0.75);
type_priority.insert(EntityType::Unknown, 0.5);
Self {
relevance_weight: 0.6,
recency_weight: 0.2,
connection_weight: 0.2,
type_priority,
}
}
}
impl ResultReranker {
pub fn rerank(&self, context: &mut RetrievalContext) {
for chunk in &mut context.chunks {
let mut final_score = self.relevance_weight * chunk.relevance;
let type_boost = chunk
.entity_type
.and_then(|t| self.type_priority.get(&t))
.unwrap_or(&0.5);
final_score += 0.1 * type_boost;
if let Some(depth) = chunk.graph_depth {
let connection_score = 1.0 / (1.0 + depth as f32);
final_score += self.connection_weight * connection_score;
}
chunk.relevance = final_score;
}
context.sort_by_relevance();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_content_similarity() {
let fusion = ContextFusion::new();
let sim1 = fusion.content_similarity("This is a test string", "This is a test string");
assert!((sim1 - 1.0).abs() < 0.001);
let sim2 = fusion.content_similarity("completely different", "nothing alike");
assert!(sim2 < 0.5);
let sim3 = fusion.content_similarity("vulnerability in nginx", "vulnerability in apache");
assert!(sim3 > 0.3 && sim3 < 0.8);
}
#[test]
fn test_ngram_extraction() {
let fusion = ContextFusion::new();
let ngrams = fusion.extract_ngrams("hello", 3);
assert!(ngrams.contains("hel"));
assert!(ngrams.contains("ell"));
assert!(ngrams.contains("llo"));
assert_eq!(ngrams.len(), 3);
}
#[test]
fn test_fusion_config_defaults() {
let config = FusionConfig::default();
assert_eq!(config.rrf_k, 60.0);
assert!(config.diversify);
assert!(config.graph_rerank);
}
}