#![allow(dead_code)]
#![allow(clippy::cast_precision_loss)]
use std::collections::HashMap;
pub use crate::scored_result::ScoredResult;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum FusionStrategy {
#[default]
Rrf,
WeightedSum,
Maximum,
}
#[derive(Debug, Clone)]
pub struct RrfConfig {
pub k: u32,
}
impl Default for RrfConfig {
fn default() -> Self {
Self { k: 60 }
}
}
impl RrfConfig {
#[must_use]
pub fn with_k(k: u32) -> Self {
Self { k }
}
}
#[derive(Debug, Clone)]
pub struct WeightedConfig {
pub vector_weight: f32,
pub graph_weight: f32,
}
impl Default for WeightedConfig {
fn default() -> Self {
Self {
vector_weight: 0.5,
graph_weight: 0.5,
}
}
}
impl WeightedConfig {
#[must_use]
pub fn new(vector_weight: f32, graph_weight: f32) -> Self {
Self {
vector_weight,
graph_weight,
}
}
}
fn sort_descending_and_truncate(results: &mut Vec<ScoredResult>, limit: usize) {
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(limit);
}
fn finalize_scores(scores: HashMap<u64, f32>, limit: usize) -> Vec<ScoredResult> {
let mut fused: Vec<ScoredResult> = scores
.into_iter()
.map(|(id, score)| ScoredResult::new(id, score))
.collect();
sort_descending_and_truncate(&mut fused, limit);
fused
}
fn accumulate_rrf_scores(scores: &mut HashMap<u64, f32>, results: &[ScoredResult], k: f32) {
for (rank, result) in results.iter().enumerate() {
let rrf_score = 1.0 / (k + (rank + 1) as f32);
*scores.entry(result.id).or_insert(0.0) += rrf_score;
}
}
fn accumulate_weighted_scores(
scores: &mut HashMap<u64, f32>,
results: &[ScoredResult],
weight: f32,
) {
for result in results {
*scores.entry(result.id).or_insert(0.0) += result.score * weight;
}
}
fn accumulate_max_scores(scores: &mut HashMap<u64, f32>, results: &[ScoredResult]) {
for result in results {
let entry = scores.entry(result.id).or_insert(0.0);
*entry = entry.max(result.score);
}
}
#[must_use]
pub fn fuse_rrf(
vector_results: &[ScoredResult],
graph_results: &[ScoredResult],
config: &RrfConfig,
limit: usize,
) -> Vec<ScoredResult> {
let mut scores: HashMap<u64, f32> = HashMap::new();
let k = config.k as f32;
accumulate_rrf_scores(&mut scores, vector_results, k);
accumulate_rrf_scores(&mut scores, graph_results, k);
finalize_scores(scores, limit)
}
#[must_use]
pub fn fuse_weighted(
vector_results: &[ScoredResult],
graph_results: &[ScoredResult],
config: &WeightedConfig,
limit: usize,
) -> Vec<ScoredResult> {
let vector_normalized = normalize_scores(vector_results);
let graph_normalized = normalize_scores(graph_results);
let mut scores: HashMap<u64, f32> = HashMap::new();
accumulate_weighted_scores(&mut scores, &vector_normalized, config.vector_weight);
accumulate_weighted_scores(&mut scores, &graph_normalized, config.graph_weight);
finalize_scores(scores, limit)
}
#[must_use]
pub fn fuse_maximum(
vector_results: &[ScoredResult],
graph_results: &[ScoredResult],
limit: usize,
) -> Vec<ScoredResult> {
let vector_normalized = normalize_scores(vector_results);
let graph_normalized = normalize_scores(graph_results);
let mut scores: HashMap<u64, f32> = HashMap::new();
accumulate_max_scores(&mut scores, &vector_normalized);
accumulate_max_scores(&mut scores, &graph_normalized);
finalize_scores(scores, limit)
}
pub(crate) fn normalize_scores(results: &[ScoredResult]) -> Vec<ScoredResult> {
if results.is_empty() {
return Vec::new();
}
let min_score = results
.iter()
.map(|r| r.score)
.fold(f32::INFINITY, f32::min);
let max_score = results
.iter()
.map(|r| r.score)
.fold(f32::NEG_INFINITY, f32::max);
let range = max_score - min_score;
if range.abs() < f32::EPSILON {
return results
.iter()
.map(|r| ScoredResult::new(r.id, 1.0))
.collect();
}
results
.iter()
.map(|r| ScoredResult::new(r.id, (r.score - min_score) / range))
.collect()
}
#[must_use]
pub fn intersect_results(
vector_results: &[ScoredResult],
graph_results: &[ScoredResult],
) -> (Vec<ScoredResult>, Vec<ScoredResult>) {
let graph_ids: std::collections::HashSet<u64> = graph_results.iter().map(|r| r.id).collect();
let filtered_vector: Vec<ScoredResult> = vector_results
.iter()
.filter(|r| graph_ids.contains(&r.id))
.copied()
.collect();
let vector_ids: std::collections::HashSet<u64> = vector_results.iter().map(|r| r.id).collect();
let filtered_graph: Vec<ScoredResult> = graph_results
.iter()
.filter(|r| vector_ids.contains(&r.id))
.copied()
.collect();
(filtered_vector, filtered_graph)
}