Skip to main content

bm25_rerank/
lib.rs

1//! # bm25-rerank
2//!
3//! Stateless BM25 reranker. Given a query and a candidate set, computes
4//! per-doc BM25 scores against an in-memory term-frequency corpus
5//! derived from the candidates themselves.
6//!
7//! This is the second-stage reranker pattern: dense retrieval pulls
8//! ~50 candidates, BM25 reranks them against the literal query terms
9//! to surface keyword matches that the embedding may have missed.
10//!
11//! ## Example
12//!
13//! ```
14//! use bm25_rerank::rerank;
15//! let docs = [
16//!     "the quick brown fox",
17//!     "a brown dog sleeps",
18//!     "lazy fox jumps over",
19//! ];
20//! let order = rerank("fox", &docs, Default::default());
21//! // Doc 0 has "fox" and is shorter -> higher BM25 score
22//! assert_eq!(order[0], 0);
23//! ```
24
25#![deny(missing_docs)]
26
27/// BM25 hyperparameters.
28#[derive(Debug, Clone, Copy)]
29pub struct Bm25Opts {
30    /// Term-frequency saturation. 1.2 is the Lucene default.
31    pub k1: f32,
32    /// Length normalization. 0.75 is the Lucene default.
33    pub b: f32,
34}
35
36impl Default for Bm25Opts {
37    fn default() -> Self {
38        Self { k1: 1.2, b: 0.75 }
39    }
40}
41
42/// Returns the candidate indices in BM25-score-descending order.
43///
44/// Pass an empty query to get the original order (stable).
45pub fn rerank<S: AsRef<str>>(query: &str, docs: &[S], opts: Bm25Opts) -> Vec<usize> {
46    let scores = score(query, docs, opts);
47    let mut indices: Vec<usize> = (0..docs.len()).collect();
48    indices.sort_by(|&a, &b| {
49        scores[b]
50            .partial_cmp(&scores[a])
51            .unwrap_or(std::cmp::Ordering::Equal)
52    });
53    indices
54}
55
56/// Per-doc BM25 scores, parallel to `docs`.
57pub fn score<S: AsRef<str>>(query: &str, docs: &[S], opts: Bm25Opts) -> Vec<f32> {
58    let q_terms: Vec<String> = tokenize(query);
59    if q_terms.is_empty() || docs.is_empty() {
60        return vec![0.0; docs.len()];
61    }
62
63    let doc_tokens: Vec<Vec<String>> = docs.iter().map(|d| tokenize(d.as_ref())).collect();
64    let lens: Vec<f32> = doc_tokens.iter().map(|t| t.len() as f32).collect();
65    let avgdl: f32 = if lens.is_empty() {
66        0.0
67    } else {
68        lens.iter().sum::<f32>() / lens.len() as f32
69    };
70    let n = doc_tokens.len() as f32;
71
72    let mut scores = vec![0.0_f32; doc_tokens.len()];
73    for term in &q_terms {
74        // df = how many docs contain this term
75        let df = doc_tokens
76            .iter()
77            .filter(|t| t.iter().any(|x| x == term))
78            .count() as f32;
79        if df == 0.0 {
80            continue;
81        }
82        // Lucene-style IDF with smoothing.
83        let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
84
85        for (i, tokens) in doc_tokens.iter().enumerate() {
86            let tf = tokens.iter().filter(|x| *x == term).count() as f32;
87            if tf == 0.0 {
88                continue;
89            }
90            let dl = lens[i];
91            let denom = tf + opts.k1 * (1.0 - opts.b + opts.b * (dl / avgdl.max(1.0)));
92            scores[i] += idf * (tf * (opts.k1 + 1.0)) / denom;
93        }
94    }
95    scores
96}
97
98fn tokenize(s: &str) -> Vec<String> {
99    s.split(|c: char| !c.is_alphanumeric())
100        .filter(|t| !t.is_empty())
101        .map(|t| t.to_ascii_lowercase())
102        .collect()
103}