alaya 0.4.8

A memory engine for conversational AI agents, inspired by neuroscience and Buddhist psychology
Documentation
use crate::types::NodeRef;
use std::collections::HashMap;

/// Reciprocal Rank Fusion (RRF) merges multiple ranked result sets.
///
/// For each document d: score(d) = sum(weight_i / (k + rank_i + 1))
/// where rank_i is the 0-based rank of d in result set i and weight_i
/// is the optional per-set multiplier (defaults to 1.0).
///
/// Reference: Cormack, Clarke & Buettcher (2009)
#[cfg(test)]
pub fn rrf_merge(result_sets: &[Vec<(NodeRef, f64)>], k: u32) -> Vec<(NodeRef, f64)> {
    rrf_merge_weighted(result_sets, k, None)
}

/// Like [`rrf_merge`] but accepts optional per-set boost weights.
/// When `weights` is `Some`, `weights[i]` is used as a multiplier for
/// the RRF contribution from `result_sets[i]`.
pub fn rrf_merge_weighted(
    result_sets: &[Vec<(NodeRef, f64)>],
    k: u32,
    weights: Option<&[f32]>,
) -> Vec<(NodeRef, f64)> {
    let mut scores: HashMap<NodeRef, f64> = HashMap::new();

    for (set_idx, result_set) in result_sets.iter().enumerate() {
        let weight = weights.and_then(|w| w.get(set_idx)).copied().unwrap_or(1.0) as f64;
        for (rank, (node_ref, _original_score)) in result_set.iter().enumerate() {
            *scores.entry(*node_ref).or_default() += weight / (k as f64 + rank as f64 + 1.0);
        }
    }

    let mut merged: Vec<(NodeRef, f64)> = scores.into_iter().collect();
    merged.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
    merged
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::types::*;
    use proptest::prelude::*;

    proptest! {
        #[test]
        fn prop_rrf_scores_are_positive(
            scores in proptest::collection::vec(0.0f64..1.0f64, 1..10),
        ) {
            let items: Vec<(NodeRef, f64)> = scores.into_iter()
                .enumerate()
                .map(|(i, s)| (NodeRef::Episode(EpisodeId(i as i64 + 1)), s))
                .collect();
            let result = rrf_merge(&[items], 60);
            for (_, score) in &result {
                prop_assert!(*score > 0.0, "RRF score should be positive, got {}", score);
            }
        }

        #[test]
        fn prop_rrf_preserves_ordering(
            n in 2usize..10,
        ) {
            // Items ranked 1..n should produce monotonically decreasing RRF scores
            let items: Vec<(NodeRef, f64)> = (0..n)
                .map(|i| (NodeRef::Episode(EpisodeId(i as i64 + 1)), 1.0 - (i as f64 / n as f64)))
                .collect();
            let result = rrf_merge(&[items], 60);
            for i in 1..result.len() {
                prop_assert!(
                    result[i - 1].1 >= result[i].1,
                    "RRF should preserve ordering: {} < {}",
                    result[i - 1].1,
                    result[i].1
                );
            }
        }
    }

    #[test]
    fn test_rrf_single_set() {
        let set = vec![
            (NodeRef::Episode(EpisodeId(1)), 0.9),
            (NodeRef::Episode(EpisodeId(2)), 0.5),
        ];
        let merged = rrf_merge(&[set], 60);
        assert_eq!(merged.len(), 2);
        // First item should have higher score
        assert!(merged[0].1 > merged[1].1);
    }

    #[test]
    fn test_rrf_two_sets_overlap() {
        let set_a = vec![
            (NodeRef::Episode(EpisodeId(1)), 0.9),
            (NodeRef::Episode(EpisodeId(2)), 0.5),
        ];
        let set_b = vec![
            (NodeRef::Episode(EpisodeId(2)), 0.8),
            (NodeRef::Episode(EpisodeId(3)), 0.3),
        ];
        let merged = rrf_merge(&[set_a, set_b], 60);
        // Episode 2 appears in both sets, should have highest combined score
        assert_eq!(merged[0].0, NodeRef::Episode(EpisodeId(2)));
    }

    #[test]
    fn test_rrf_disjoint() {
        let set_a = vec![(NodeRef::Episode(EpisodeId(1)), 0.9)];
        let set_b = vec![(NodeRef::Episode(EpisodeId(2)), 0.8)];
        let merged = rrf_merge(&[set_a, set_b], 60);
        assert_eq!(merged.len(), 2);
        // Both at rank 0, so equal RRF scores
        assert!((merged[0].1 - merged[1].1).abs() < 1e-10);
    }

    #[test]
    fn test_rrf_empty_sets() {
        let sets: Vec<Vec<(NodeRef, f64)>> = vec![];
        let result = rrf_merge(&sets, 60);
        assert!(result.is_empty());
    }

    #[test]
    fn test_rrf_single_item() {
        let sets = vec![vec![(NodeRef::Episode(EpisodeId(1)), 1.0)]];
        let result = rrf_merge(&sets, 60);
        assert_eq!(result.len(), 1);
        assert_eq!(result[0].0, NodeRef::Episode(EpisodeId(1)));
    }
}