chaotic_semantic_memory 0.3.4

AI memory systems with hyperdimensional vectors and chaotic reservoirs
Documentation
//! Hybrid retrieval combining BM25 and HDC scores.
//!
//! Provides query-length-dependent weighting between keyword (BM25) and
//! semantic (HDC) search results.

use std::collections::HashMap;

/// Compute query-length-dependent weights for hybrid retrieval.
///
/// Returns (keyword_weight, semantic_weight) based on token count.
///
/// | Query Tokens | Keyword | Semantic | Rationale |
/// |-------------|---------|----------|-----------|
/// | 1-2 | 0.9 | 0.1 | Exact match dominates |
/// | 3-4 | 0.7 | 0.3 | Keyword still strong |
/// | 5-8 | 0.4 | 0.6 | Semantic takes over |
/// | 9+ | 0.2 | 0.8 | Full semantic mode |
pub fn compute_weights(token_count: usize) -> (f32, f32) {
    match token_count {
        1..=2 => (0.9, 0.1),
        3..=4 => (0.7, 0.3),
        5..=8 => (0.4, 0.6),
        _ => (0.2, 0.8),
    }
}

/// Normalize scores to [0, 1] range using min-max normalization.
///
/// If all scores are equal, returns 0.5 for all.
pub fn normalize_scores(scores: &[(String, f32)]) -> Vec<(String, f32)> {
    if scores.is_empty() {
        return Vec::new();
    }

    let min = scores.iter().map(|(_, s)| *s).fold(f32::INFINITY, f32::min);
    let max = scores
        .iter()
        .map(|(_, s)| *s)
        .fold(f32::NEG_INFINITY, f32::max);

    let range = max - min;
    let epsilon = 1e-10;

    if range < epsilon {
        return scores.iter().map(|(id, _)| (id.clone(), 1.0)).collect();
    }

    scores
        .iter()
        .map(|(id, score)| {
            let normalized = (score - min) / range;
            (id.clone(), normalized)
        })
        .collect()
}

/// Merge BM25 and HDC results with given weights.
///
/// Takes two result sets (from BM25 and HDC), normalizes scores,
/// and combines them using weighted sum.
///
/// Duplicate IDs are merged by taking the maximum combined score.
pub fn merge_results(
    bm25_results: &[(String, f32)],
    hdc_results: &[(String, f32)],
    weights: (f32, f32),
) -> Vec<(String, f32)> {
    let (kw_weight, sem_weight) = weights;

    // Normalize both result sets
    let bm25_normalized = normalize_scores(bm25_results);
    let hdc_normalized = normalize_scores(hdc_results);

    // Build score map
    let mut combined: HashMap<String, f32> = HashMap::new();

    for (id, score) in &bm25_normalized {
        let entry = combined.entry(id.clone()).or_insert(0.0);
        *entry += kw_weight * score;
    }

    for (id, score) in &hdc_normalized {
        let entry = combined.entry(id.clone()).or_insert(0.0);
        *entry += sem_weight * score;
    }

    // Sort by combined score descending
    let mut results: Vec<(String, f32)> = combined.into_iter().collect();
    results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));

    results
}

/// Hybrid retrieval mode.
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum HybridMode {
    /// Auto-weight by query length (default).
    #[default]
    Auto,
    /// Force semantic-only (HDC).
    SemanticOnly,
    /// Force keyword-only (BM25).
    KeywordOnly,
    /// Custom weight override.
    Custom(f32),
}

/// Configuration for hybrid retrieval.
#[derive(Debug, Clone)]
pub struct HybridConfig {
    /// Hybrid mode.
    pub mode: HybridMode,
    /// Minimum score threshold (0.0-1.0).
    pub min_score: f32,
}

