hermes_core/query/
global_stats.rs1use std::sync::Arc;
11
12use parking_lot::RwLock;
13use rustc_hash::FxHashMap;
14
15use crate::dsl::Field;
16use crate::segment::SegmentReader;
17
18#[derive(Debug)]
23pub struct GlobalStats {
24 total_docs: u64,
26 sparse_stats: FxHashMap<u32, SparseFieldStats>,
28 text_stats: FxHashMap<u32, TextFieldStats>,
30 generation: u64,
32}
33
34#[derive(Debug, Default)]
36pub struct SparseFieldStats {
37 pub doc_freqs: FxHashMap<u32, u64>,
39}
40
41#[derive(Debug, Default)]
43pub struct TextFieldStats {
44 pub doc_freqs: FxHashMap<String, u64>,
46 pub avg_field_len: f32,
48}
49
50impl GlobalStats {
51 pub fn new() -> Self {
53 Self {
54 total_docs: 0,
55 sparse_stats: FxHashMap::default(),
56 text_stats: FxHashMap::default(),
57 generation: 0,
58 }
59 }
60
61 #[inline]
63 pub fn total_docs(&self) -> u64 {
64 self.total_docs
65 }
66
67 #[inline]
71 pub fn sparse_idf(&self, field: Field, dim_id: u32) -> f32 {
72 if let Some(stats) = self.sparse_stats.get(&field.0)
73 && let Some(&df) = stats.doc_freqs.get(&dim_id)
74 && df > 0
75 {
76 return (self.total_docs as f32 / df as f32).ln();
77 }
78 0.0
79 }
80
81 pub fn sparse_idf_weights(&self, field: Field, dim_ids: &[u32]) -> Vec<f32> {
83 dim_ids.iter().map(|&d| self.sparse_idf(field, d)).collect()
84 }
85
86 #[inline]
90 pub fn text_idf(&self, field: Field, term: &str) -> f32 {
91 if let Some(stats) = self.text_stats.get(&field.0)
92 && let Some(&df) = stats.doc_freqs.get(term)
93 {
94 let n = self.total_docs as f32;
95 let df = df as f32;
96 return ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
98 }
99 0.0
100 }
101
102 #[inline]
104 pub fn avg_field_len(&self, field: Field) -> f32 {
105 self.text_stats
106 .get(&field.0)
107 .map(|s| s.avg_field_len)
108 .unwrap_or(1.0)
109 }
110
111 #[inline]
113 pub fn generation(&self) -> u64 {
114 self.generation
115 }
116}
117
118impl Default for GlobalStats {
119 fn default() -> Self {
120 Self::new()
121 }
122}
123
124pub struct GlobalStatsBuilder {
126 pub total_docs: u64,
128 sparse_stats: FxHashMap<u32, SparseFieldStats>,
129 text_stats: FxHashMap<u32, TextFieldStats>,
130}
131
132impl GlobalStatsBuilder {
133 pub fn new() -> Self {
135 Self {
136 total_docs: 0,
137 sparse_stats: FxHashMap::default(),
138 text_stats: FxHashMap::default(),
139 }
140 }
141
142 pub fn add_segment(&mut self, reader: &SegmentReader) {
144 self.total_docs += reader.num_docs() as u64;
145
146 }
149
150 pub fn add_sparse_df(&mut self, field: Field, dim_id: u32, doc_count: u64) {
152 let stats = self.sparse_stats.entry(field.0).or_default();
153 *stats.doc_freqs.entry(dim_id).or_insert(0) += doc_count;
154 }
155
156 pub fn add_text_df(&mut self, field: Field, term: String, doc_count: u64) {
158 let stats = self.text_stats.entry(field.0).or_default();
159 *stats.doc_freqs.entry(term).or_insert(0) += doc_count;
160 }
161
162 pub fn set_avg_field_len(&mut self, field: Field, avg_len: f32) {
164 let stats = self.text_stats.entry(field.0).or_default();
165 stats.avg_field_len = avg_len;
166 }
167
168 pub fn build(self, generation: u64) -> GlobalStats {
170 GlobalStats {
171 total_docs: self.total_docs,
172 sparse_stats: self.sparse_stats,
173 text_stats: self.text_stats,
174 generation,
175 }
176 }
177}
178
179impl Default for GlobalStatsBuilder {
180 fn default() -> Self {
181 Self::new()
182 }
183}
184
185pub struct GlobalStatsCache {
190 stats: RwLock<Option<Arc<GlobalStats>>>,
192 generation: RwLock<u64>,
194}
195
196impl GlobalStatsCache {
197 pub fn new() -> Self {
199 Self {
200 stats: RwLock::new(None),
201 generation: RwLock::new(0),
202 }
203 }
204
205 pub fn invalidate(&self) {
207 let mut current_gen = self.generation.write();
208 *current_gen += 1;
209 let mut stats = self.stats.write();
210 *stats = None;
211 }
212
213 pub fn generation(&self) -> u64 {
215 *self.generation.read()
216 }
217
218 pub fn get(&self) -> Option<Arc<GlobalStats>> {
220 self.stats.read().clone()
221 }
222
223 pub fn set(&self, stats: GlobalStats) {
225 let mut cached = self.stats.write();
226 *cached = Some(Arc::new(stats));
227 }
228
229 pub fn get_or_compute<F>(&self, compute: F) -> Arc<GlobalStats>
233 where
234 F: FnOnce(&mut GlobalStatsBuilder),
235 {
236 if let Some(stats) = self.get() {
238 return stats;
239 }
240
241 let current_gen = self.generation();
243 let mut builder = GlobalStatsBuilder::new();
244 compute(&mut builder);
245 let stats = Arc::new(builder.build(current_gen));
246
247 let mut cached = self.stats.write();
249 *cached = Some(Arc::clone(&stats));
250
251 stats
252 }
253
254 pub fn needs_rebuild(&self) -> bool {
256 self.stats.read().is_none()
257 }
258
259 pub fn set_stats(&self, stats: GlobalStats) {
261 let mut cached = self.stats.write();
262 *cached = Some(Arc::new(stats));
263 }
264}
265
266impl Default for GlobalStatsCache {
267 fn default() -> Self {
268 Self::new()
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn test_sparse_idf_computation() {
278 let mut builder = GlobalStatsBuilder::new();
279 builder.total_docs = 1000;
280 builder.add_sparse_df(Field(0), 42, 100); builder.add_sparse_df(Field(0), 43, 10); let stats = builder.build(1);
284
285 let idf_42 = stats.sparse_idf(Field(0), 42);
287 let idf_43 = stats.sparse_idf(Field(0), 43);
288
289 assert!(idf_43 > idf_42);
291 assert!((idf_42 - (1000.0_f32 / 100.0).ln()).abs() < 0.001);
292 assert!((idf_43 - (1000.0_f32 / 10.0).ln()).abs() < 0.001);
293 }
294
295 #[test]
296 fn test_text_idf_computation() {
297 let mut builder = GlobalStatsBuilder::new();
298 builder.total_docs = 10000;
299 builder.add_text_df(Field(0), "common".to_string(), 5000);
300 builder.add_text_df(Field(0), "rare".to_string(), 10);
301
302 let stats = builder.build(1);
303
304 let idf_common = stats.text_idf(Field(0), "common");
305 let idf_rare = stats.text_idf(Field(0), "rare");
306
307 assert!(idf_rare > idf_common);
309 }
310
311 #[test]
312 fn test_cache_invalidation() {
313 let cache = GlobalStatsCache::new();
314
315 assert!(cache.get().is_none());
317
318 let stats = cache.get_or_compute(|builder| {
320 builder.total_docs = 100;
321 });
322 assert_eq!(stats.total_docs(), 100);
323
324 assert!(cache.get().is_some());
326
327 cache.invalidate();
329 assert!(cache.get().is_none());
330 }
331}