lucisearch 0.8.0

Embeddable, in-process search engine — the SQLite/DuckDB of Elasticsearch
Documentation
//! Result-set fusion: combine multiple ranked result lists.
//!
//! Supports rank-based (RRF) and score-based (sum, mean) fusion methods.
//! See [[feature-rrf-retrievers]] and [[reciprocal-rank-fusion]].

use crate::core::{DocId, SegmentId};
use crate::query::ast::FusionMethod;
use crate::search::results::HitRef;
use crate::search::searcher::ScoringResults;
use std::collections::HashMap;

/// Default rank constant (Elasticsearch default).
pub const DEFAULT_RANK_CONSTANT: f32 = 60.0;

/// Combine multiple search result lists using Reciprocal Rank Fusion.
///
/// Weighted RRF: `score(d) = Σ weight_i / (rank_constant + rank_i(d))`
/// where rank starts at 1. Unweighted: all weights = 1.0.
///
/// `total_hits` is the union of unique documents across all sources.
pub(crate) fn reciprocal_rank_fusion(
    result_lists: &[&ScoringResults],
    rank_constant: f32,
    weights: Option<&[f32]>,
    top_k: usize,
) -> ScoringResults {
    let n = result_lists.len();
    let default_weights: Vec<f32> = vec![1.0; n];
    let w = weights.unwrap_or(&default_weights);

    let mut scores: HashMap<(u64, u32), f32> = HashMap::new();

    for (list_idx, results) in result_lists.iter().enumerate() {
        let weight = w.get(list_idx).copied().unwrap_or(1.0);
        for (rank, hit) in results.hits.iter().enumerate() {
            let key = (hit.segment_id.as_u64(), hit.doc_id.as_u32());
            let rrf_contribution = weight / (rank_constant + (rank + 1) as f32);
            *scores.entry(key).or_insert(0.0) += rrf_contribution;
        }
    }

    let total_unique = scores.len() as u64;

    let mut ranked: Vec<_> = scores.into_iter().collect();
    ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
    ranked.truncate(top_k);

    let hits: Vec<HitRef> = ranked
        .into_iter()
        .map(|((seg_id, doc_id), score)| HitRef {
            doc_id: DocId::new(doc_id),
            segment_id: SegmentId::new(seg_id),
            score,
            sort_values: None,
            collapse_key: None,
        })
        .collect();

    ScoringResults {
        hits,
        total_hits: crate::search::TotalHits::exact(total_unique),
        aggregations: HashMap::new(),
    }
}

