Skip to main content

ailake_query/
bm25.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! BM25 scoring and corpus IDF statistics for hybrid vector+lexical search.
3//!
4//! # Storage
5//!
6//! `IdfStats` is persisted as a compressed binary blob at
7//! `<table_root>/metadata/ailake_bm25_stats.bin` (zstd-compressed bincode).
8//! The path is recorded in Iceberg table properties under `ailake.bm25.stats-path`
9//! so readers know where to find it.
10//!
11//! # Accuracy
12//!
13//! IDF is computed from ALL documents written through `TableWriter::write_batch`
14//! when a `bm25_text_column` is configured. Concurrent writers may lose DF deltas
15//! due to a read-modify-write race on the stats file (same caveat as Iceberg without
16//! OCC). Compaction rebuilds stats accurately from all surviving data files.
17
18use std::collections::HashMap;
19
20use serde::{Deserialize, Serialize};
21
22use ailake_core::{AilakeError, AilakeResult};
23
24/// BM25 hyperparameters.
25const K1: f32 = 1.2;
26const B: f32 = 0.75;
27/// Maximum vocabulary size. Prunes lowest-DF terms when exceeded.
28const MAX_VOCAB: usize = 50_000;
29/// Minimum term length to index.
30const MIN_TERM_LEN: usize = 2;
31
32/// Tokenize text into lowercase alphanumeric terms, dropping single-char tokens.
33pub fn tokenize(text: &str) -> Vec<String> {
34    text.split(|c: char| !c.is_alphanumeric())
35        .filter(|t| t.len() >= MIN_TERM_LEN)
36        .map(|t| t.to_lowercase())
37        .collect()
38}
39
40/// Corpus-level IDF statistics accumulated from all ingested documents.
41///
42/// Serialized via bincode + zstd and stored as a file alongside the Iceberg metadata.
43#[derive(Debug, Clone, Serialize, Deserialize, Default)]
44pub struct IdfStats {
45    /// Number of documents (rows) seen.
46    pub doc_count: u64,
47    /// Sum of all document token lengths (used for avg_doc_len).
48    pub total_tokens: u64,
49    /// Document frequency: number of documents containing each term.
50    pub term_df: HashMap<String, u64>,
51}
52
53impl IdfStats {
54    pub fn avg_doc_len(&self) -> f32 {
55        if self.doc_count == 0 {
56            1.0
57        } else {
58            self.total_tokens as f32 / self.doc_count as f32
59        }
60    }
61
62    /// BM25+ IDF: always positive, avoids negative values for terms appearing in >50% of docs.
63    pub fn idf(&self, term: &str) -> f32 {
64        let df = self.term_df.get(term).copied().unwrap_or(0) as f32;
65        let n = self.doc_count as f32;
66        // Terms absent from stats get max IDF (treated as appearing in 1 doc).
67        // ln((N - 0 + 0.5) / (0 + 0.5) + 1) ≈ ln(2N + 1) — correctly very high for rare terms.
68        ((n - df + 0.5) / (df + 0.5) + 1.0).ln()
69    }
70
71    /// Merge DF counts from a new batch of text documents into this stats object.
72    ///
73    /// Each `&str` is one document. Prunes to `MAX_VOCAB` by dropping lowest-DF
74    /// terms after merge (keeps highest-DF terms which are most useful for BM25 normalization).
75    pub fn merge_batch(&mut self, texts: &[&str]) {
76        for &text in texts {
77            let terms = tokenize(text);
78            self.doc_count += 1;
79            self.total_tokens += terms.len() as u64;
80
81            // Count each unique term at most once per document (for DF, not TF).
82            let mut seen = HashMap::<&str, ()>::new();
83            for term in &terms {
84                if seen.insert(term.as_str(), ()).is_none() {
85                    *self.term_df.entry(term.clone()).or_insert(0) += 1;
86                }
87            }
88        }
89
90        if self.term_df.len() > MAX_VOCAB {
91            // Keep highest-DF terms — common terms anchor BM25 normalization (avgdl)
92            // and appear most in queries. Rare unseen terms get max-IDF approximation.
93            let mut pairs: Vec<(String, u64)> = self.term_df.drain().collect();
94            pairs.sort_unstable_by_key(|b| std::cmp::Reverse(b.1));
95            pairs.truncate(MAX_VOCAB);
96            self.term_df = pairs.into_iter().collect();
97        }
98    }
99
100    /// Serialize to zstd-compressed bincode bytes.
101    pub fn to_bytes(&self) -> AilakeResult<Vec<u8>> {
102        let raw = bincode::serialize(self).map_err(|e| AilakeError::Bincode(e.to_string()))?;
103        zstd::encode_all(&raw[..], 3).map_err(AilakeError::Io)
104    }
105
106    /// Deserialize from zstd-compressed bincode bytes.
107    pub fn from_bytes(bytes: &[u8]) -> AilakeResult<Self> {
108        let raw = zstd::decode_all(bytes).map_err(AilakeError::Io)?;
109        bincode::deserialize(&raw).map_err(|e| AilakeError::Bincode(e.to_string()))
110    }
111}
112
113/// BM25 scorer backed by global [`IdfStats`].
114pub struct BM25Scorer<'a> {
115    stats: &'a IdfStats,
116}
117
118impl<'a> BM25Scorer<'a> {
119    pub fn new(stats: &'a IdfStats) -> Self {
120        Self { stats }
121    }
122
123    /// Score `doc_text` against `query_text`. Returns BM25 score (higher = more relevant).
124    pub fn score(&self, query_text: &str, doc_text: &str) -> f32 {
125        let query_terms = tokenize(query_text);
126        if query_terms.is_empty() {
127            return 0.0;
128        }
129
130        let doc_terms = tokenize(doc_text);
131        let doc_len = doc_terms.len() as f32;
132        let avgdl = self.stats.avg_doc_len();
133
134        let mut tf_map: HashMap<&str, u32> = HashMap::new();
135        for term in &doc_terms {
136            *tf_map.entry(term.as_str()).or_insert(0) += 1;
137        }
138
139        let mut score = 0.0f32;
140        for term in &query_terms {
141            let tf = tf_map.get(term.as_str()).copied().unwrap_or(0) as f32;
142            if tf == 0.0 {
143                continue;
144            }
145            let idf = self.stats.idf(term);
146            // BM25 TF normalization with length penalty
147            let tf_norm = tf * (K1 + 1.0) / (tf + K1 * (1.0 - B + B * doc_len / avgdl));
148            score += idf * tf_norm;
149        }
150        score
151    }
152
153    /// Compute BM25 scores for a slice of document texts. Returns parallel scores.
154    pub fn score_batch(&self, query_text: &str, docs: &[&str]) -> Vec<f32> {
155        docs.iter().map(|doc| self.score(query_text, doc)).collect()
156    }
157}
158
159/// Fusion method for combining vector and BM25 ranked lists.
160#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
161pub enum HybridFusion {
162    /// Reciprocal Rank Fusion: `score = w_vec/(k+rank_vec) + w_bm25/(k+rank_bm25)`.
163    /// Default k=60 (standard RRF). Returned score is negated so sort-ascending = best first.
164    #[default]
165    Rrf,
166    /// Linear combination after min-max normalization:
167    /// `score = (1-bm25_weight) * norm_dist + bm25_weight * (1 - norm_bm25)`.
168    Linear,
169}
170
171/// Configuration for hybrid vector+BM25 search.
172#[derive(Debug, Clone)]
173pub struct HybridConfig {
174    /// Raw text query. Tokenized internally for BM25 scoring.
175    pub query_text: String,
176    /// Parquet column(s) containing the text to score. Typically `["chunk_text"]`.
177    /// When multiple columns are given, texts are concatenated with a space separator.
178    pub text_columns: Vec<String>,
179    /// BM25 weight in RRF fusion (0.0–1.0). Vector weight = 1.0 - bm25_weight in Linear mode.
180    /// In Rrf mode, both weights scale their respective RRF rank term.
181    pub bm25_weight: f32,
182    /// Fusion strategy. Default: `Rrf`.
183    pub fusion: HybridFusion,
184    /// Minimum number of HNSW candidates to fetch before BM25 re-ranking.
185    /// Ensures the BM25 pool is large enough to find lexically relevant results.
186    /// Defaults to `max(rerank_factor * top_k, 10 * top_k)` if not set.
187    pub candidate_pool: Option<usize>,
188}
189
190impl Default for HybridConfig {
191    fn default() -> Self {
192        Self {
193            query_text: String::new(),
194            text_columns: vec!["chunk_text".to_string()],
195            bm25_weight: 0.5,
196            fusion: HybridFusion::Rrf,
197            candidate_pool: None,
198        }
199    }
200}
201
202impl HybridConfig {
203    pub fn new(query_text: impl Into<String>) -> Self {
204        Self {
205            query_text: query_text.into(),
206            ..Default::default()
207        }
208    }
209
210    pub fn with_text_column(mut self, col: impl Into<String>) -> Self {
211        self.text_columns = vec![col.into()];
212        self
213    }
214
215    pub fn with_text_columns(mut self, cols: Vec<String>) -> Self {
216        self.text_columns = cols;
217        self
218    }
219
220    pub fn with_bm25_weight(mut self, w: f32) -> Self {
221        self.bm25_weight = w.clamp(0.0, 1.0);
222        self
223    }
224
225    pub fn with_fusion(mut self, fusion: HybridFusion) -> Self {
226        self.fusion = fusion;
227        self
228    }
229
230    pub fn with_candidate_pool(mut self, n: usize) -> Self {
231        self.candidate_pool = Some(n);
232        self
233    }
234}
235
236/// Apply RRF fusion over (vec_rank, bm25_rank) pairs.
237///
238/// Returns `-rrf_score` so that sort-ascending-by-distance gives best results first.
239pub fn rrf_score(vec_rank: usize, bm25_rank: usize, bm25_weight: f32) -> f32 {
240    const RRF_K: f32 = 60.0;
241    let vec_weight = 1.0 - bm25_weight;
242    let rrf = vec_weight / (RRF_K + vec_rank as f32) + bm25_weight / (RRF_K + bm25_rank as f32);
243    -rrf
244}
245
246/// Apply linear fusion over normalized (vec_dist, bm25_score) pairs.
247///
248/// Both inputs are normalized to [0,1] across the candidate set before fusion.
249/// Returns value in [0,1] where 0 = best (lower = better convention).
250pub fn linear_score(
251    vec_dist: f32,
252    min_vec: f32,
253    max_vec: f32,
254    bm25: f32,
255    min_bm25: f32,
256    max_bm25: f32,
257    bm25_weight: f32,
258) -> f32 {
259    let norm_vec = if (max_vec - min_vec).abs() < f32::EPSILON {
260        0.0
261    } else {
262        (vec_dist - min_vec) / (max_vec - min_vec)
263    };
264    let norm_bm25 = if (max_bm25 - min_bm25).abs() < f32::EPSILON {
265        0.5
266    } else {
267        (bm25 - min_bm25) / (max_bm25 - min_bm25)
268    };
269    let vec_weight = 1.0 - bm25_weight;
270    // Higher BM25 = better = lower final distance: use (1 - norm_bm25)
271    vec_weight * norm_vec + bm25_weight * (1.0 - norm_bm25)
272}
273
274/// Constant used in Iceberg properties to point at the BM25 stats file.
275pub const BM25_STATS_PATH_PROP: &str = "ailake.bm25.stats-path";
276/// Default relative path for the BM25 stats file within the table root.
277pub const BM25_STATS_FILE: &str = "metadata/ailake_bm25_stats.bin";
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn tokenize_basic() {
285        let tokens = tokenize("Hello, World! This is a test.");
286        assert!(tokens.contains(&"hello".to_string()));
287        assert!(tokens.contains(&"world".to_string()));
288        assert!(tokens.contains(&"test".to_string()));
289        // "a" and "is" are filtered (len < 2 and len == 2)
290        assert!(!tokens.contains(&"a".to_string()));
291    }
292
293    #[test]
294    fn idf_empty_corpus_returns_positive() {
295        let stats = IdfStats::default();
296        let idf = stats.idf("unknown_term");
297        assert!(idf > 0.0, "IDF should be positive for unseen term");
298    }
299
300    #[test]
301    fn merge_batch_accumulates_df() {
302        let mut stats = IdfStats::default();
303        stats.merge_batch(&["the quick brown fox", "the lazy dog"]);
304        assert_eq!(stats.doc_count, 2);
305        assert_eq!(stats.term_df["the"], 2, "the appears in both docs");
306        assert_eq!(stats.term_df["fox"], 1);
307        assert_eq!(stats.term_df["dog"], 1);
308    }
309
310    #[test]
311    fn bm25_scorer_ranks_relevant_doc_higher() {
312        let mut stats = IdfStats::default();
313        let docs = [
314            "rust programming language systems",
315            "python machine learning data science",
316            "rust memory safety zero cost abstractions",
317        ];
318        stats.merge_batch(&docs);
319
320        let scorer = BM25Scorer::new(&stats);
321        let query = "rust systems programming";
322        let s0 = scorer.score(query, docs[0]);
323        let s1 = scorer.score(query, docs[1]);
324        let s2 = scorer.score(query, docs[2]);
325
326        // docs[0] and docs[2] are about Rust — should score higher than docs[1]
327        assert!(
328            s0 > s1,
329            "rust doc scores higher than python doc: s0={s0}, s1={s1}"
330        );
331        assert!(
332            s2 > s1,
333            "rust doc scores higher than python doc: s2={s2}, s1={s1}"
334        );
335    }
336
337    #[test]
338    fn idf_stats_roundtrip() {
339        let mut stats = IdfStats::default();
340        stats.merge_batch(&["hello world foo bar", "foo baz qux"]);
341        let bytes = stats.to_bytes().unwrap();
342        let restored = IdfStats::from_bytes(&bytes).unwrap();
343        assert_eq!(restored.doc_count, stats.doc_count);
344        assert_eq!(restored.term_df["foo"], 2);
345        assert_eq!(restored.term_df["hello"], 1);
346    }
347
348    #[test]
349    fn vocab_cap_prunes_to_max() {
350        let mut stats = IdfStats::default();
351        // Generate more than MAX_VOCAB unique terms
352        let doc: String = (0..=MAX_VOCAB + 100)
353            .map(|i| format!("term{i}"))
354            .collect::<Vec<_>>()
355            .join(" ");
356        stats.merge_batch(&[doc.as_str()]);
357        assert!(
358            stats.term_df.len() <= MAX_VOCAB,
359            "vocab should be capped at {MAX_VOCAB}"
360        );
361    }
362
363    #[test]
364    fn rrf_score_is_negative() {
365        let s = rrf_score(0, 0, 0.5);
366        assert!(
367            s < 0.0,
368            "RRF score should be negated for sort-ascending convention"
369        );
370    }
371
372    #[test]
373    fn linear_score_in_range() {
374        let s = linear_score(0.5, 0.0, 1.0, 0.8, 0.0, 1.0, 0.5);
375        assert!(
376            (0.0..=1.0).contains(&s),
377            "linear score should be in [0,1]: {s}"
378        );
379    }
380}