Skip to main content

nodedb_cluster/distributed_document/
bm25_global.rs

1// SPDX-License-Identifier: BUSL-1.1
2
3//! Global IDF computation for distributed BM25 text search.
4//!
5//! BM25 scoring depends on IDF (Inverse Document Frequency) — how rare a
6//! term is across the ENTIRE corpus. When documents are sharded, each shard
7//! only knows its local DF. Scores from different shards are incomparable
8//! without global IDF.
9//!
10//! Two-phase scatter-gather:
11//! 1. **Phase 1 (DF collection)**: Ask all shards for local document
12//!    frequencies and total doc counts for the search terms.
13//! 2. **Coordinator**: Compute global IDF from aggregated DFs.
14//! 3. **Phase 2 (Scored search)**: Send global IDF to all shards. Each
15//!    shard computes BM25 with the shared IDF, returns its local top-K.
16//! 4. **Coordinator**: Merge-sort by BM25 score, return global top-K.
17
18use std::collections::HashMap;
19
20use serde::{Deserialize, Serialize};
21
22/// Per-shard document frequency report (Phase 1 response).
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ShardDfReport {
25    pub shard_id: u32,
26    /// Total documents on this shard.
27    pub total_docs: u64,
28    /// Sum of all document lengths on this shard (for global avg_doc_len).
29    pub total_token_sum: u64,
30    /// Per-term document frequency: `term → count of docs containing term`.
31    pub term_dfs: HashMap<String, u64>,
32}
33
34/// Global IDF and avg_doc_len computed from all shard DF reports.
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct GlobalIdf {
37    /// Total documents across all shards.
38    pub total_docs: u64,
39    /// Global average document length (total_token_sum / total_docs).
40    /// Shards MUST use this instead of their local avg_doc_len for BM25.
41    pub avg_doc_len: f64,
42    /// Per-term IDF: `term → idf_score`.
43    pub term_idfs: HashMap<String, f64>,
44}
45
46/// A scored search hit from a shard (Phase 2 response).
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ScoredHit {
49    pub doc_id: String,
50    pub bm25_score: f64,
51    pub shard_id: u32,
52}
53
54/// Coordinator for 2-phase distributed BM25.
55pub struct GlobalIdfCoordinator {
56    /// Search terms for this query.
57    terms: Vec<String>,
58    /// Collected DF reports from shards.
59    df_reports: Vec<ShardDfReport>,
60    /// Number of shards expected.
61    expected_shards: usize,
62    /// Computed global IDF (available after Phase 1 complete).
63    global_idf: Option<GlobalIdf>,
64}
65
66impl GlobalIdfCoordinator {
67    pub fn new(terms: Vec<String>, expected_shards: usize) -> Self {
68        Self {
69            terms,
70            df_reports: Vec::with_capacity(expected_shards),
71            expected_shards,
72            global_idf: None,
73        }
74    }
75
76    // -- Phase 1: Collect DFs --
77
78    /// Record a shard's DF report.
79    pub fn add_df_report(&mut self, report: ShardDfReport) {
80        self.df_reports.push(report);
81    }
82
83    /// Whether all shards have reported Phase 1.
84    pub fn phase1_complete(&self) -> bool {
85        self.df_reports.len() >= self.expected_shards
86    }
87
88    /// Compute global IDF from all shard DF reports.
89    ///
90    /// Call this after `phase1_complete()` returns true.
91    /// Uses the standard BM25 IDF formula:
92    /// `idf(t) = ln((N - df(t) + 0.5) / (df(t) + 0.5) + 1)`
93    /// where N = total docs, df(t) = docs containing term t.
94    pub fn compute_global_idf(&mut self) -> &GlobalIdf {
95        let total_docs: u64 = self.df_reports.iter().map(|r| r.total_docs).sum();
96        let total_token_sum: u64 = self.df_reports.iter().map(|r| r.total_token_sum).sum();
97        let avg_doc_len = if total_docs > 0 {
98            total_token_sum as f64 / total_docs as f64
99        } else {
100            1.0
101        };
102
103        let mut global_dfs: HashMap<String, u64> = HashMap::new();
104        for report in &self.df_reports {
105            for (term, &df) in &report.term_dfs {
106                *global_dfs.entry(term.clone()).or_insert(0) += df;
107            }
108        }
109
110        let mut term_idfs = HashMap::new();
111        let n = total_docs as f64;
112        for term in &self.terms {
113            let df = *global_dfs.get(term).unwrap_or(&0) as f64;
114            let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
115            term_idfs.insert(term.clone(), idf);
116        }
117
118        self.global_idf = Some(GlobalIdf {
119            total_docs,
120            avg_doc_len,
121            term_idfs,
122        });
123        // Safety: we just assigned Some above.
124        match &self.global_idf {
125            Some(idf) => idf,
126            None => unreachable!(),
127        }
128    }
129
130    /// Get the computed global IDF (None if Phase 1 not complete).
131    pub fn global_idf(&self) -> Option<&GlobalIdf> {
132        self.global_idf.as_ref()
133    }
134
135    // -- Phase 2: Merge scored results --
136
137    /// Merge scored hits from all shards, return global top-K by BM25 score.
138    pub fn merge_scored_hits(shard_results: &[Vec<ScoredHit>], top_k: usize) -> Vec<ScoredHit> {
139        let mut all_hits: Vec<ScoredHit> = shard_results
140            .iter()
141            .flat_map(|r| r.iter().cloned())
142            .collect();
143        all_hits.sort_by(|a, b| {
144            b.bm25_score
145                .partial_cmp(&a.bm25_score)
146                .unwrap_or(std::cmp::Ordering::Equal)
147        });
148        all_hits.truncate(top_k);
149        all_hits
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    #[test]
158    fn global_idf_two_shards() {
159        let mut coord = GlobalIdfCoordinator::new(vec!["rust".into(), "database".into()], 2);
160
161        coord.add_df_report(ShardDfReport {
162            shard_id: 0,
163            total_docs: 1000,
164            total_token_sum: 100_000,
165            term_dfs: HashMap::from([("rust".into(), 50), ("database".into(), 200)]),
166        });
167        coord.add_df_report(ShardDfReport {
168            shard_id: 1,
169            total_docs: 1000,
170            total_token_sum: 120_000,
171            term_dfs: HashMap::from([("rust".into(), 30), ("database".into(), 300)]),
172        });
173
174        assert!(coord.phase1_complete());
175        let idf = coord.compute_global_idf();
176
177        assert_eq!(idf.total_docs, 2000);
178        // Global avg_doc_len = (100_000 + 120_000) / 2000 = 110.0
179        assert!((idf.avg_doc_len - 110.0).abs() < f64::EPSILON);
180        // "rust": df=80, N=2000 → idf = ln((2000-80+0.5)/(80+0.5)+1) ≈ 3.2
181        assert!(idf.term_idfs["rust"] > 3.0);
182        // "database": df=500, N=2000 → idf = ln((2000-500+0.5)/(500+0.5)+1) ≈ 1.4
183        assert!(idf.term_idfs["database"] > 1.0);
184        assert!(idf.term_idfs["database"] < idf.term_idfs["rust"]); // "rust" is rarer.
185    }
186
187    #[test]
188    fn merge_scored_hits() {
189        let shard_a = vec![
190            ScoredHit {
191                doc_id: "a1".into(),
192                bm25_score: 5.0,
193                shard_id: 0,
194            },
195            ScoredHit {
196                doc_id: "a2".into(),
197                bm25_score: 3.0,
198                shard_id: 0,
199            },
200        ];
201        let shard_b = vec![
202            ScoredHit {
203                doc_id: "b1".into(),
204                bm25_score: 4.5,
205                shard_id: 1,
206            },
207            ScoredHit {
208                doc_id: "b2".into(),
209                bm25_score: 2.0,
210                shard_id: 1,
211            },
212        ];
213
214        let merged = GlobalIdfCoordinator::merge_scored_hits(&[shard_a, shard_b], 3);
215        assert_eq!(merged.len(), 3);
216        assert_eq!(merged[0].doc_id, "a1"); // score 5.0
217        assert_eq!(merged[1].doc_id, "b1"); // score 4.5
218        assert_eq!(merged[2].doc_id, "a2"); // score 3.0
219    }
220
221    #[test]
222    fn rare_term_has_higher_idf() {
223        let mut coord = GlobalIdfCoordinator::new(vec!["rare".into(), "common".into()], 1);
224        coord.add_df_report(ShardDfReport {
225            shard_id: 0,
226            total_docs: 10_000,
227            total_token_sum: 1_000_000,
228            term_dfs: HashMap::from([("rare".into(), 5), ("common".into(), 9000)]),
229        });
230        let idf = coord.compute_global_idf();
231        assert!(idf.term_idfs["rare"] > idf.term_idfs["common"]);
232    }
233}