nodedb_cluster/distributed_document/
bm25_global.rs1use std::collections::HashMap;
19
20use serde::{Deserialize, Serialize};
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ShardDfReport {
25 pub shard_id: u32,
26 pub total_docs: u64,
28 pub total_token_sum: u64,
30 pub term_dfs: HashMap<String, u64>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct GlobalIdf {
37 pub total_docs: u64,
39 pub avg_doc_len: f64,
42 pub term_idfs: HashMap<String, f64>,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ScoredHit {
49 pub doc_id: String,
50 pub bm25_score: f64,
51 pub shard_id: u32,
52}
53
54pub struct GlobalIdfCoordinator {
56 terms: Vec<String>,
58 df_reports: Vec<ShardDfReport>,
60 expected_shards: usize,
62 global_idf: Option<GlobalIdf>,
64}
65
66impl GlobalIdfCoordinator {
67 pub fn new(terms: Vec<String>, expected_shards: usize) -> Self {
68 Self {
69 terms,
70 df_reports: Vec::with_capacity(expected_shards),
71 expected_shards,
72 global_idf: None,
73 }
74 }
75
76 pub fn add_df_report(&mut self, report: ShardDfReport) {
80 self.df_reports.push(report);
81 }
82
83 pub fn phase1_complete(&self) -> bool {
85 self.df_reports.len() >= self.expected_shards
86 }
87
88 pub fn compute_global_idf(&mut self) -> &GlobalIdf {
95 let total_docs: u64 = self.df_reports.iter().map(|r| r.total_docs).sum();
96 let total_token_sum: u64 = self.df_reports.iter().map(|r| r.total_token_sum).sum();
97 let avg_doc_len = if total_docs > 0 {
98 total_token_sum as f64 / total_docs as f64
99 } else {
100 1.0
101 };
102
103 let mut global_dfs: HashMap<String, u64> = HashMap::new();
104 for report in &self.df_reports {
105 for (term, &df) in &report.term_dfs {
106 *global_dfs.entry(term.clone()).or_insert(0) += df;
107 }
108 }
109
110 let mut term_idfs = HashMap::new();
111 let n = total_docs as f64;
112 for term in &self.terms {
113 let df = *global_dfs.get(term).unwrap_or(&0) as f64;
114 let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
115 term_idfs.insert(term.clone(), idf);
116 }
117
118 self.global_idf = Some(GlobalIdf {
119 total_docs,
120 avg_doc_len,
121 term_idfs,
122 });
123 match &self.global_idf {
125 Some(idf) => idf,
126 None => unreachable!(),
127 }
128 }
129
130 pub fn global_idf(&self) -> Option<&GlobalIdf> {
132 self.global_idf.as_ref()
133 }
134
135 pub fn merge_scored_hits(shard_results: &[Vec<ScoredHit>], top_k: usize) -> Vec<ScoredHit> {
139 let mut all_hits: Vec<ScoredHit> = shard_results
140 .iter()
141 .flat_map(|r| r.iter().cloned())
142 .collect();
143 all_hits.sort_by(|a, b| {
144 b.bm25_score
145 .partial_cmp(&a.bm25_score)
146 .unwrap_or(std::cmp::Ordering::Equal)
147 });
148 all_hits.truncate(top_k);
149 all_hits
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn global_idf_two_shards() {
159 let mut coord = GlobalIdfCoordinator::new(vec!["rust".into(), "database".into()], 2);
160
161 coord.add_df_report(ShardDfReport {
162 shard_id: 0,
163 total_docs: 1000,
164 total_token_sum: 100_000,
165 term_dfs: HashMap::from([("rust".into(), 50), ("database".into(), 200)]),
166 });
167 coord.add_df_report(ShardDfReport {
168 shard_id: 1,
169 total_docs: 1000,
170 total_token_sum: 120_000,
171 term_dfs: HashMap::from([("rust".into(), 30), ("database".into(), 300)]),
172 });
173
174 assert!(coord.phase1_complete());
175 let idf = coord.compute_global_idf();
176
177 assert_eq!(idf.total_docs, 2000);
178 assert!((idf.avg_doc_len - 110.0).abs() < f64::EPSILON);
180 assert!(idf.term_idfs["rust"] > 3.0);
182 assert!(idf.term_idfs["database"] > 1.0);
184 assert!(idf.term_idfs["database"] < idf.term_idfs["rust"]); }
186
187 #[test]
188 fn merge_scored_hits() {
189 let shard_a = vec![
190 ScoredHit {
191 doc_id: "a1".into(),
192 bm25_score: 5.0,
193 shard_id: 0,
194 },
195 ScoredHit {
196 doc_id: "a2".into(),
197 bm25_score: 3.0,
198 shard_id: 0,
199 },
200 ];
201 let shard_b = vec![
202 ScoredHit {
203 doc_id: "b1".into(),
204 bm25_score: 4.5,
205 shard_id: 1,
206 },
207 ScoredHit {
208 doc_id: "b2".into(),
209 bm25_score: 2.0,
210 shard_id: 1,
211 },
212 ];
213
214 let merged = GlobalIdfCoordinator::merge_scored_hits(&[shard_a, shard_b], 3);
215 assert_eq!(merged.len(), 3);
216 assert_eq!(merged[0].doc_id, "a1"); assert_eq!(merged[1].doc_id, "b1"); assert_eq!(merged[2].doc_id, "a2"); }
220
221 #[test]
222 fn rare_term_has_higher_idf() {
223 let mut coord = GlobalIdfCoordinator::new(vec!["rare".into(), "common".into()], 1);
224 coord.add_df_report(ShardDfReport {
225 shard_id: 0,
226 total_docs: 10_000,
227 total_token_sum: 1_000_000,
228 term_dfs: HashMap::from([("rare".into(), 5), ("common".into(), 9000)]),
229 });
230 let idf = coord.compute_global_idf();
231 assert!(idf.term_idfs["rare"] > idf.term_idfs["common"]);
232 }
233}