collet 0.1.1

Relentless agentic coding orchestrator with zero-drop agent loops
Documentation
//! BM25 scoring functions shared across all search consumers.
//!
//! Each consumer provides its own K1/B parameters to tune for its domain
//! (code files vs. short knowledge entries).

use std::collections::HashMap;

// ── BM25 Scoring ─────────────────────────────────────────────────────

/// BM25 tuning parameters.
#[derive(Debug, Clone, Copy)]
pub struct Bm25Params {
    pub k1: f64,
    pub b: f64,
}

// Convenience constructors and default constants are test-only.
// Production callers use Bm25Params { k1: ..., b: ... } struct literal directly.
#[cfg(test)]
mod defaults {
    pub const DEFAULT_K1: f64 = 1.2;
    pub const CODE_B: f64 = 0.5;
    pub const KNOWLEDGE_B: f64 = 0.3;
}

#[cfg(test)]
impl Bm25Params {
    /// Parameters tuned for source code files.
    pub fn code() -> Self {
        Self {
            k1: defaults::DEFAULT_K1,
            b: defaults::CODE_B,
        }
    }

    /// Parameters tuned for short knowledge entries (facts, summaries).
    pub fn knowledge() -> Self {
        Self {
            k1: defaults::DEFAULT_K1,
            b: defaults::KNOWLEDGE_B,
        }
    }
}

/// Compute BM25 score for a single document against query tokens.
///
/// # Arguments
/// - `tf` — term frequency map for the document
/// - `doc_len` — total tokens in the document
/// - `query_tokens` — deduplicated query tokens
/// - `doc_freq` — global document frequency map (token → # of docs containing it)
/// - `total_docs` — total number of documents in the corpus
/// - `avg_doc_len` — average document length across corpus
/// - `params` — BM25 K1/B parameters
pub fn bm25_score(
    tf: &HashMap<String, u32>,
    doc_len: u32,
    query_tokens: &[String],
    doc_freq: &HashMap<String, u32>,
    total_docs: usize,
    avg_doc_len: f64,
    params: &Bm25Params,
) -> f64 {
    let n = total_docs as f64;
    let dl = doc_len as f64;
    let mut score = 0.0;

    for token in query_tokens {
        let term_freq = tf.get(token).copied().unwrap_or(0) as f64;
        if term_freq == 0.0 {
            continue;
        }

        let df = doc_freq.get(token).copied().unwrap_or(0) as f64;
        let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
        let tf_norm = (term_freq * (params.k1 + 1.0))
            / (term_freq + params.k1 * (1.0 - params.b + params.b * dl / avg_doc_len));

        score += idf * tf_norm;
    }

    score
}

/// Dynamic elbow detection: find the cutoff point where scores drop sharply.
///
/// Given a descending-sorted list of `(index, score)`, returns the prefix
/// before the first point where `prev_score / score > threshold`.
pub fn elbow_cutoff(
    scored: &[(usize, f64)],
    max_k: usize,
    drop_threshold: f64,
) -> Vec<(usize, f64)> {
    let mut result = Vec::new();

    for (idx, &(i, score)) in scored.iter().enumerate() {
        if idx > 0 {
            let prev_score = result
                .last()
                .map(|&(_, s): &(usize, f64)| s)
                .unwrap_or(score);
            if prev_score / (score + 0.001) > drop_threshold {
                break;
            }
        }
        result.push((i, score));
        if result.len() >= max_k {
            break;
        }
    }

    result
}

// ── Tests ────────────────────────────────────────────────────────────

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_bm25_score_basic() {
        let mut tf = HashMap::new();
        tf.insert("error".to_string(), 3);
        tf.insert("handling".to_string(), 1);

        let mut doc_freq = HashMap::new();
        doc_freq.insert("error".to_string(), 2);
        doc_freq.insert("handling".to_string(), 1);

        let query = vec!["error".to_string(), "handling".to_string()];
        let score = bm25_score(&tf, 10, &query, &doc_freq, 5, 8.0, &Bm25Params::code());
        assert!(score > 0.0, "score should be positive for matching terms");
    }

    #[test]
    fn test_bm25_score_zero_for_no_match() {
        let tf = HashMap::new(); // empty doc
        let mut doc_freq = HashMap::new();
        doc_freq.insert("error".to_string(), 1);

        let query = vec!["error".to_string()];
        let score = bm25_score(&tf, 0, &query, &doc_freq, 3, 5.0, &Bm25Params::knowledge());
        assert_eq!(score, 0.0);
    }

    #[test]
    fn test_elbow_cutoff_stops_at_drop() {
        let scored = vec![
            (0, 10.0),
            (1, 9.5),
            (2, 8.0),
            (3, 2.0), // 8.0 / 2.0 = 4.0 > 3.0 threshold → stop
            (4, 1.0),
        ];
        let result = elbow_cutoff(&scored, 10, 3.0);
        assert_eq!(result.len(), 3);
    }

    #[test]
    fn test_elbow_cutoff_respects_max_k() {
        let scored = vec![(0, 10.0), (1, 9.5), (2, 9.0), (3, 8.5)];
        let result = elbow_cutoff(&scored, 2, 3.0);
        assert_eq!(result.len(), 2);
    }

    #[test]
    fn test_params_presets() {
        let code = Bm25Params::code();
        assert_eq!(code.k1, 1.2);
        assert_eq!(code.b, 0.5);

        let knowledge = Bm25Params::knowledge();
        assert_eq!(knowledge.k1, 1.2);
        assert_eq!(knowledge.b, 0.3);
    }
}