/// Combine multiple result lists using score-based fusion.
///
/// All mean methods treat missing documents consistently: a document
/// absent from a source contributes score=0 with weight=0 to that
/// source's term. This means the denominator only counts sources where
/// the document actually appears.
///
/// `total_hits` is the union of unique documents across all sources.
pub(crate) fn score_fusion(
    result_lists: &[&ScoringResults],
    method: &FusionMethod,
    weights: Option<&[f32]>,
    top_k: usize,
) -> ScoringResults {
    let n = result_lists.len();
    let default_weights: Vec<f32> = vec![1.0; n];
    let w = weights.unwrap_or(&default_weights);

    // Collect per-doc (source_idx, score) pairs
    let mut doc_scores: HashMap<(u64, u32), Vec<(usize, f32)>> = HashMap::new();
    for (list_idx, results) in result_lists.iter().enumerate() {
        for hit in &results.hits {
            let key = (hit.segment_id.as_u64(), hit.doc_id.as_u32());
            doc_scores
                .entry(key)
                .or_default()
                .push((list_idx, hit.score));
        }
    }

    let total_unique = doc_scores.len() as u64;

    let mut ranked: Vec<((u64, u32), f32)> = doc_scores
        .into_iter()
        .map(|(key, scores)| {
            let fused = match method {
                FusionMethod::Sum => scores.iter().map(|&(idx, s)| weight(w, idx) * s).sum(),
                FusionMethod::ArithmeticMean => {
                    // Denominator = sum of weights for PRESENT sources only.
                    // Missing sources contribute nothing (consistent with other means).
                    let present_weight: f32 = scores.iter().map(|&(idx, _)| weight(w, idx)).sum();
                    let weighted_sum: f32 = scores.iter().map(|&(idx, s)| weight(w, idx) * s).sum();
                    if present_weight > 0.0 {
                        weighted_sum / present_weight
                    } else {
                        0.0
                    }
                }
                FusionMethod::HarmonicMean => {
                    // H = (Σ w_i) / (Σ w_i/s_i). Zero scores → document gets score 0
                    // (mathematically correct: harmonic mean of any set containing 0 is 0).
                    let present_weight: f32 = scores.iter().map(|&(idx, _)| weight(w, idx)).sum();
                    if scores.iter().any(|&(_, s)| s <= 0.0) {
                        0.0
                    } else {
                        let weighted_recip: f32 =
                            scores.iter().map(|&(idx, s)| weight(w, idx) / s).sum();
                        if weighted_recip > 0.0 {
                            present_weight / weighted_recip
                        } else {
                            0.0
                        }
                    }
                }
                FusionMethod::GeometricMean => {
                    // G = exp((Σ w_i * ln(s_i)) / Σ w_i). Zero scores → score 0.
                    if scores.iter().any(|&(_, s)| s <= 0.0) {
                        0.0
                    } else {
                        let present_weight: f32 =
                            scores.iter().map(|&(idx, _)| weight(w, idx)).sum();
                        let log_sum: f32 =
                            scores.iter().map(|&(idx, s)| weight(w, idx) * s.ln()).sum();
                        if present_weight > 0.0 {
                            (log_sum / present_weight).exp()
                        } else {
                            0.0
                        }
                    }
                }
                FusionMethod::ReciprocalRank => {
                    unreachable!("RRF handled by reciprocal_rank_fusion")
                }
            };
            (key, fused)
        })
        .collect();

    ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
    ranked.truncate(top_k);

    let hits: Vec<HitRef> = ranked
        .into_iter()
        .map(|((seg_id, doc_id), score)| HitRef {
            doc_id: DocId::new(doc_id),
            segment_id: SegmentId::new(seg_id),
            score,
            sort_values: None,
            collapse_key: None,
        })
        .collect();

    ScoringResults {
        hits,
        total_hits: crate::search::TotalHits::exact(total_unique),
        aggregations: HashMap::new(),
    }
}

/// Safe weight lookup — returns 1.0 for out-of-bounds indices.
fn weight(weights: &[f32], idx: usize) -> f32 {
    weights.get(idx).copied().unwrap_or(1.0)
}

#[cfg(test)]
mod tests {
    use super::*;

    fn make_results(docs: &[(u32, f32)]) -> ScoringResults {
        let hits = docs
            .iter()
            .map(|&(id, score)| HitRef {
                doc_id: DocId::new(id),
                segment_id: SegmentId::new(1),
                score,
                sort_values: None,
                collapse_key: None,
            })
            .collect();
        ScoringResults {
            hits,
            total_hits: crate::search::TotalHits::exact(docs.len() as u64),
            aggregations: HashMap::new(),
        }
    }

    #[test]
    fn rrf_single_list() {
        let r1 = make_results(&[(0, 3.0), (1, 2.0), (2, 1.0)]);
        let fused = reciprocal_rank_fusion(&[&r1], 60.0, None, 10);
        assert_eq!(fused.hits.len(), 3);
        assert!(fused.hits[0].doc_id == DocId::new(0));
    }

    #[test]
    fn rrf_two_lists_overlap() {
        let r1 = make_results(&[(0, 3.0), (1, 2.0), (2, 1.0)]);
        let r2 = make_results(&[(1, 3.0), (2, 2.0), (0, 1.0)]);

        let fused = reciprocal_rank_fusion(&[&r1, &r2], 60.0, None, 10);
        assert_eq!(fused.hits.len(), 3);
        assert_eq!(fused.hits[0].doc_id, DocId::new(1));
    }

    #[test]
    fn rrf_disjoint_lists() {
        let r1 = make_results(&[(0, 1.0), (1, 0.5)]);
        let r2 = make_results(&[(2, 1.0), (3, 0.5)]);

        let fused = reciprocal_rank_fusion(&[&r1, &r2], 60.0, None, 10);
        assert_eq!(fused.hits.len(), 4);
        let top_ids: Vec<u32> = fused.hits[..2].iter().map(|h| h.doc_id.as_u32()).collect();
        assert!(top_ids.contains(&0) && top_ids.contains(&2));
    }

    #[test]
    fn rrf_top_k() {
        let r1 = make_results(&[(0, 1.0), (1, 0.5), (2, 0.3)]);
        let fused = reciprocal_rank_fusion(&[&r1], 60.0, None, 2);
        assert_eq!(fused.hits.len(), 2);
    }

