bm25-rerank 0.1.0

BM25 reranker for RAG: in-memory term-frequency reranking against a small candidate set. Stateless, zero deps.
Documentation
//! # bm25-rerank
//!
//! Stateless BM25 reranker. Given a query and a candidate set, computes
//! per-doc BM25 scores against an in-memory term-frequency corpus
//! derived from the candidates themselves.
//!
//! This is the second-stage reranker pattern: dense retrieval pulls
//! ~50 candidates, BM25 reranks them against the literal query terms
//! to surface keyword matches that the embedding may have missed.
//!
//! ## Example
//!
//! ```
//! use bm25_rerank::rerank;
//! let docs = [
//!     "the quick brown fox",
//!     "a brown dog sleeps",
//!     "lazy fox jumps over",
//! ];
//! let order = rerank("fox", &docs, Default::default());
//! // Doc 0 has "fox" and is shorter -> higher BM25 score
//! assert_eq!(order[0], 0);
//! ```

#![deny(missing_docs)]

/// BM25 hyperparameters.
#[derive(Debug, Clone, Copy)]
pub struct Bm25Opts {
    /// Term-frequency saturation. 1.2 is the Lucene default.
    pub k1: f32,
    /// Length normalization. 0.75 is the Lucene default.
    pub b: f32,
}

impl Default for Bm25Opts {
    fn default() -> Self {
        Self { k1: 1.2, b: 0.75 }
    }
}

/// Returns the candidate indices in BM25-score-descending order.
///
/// Pass an empty query to get the original order (stable).
pub fn rerank<S: AsRef<str>>(query: &str, docs: &[S], opts: Bm25Opts) -> Vec<usize> {
    let scores = score(query, docs, opts);
    let mut indices: Vec<usize> = (0..docs.len()).collect();
    indices.sort_by(|&a, &b| {
        scores[b]
            .partial_cmp(&scores[a])
            .unwrap_or(std::cmp::Ordering::Equal)
    });
    indices
}

/// Per-doc BM25 scores, parallel to `docs`.
pub fn score<S: AsRef<str>>(query: &str, docs: &[S], opts: Bm25Opts) -> Vec<f32> {
    let q_terms: Vec<String> = tokenize(query);
    if q_terms.is_empty() || docs.is_empty() {
        return vec![0.0; docs.len()];
    }

    let doc_tokens: Vec<Vec<String>> = docs.iter().map(|d| tokenize(d.as_ref())).collect();
    let lens: Vec<f32> = doc_tokens.iter().map(|t| t.len() as f32).collect();
    let avgdl: f32 = if lens.is_empty() {
        0.0
    } else {
        lens.iter().sum::<f32>() / lens.len() as f32
    };
    let n = doc_tokens.len() as f32;

    let mut scores = vec![0.0_f32; doc_tokens.len()];
    for term in &q_terms {
        // df = how many docs contain this term
        let df = doc_tokens
            .iter()
            .filter(|t| t.iter().any(|x| x == term))
            .count() as f32;
        if df == 0.0 {
            continue;
        }
        // Lucene-style IDF with smoothing.
        let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();

        for (i, tokens) in doc_tokens.iter().enumerate() {
            let tf = tokens.iter().filter(|x| *x == term).count() as f32;
            if tf == 0.0 {
                continue;
            }
            let dl = lens[i];
            let denom = tf + opts.k1 * (1.0 - opts.b + opts.b * (dl / avgdl.max(1.0)));
            scores[i] += idf * (tf * (opts.k1 + 1.0)) / denom;
        }
    }
    scores
}

fn tokenize(s: &str) -> Vec<String> {
    s.split(|c: char| !c.is_alphanumeric())
        .filter(|t| !t.is_empty())
        .map(|t| t.to_ascii_lowercase())
        .collect()
}