impl Default for HybridConfig {
    fn default() -> Self {
        Self {
            mode: HybridMode::Auto,
            min_score: 0.0,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_compute_weights_short_query() {
        let (kw, sem) = compute_weights(1);
        assert_eq!(kw, 0.9);
        assert_eq!(sem, 0.1);

        let (kw, sem) = compute_weights(2);
        assert_eq!(kw, 0.9);
        assert_eq!(sem, 0.1);
    }

    #[test]
    fn test_compute_weights_medium_query() {
        let (kw, sem) = compute_weights(3);
        assert_eq!(kw, 0.7);
        assert_eq!(sem, 0.3);

        let (kw, sem) = compute_weights(4);
        assert_eq!(kw, 0.7);
        assert_eq!(sem, 0.3);
    }

    #[test]
    fn test_compute_weights_long_query() {
        let (kw, sem) = compute_weights(5);
        assert_eq!(kw, 0.4);
        assert_eq!(sem, 0.6);

        let (kw, sem) = compute_weights(8);
        assert_eq!(kw, 0.4);
        assert_eq!(sem, 0.6);
    }

    #[test]
    fn test_compute_weights_very_long_query() {
        let (kw, sem) = compute_weights(9);
        assert_eq!(kw, 0.2);
        assert_eq!(sem, 0.8);

        let (kw, sem) = compute_weights(100);
        assert_eq!(kw, 0.2);
        assert_eq!(sem, 0.8);
    }

    #[test]
    fn test_normalize_scores_basic() {
        let scores = vec![
            ("a".to_string(), 0.0),
            ("b".to_string(), 0.5),
            ("c".to_string(), 1.0),
        ];
        let normalized = normalize_scores(&scores);

        assert!((normalized[0].1 - 0.0).abs() < 1e-6);
        assert!((normalized[1].1 - 0.5).abs() < 1e-6);
        assert!((normalized[2].1 - 1.0).abs() < 1e-6);
    }

    #[test]
    fn test_normalize_scores_empty() {
        let normalized = normalize_scores(&[]);
        assert!(normalized.is_empty());
    }

    #[test]
    fn test_normalize_scores_equal() {
        let scores = vec![("a".to_string(), 5.0), ("b".to_string(), 5.0)];
        let normalized = normalize_scores(&scores);

        // All equal scores should normalize to 1.0
        assert!((normalized[0].1 - 1.0).abs() < 1e-6);
        assert!((normalized[1].1 - 1.0).abs() < 1e-6);
    }

    #[test]
    fn test_merge_results_basic() {
        let bm25 = vec![("doc1".to_string(), 1.0), ("doc2".to_string(), 0.5)];
        let hdc = vec![("doc1".to_string(), 0.5), ("doc3".to_string(), 1.0)];

        let merged = merge_results(&bm25, &hdc, (0.5, 0.5));

        // doc1 appears in both
        assert!(merged.iter().any(|(id, _)| id == "doc1"));
        // doc2 only in BM25
        assert!(merged.iter().any(|(id, _)| id == "doc2"));
        // doc3 only in HDC
        assert!(merged.iter().any(|(id, _)| id == "doc3"));
    }

    #[test]
    fn test_merge_results_weighted() {
        let bm25 = vec![("doc1".to_string(), 1.0)];
        let hdc = vec![("doc1".to_string(), 1.0)];

        // With heavy keyword weight, BM25 should dominate
        let merged = merge_results(&bm25, &hdc, (0.9, 0.1));

        // doc1 should have combined score
        assert!(merged.iter().any(|(id, s)| id == "doc1" && *s > 0.0));
    }

    #[test]
    fn test_merge_results_empty() {
        let merged = merge_results(&[], &[], (0.5, 0.5));
        assert!(merged.is_empty());

        let merged = merge_results(&[("a".to_string(), 1.0)], &[], (0.5, 0.5));
        assert_eq!(merged.len(), 1);

        let merged = merge_results(&[], &[("a".to_string(), 1.0)], (0.5, 0.5));
        assert_eq!(merged.len(), 1);
    }
}