    #[test]
    fn rrf_empty() {
        let r1 = make_results(&[]);
        let fused = reciprocal_rank_fusion(&[&r1], 60.0, None, 10);
        assert!(fused.hits.is_empty());
    }

    #[test]
    fn rrf_doc_id_preserved() {
        let r1 = make_results(&[(42, 1.0)]);
        let fused = reciprocal_rank_fusion(&[&r1], 60.0, None, 10);
        assert_eq!(fused.hits[0].doc_id, DocId::new(42));
    }

    #[test]
    fn rrf_weighted() {
        // Source 0 weighted 3x more than source 1
        let r1 = make_results(&[(0, 1.0), (1, 0.5)]);
        let r2 = make_results(&[(1, 1.0), (0, 0.5)]);
        let weights = vec![3.0, 1.0];

        let fused = reciprocal_rank_fusion(&[&r1, &r2], 60.0, Some(&weights), 10);
        // Doc 0: 3.0/61 + 1.0/62 ≈ 0.0652
        // Doc 1: 3.0/62 + 1.0/61 ≈ 0.0648
        // Doc 0 should win because it ranks first in the higher-weighted source
        assert_eq!(fused.hits[0].doc_id, DocId::new(0));
    }

    #[test]
    fn rrf_total_hits_counts_unique() {
        let r1 = make_results(&[(0, 1.0), (1, 0.5)]);
        let r2 = make_results(&[(1, 1.0), (2, 0.5)]); // doc 1 overlaps
        let fused = reciprocal_rank_fusion(&[&r1, &r2], 60.0, None, 1);
        // 3 unique docs, even though top_k=1
        assert_eq!(fused.total_hits.value, 3);
    }

    #[test]
    fn score_fusion_sum() {
        let r1 = make_results(&[(0, 2.0), (1, 1.0)]);
        let r2 = make_results(&[(0, 3.0), (2, 1.0)]);
        let fused = score_fusion(&[&r1, &r2], &FusionMethod::Sum, None, 10);

        // Doc 0: 2.0 + 3.0 = 5.0
        assert_eq!(fused.hits[0].doc_id, DocId::new(0));
        assert!((fused.hits[0].score - 5.0).abs() < 0.01);
    }

    #[test]
    fn score_fusion_arithmetic_mean_present_only() {
        // ArithmeticMean denominator = sum of PRESENT source weights only
        let r1 = make_results(&[(0, 4.0)]); // doc 0 only in source 0
        let r2 = make_results(&[(1, 2.0)]); // doc 1 only in source 1
        let fused = score_fusion(&[&r1, &r2], &FusionMethod::ArithmeticMean, None, 10);

        // Doc 0: 4.0/1.0 = 4.0 (only present in 1 source, weight=1.0)
        // Doc 1: 2.0/1.0 = 2.0
        assert_eq!(fused.hits[0].doc_id, DocId::new(0));
        assert!((fused.hits[0].score - 4.0).abs() < 0.01);
    }

    #[test]
    fn score_fusion_harmonic_mean_zero_score() {
        let r1 = make_results(&[(0, 0.0), (1, 4.0)]);
        let r2 = make_results(&[(0, 5.0), (1, 4.0)]);
        let fused = score_fusion(&[&r1, &r2], &FusionMethod::HarmonicMean, None, 10);

        // Doc 0: has a zero score → harmonic mean = 0
        // Doc 1: H(4,4) = 4.0
        assert_eq!(fused.hits[0].doc_id, DocId::new(1));
        assert!((fused.hits[0].score - 4.0).abs() < 0.01);
        assert!((fused.hits[1].score - 0.0).abs() < 0.01);
    }

    #[test]
    fn score_fusion_weight_bounds_safe() {
        // Fewer weights than sources — should not panic
        let r1 = make_results(&[(0, 1.0)]);
        let r2 = make_results(&[(0, 2.0)]);
        let weights = vec![0.5]; // only 1 weight for 2 sources
        let fused = score_fusion(&[&r1, &r2], &FusionMethod::Sum, Some(&weights), 10);

        // Source 0: 0.5 * 1.0 = 0.5, Source 1: default(1.0) * 2.0 = 2.0
        assert!((fused.hits[0].score - 2.5).abs() < 0.01);
    }
}