1use std::sync::Arc;
13
14use parking_lot::RwLock;
15use rustc_hash::FxHashMap;
16
17use crate::dsl::Field;
18use crate::segment::SegmentReader;
19
20pub struct LazyGlobalStats {
26 segments: Vec<Arc<SegmentReader>>,
28 total_docs: u64,
30 sparse_idf_cache: RwLock<FxHashMap<u32, FxHashMap<u32, f32>>>,
32 text_idf_cache: RwLock<FxHashMap<u32, FxHashMap<String, f32>>>,
34 avg_field_len_cache: RwLock<FxHashMap<u32, f32>>,
36}
37
38impl LazyGlobalStats {
39 pub fn new(segments: Vec<Arc<SegmentReader>>) -> Self {
41 let total_docs: u64 = segments.iter().map(|s| s.num_docs() as u64).sum();
42 Self {
43 segments,
44 total_docs,
45 sparse_idf_cache: RwLock::new(FxHashMap::default()),
46 text_idf_cache: RwLock::new(FxHashMap::default()),
47 avg_field_len_cache: RwLock::new(FxHashMap::default()),
48 }
49 }
50
51 #[inline]
53 pub fn total_docs(&self) -> u64 {
54 self.total_docs
55 }
56
57 pub fn sparse_idf(&self, field: Field, dim_id: u32) -> f32 {
61 {
63 let cache = self.sparse_idf_cache.read();
64 if let Some(field_cache) = cache.get(&field.0)
65 && let Some(&idf) = field_cache.get(&dim_id)
66 {
67 return idf;
68 }
69 }
70
71 let df = self.compute_sparse_df(field, dim_id);
73 let idf = if df > 0 && self.total_docs > 0 {
74 (self.total_docs as f32 / df as f32).ln()
75 } else {
76 0.0
77 };
78
79 {
81 let mut cache = self.sparse_idf_cache.write();
82 cache.entry(field.0).or_default().insert(dim_id, idf);
83 }
84
85 idf
86 }
87
88 pub fn sparse_idf_weights(&self, field: Field, dim_ids: &[u32]) -> Vec<f32> {
90 dim_ids.iter().map(|&d| self.sparse_idf(field, d)).collect()
91 }
92
93 pub fn text_idf(&self, field: Field, term: &str) -> f32 {
97 {
99 let cache = self.text_idf_cache.read();
100 if let Some(field_cache) = cache.get(&field.0)
101 && let Some(&idf) = field_cache.get(term)
102 {
103 return idf;
104 }
105 }
106
107 let df = self.compute_text_df(field, term);
109 let n = self.total_docs as f32;
110 let df_f = df as f32;
111 let idf = if df > 0 {
112 ((n - df_f + 0.5) / (df_f + 0.5) + 1.0).ln()
113 } else {
114 0.0
115 };
116
117 {
119 let mut cache = self.text_idf_cache.write();
120 cache
121 .entry(field.0)
122 .or_default()
123 .insert(term.to_string(), idf);
124 }
125
126 idf
127 }
128
129 pub fn avg_field_len(&self, field: Field) -> f32 {
131 {
133 let cache = self.avg_field_len_cache.read();
134 if let Some(&avg) = cache.get(&field.0) {
135 return avg;
136 }
137 }
138
139 let mut weighted_sum = 0.0f64;
141 let mut total_weight = 0u64;
142
143 for segment in &self.segments {
144 let avg_len = segment.avg_field_len(field);
145 let doc_count = segment.num_docs() as u64;
146 if avg_len > 0.0 && doc_count > 0 {
147 weighted_sum += avg_len as f64 * doc_count as f64;
148 total_weight += doc_count;
149 }
150 }
151
152 let avg = if total_weight > 0 {
153 (weighted_sum / total_weight as f64) as f32
154 } else {
155 1.0
156 };
157
158 {
160 let mut cache = self.avg_field_len_cache.write();
161 cache.insert(field.0, avg);
162 }
163
164 avg
165 }
166
167 fn compute_sparse_df(&self, field: Field, dim_id: u32) -> u64 {
169 let mut df = 0u64;
170 for segment in &self.segments {
171 if let Some(sparse_index) = segment.sparse_indexes().get(&field.0)
172 && let Some(Some(posting)) = sparse_index.postings.get(dim_id as usize)
173 {
174 df += posting.doc_count() as u64;
175 }
176 }
177 df
178 }
179
180 fn compute_text_df(&self, _field: Field, _term: &str) -> u64 {
185 0
189 }
190
191 pub fn num_segments(&self) -> usize {
193 self.segments.len()
194 }
195}
196
197impl std::fmt::Debug for LazyGlobalStats {
198 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199 f.debug_struct("LazyGlobalStats")
200 .field("total_docs", &self.total_docs)
201 .field("num_segments", &self.segments.len())
202 .field("sparse_cache_fields", &self.sparse_idf_cache.read().len())
203 .field("text_cache_fields", &self.text_idf_cache.read().len())
204 .finish()
205 }
206}
207
208#[derive(Debug)]
212pub struct GlobalStats {
213 total_docs: u64,
215 sparse_stats: FxHashMap<u32, SparseFieldStats>,
217 text_stats: FxHashMap<u32, TextFieldStats>,
219 generation: u64,
221}
222
223#[derive(Debug, Default)]
225pub struct SparseFieldStats {
226 pub doc_freqs: FxHashMap<u32, u64>,
228}
229
230#[derive(Debug, Default)]
232pub struct TextFieldStats {
233 pub doc_freqs: FxHashMap<String, u64>,
235 pub avg_field_len: f32,
237}
238
239impl GlobalStats {
240 pub fn new() -> Self {
242 Self {
243 total_docs: 0,
244 sparse_stats: FxHashMap::default(),
245 text_stats: FxHashMap::default(),
246 generation: 0,
247 }
248 }
249
250 #[inline]
252 pub fn total_docs(&self) -> u64 {
253 self.total_docs
254 }
255
256 #[inline]
258 pub fn sparse_idf(&self, field: Field, dim_id: u32) -> f32 {
259 if let Some(stats) = self.sparse_stats.get(&field.0)
260 && let Some(&df) = stats.doc_freqs.get(&dim_id)
261 && df > 0
262 {
263 return (self.total_docs as f32 / df as f32).ln();
264 }
265 0.0
266 }
267
268 pub fn sparse_idf_weights(&self, field: Field, dim_ids: &[u32]) -> Vec<f32> {
270 dim_ids.iter().map(|&d| self.sparse_idf(field, d)).collect()
271 }
272
273 #[inline]
275 pub fn text_idf(&self, field: Field, term: &str) -> f32 {
276 if let Some(stats) = self.text_stats.get(&field.0)
277 && let Some(&df) = stats.doc_freqs.get(term)
278 {
279 let n = self.total_docs as f32;
280 let df = df as f32;
281 return ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
282 }
283 0.0
284 }
285
286 #[inline]
288 pub fn avg_field_len(&self, field: Field) -> f32 {
289 self.text_stats
290 .get(&field.0)
291 .map(|s| s.avg_field_len)
292 .unwrap_or(1.0)
293 }
294
295 #[inline]
297 pub fn generation(&self) -> u64 {
298 self.generation
299 }
300}
301
302impl Default for GlobalStats {
303 fn default() -> Self {
304 Self::new()
305 }
306}
307
308pub struct GlobalStatsBuilder {
310 pub total_docs: u64,
312 sparse_stats: FxHashMap<u32, SparseFieldStats>,
313 text_stats: FxHashMap<u32, TextFieldStats>,
314}
315
316impl GlobalStatsBuilder {
317 pub fn new() -> Self {
319 Self {
320 total_docs: 0,
321 sparse_stats: FxHashMap::default(),
322 text_stats: FxHashMap::default(),
323 }
324 }
325
326 pub fn add_segment(&mut self, reader: &SegmentReader) {
328 self.total_docs += reader.num_docs() as u64;
329
330 }
333
334 pub fn add_sparse_df(&mut self, field: Field, dim_id: u32, doc_count: u64) {
336 let stats = self.sparse_stats.entry(field.0).or_default();
337 *stats.doc_freqs.entry(dim_id).or_insert(0) += doc_count;
338 }
339
340 pub fn add_text_df(&mut self, field: Field, term: String, doc_count: u64) {
342 let stats = self.text_stats.entry(field.0).or_default();
343 *stats.doc_freqs.entry(term).or_insert(0) += doc_count;
344 }
345
346 pub fn set_avg_field_len(&mut self, field: Field, avg_len: f32) {
348 let stats = self.text_stats.entry(field.0).or_default();
349 stats.avg_field_len = avg_len;
350 }
351
352 pub fn build(self, generation: u64) -> GlobalStats {
354 GlobalStats {
355 total_docs: self.total_docs,
356 sparse_stats: self.sparse_stats,
357 text_stats: self.text_stats,
358 generation,
359 }
360 }
361}
362
363impl Default for GlobalStatsBuilder {
364 fn default() -> Self {
365 Self::new()
366 }
367}
368
369pub struct GlobalStatsCache {
374 stats: RwLock<Option<Arc<GlobalStats>>>,
376 generation: RwLock<u64>,
378}
379
380impl GlobalStatsCache {
381 pub fn new() -> Self {
383 Self {
384 stats: RwLock::new(None),
385 generation: RwLock::new(0),
386 }
387 }
388
389 pub fn invalidate(&self) {
391 let mut current_gen = self.generation.write();
392 *current_gen += 1;
393 let mut stats = self.stats.write();
394 *stats = None;
395 }
396
397 pub fn generation(&self) -> u64 {
399 *self.generation.read()
400 }
401
402 pub fn get(&self) -> Option<Arc<GlobalStats>> {
404 self.stats.read().clone()
405 }
406
407 pub fn set(&self, stats: GlobalStats) {
409 let mut cached = self.stats.write();
410 *cached = Some(Arc::new(stats));
411 }
412
413 pub fn get_or_compute<F>(&self, compute: F) -> Arc<GlobalStats>
417 where
418 F: FnOnce(&mut GlobalStatsBuilder),
419 {
420 if let Some(stats) = self.get() {
422 return stats;
423 }
424
425 let current_gen = self.generation();
427 let mut builder = GlobalStatsBuilder::new();
428 compute(&mut builder);
429 let stats = Arc::new(builder.build(current_gen));
430
431 let mut cached = self.stats.write();
433 *cached = Some(Arc::clone(&stats));
434
435 stats
436 }
437
438 pub fn needs_rebuild(&self) -> bool {
440 self.stats.read().is_none()
441 }
442
443 pub fn set_stats(&self, stats: GlobalStats) {
445 let mut cached = self.stats.write();
446 *cached = Some(Arc::new(stats));
447 }
448}
449
450impl Default for GlobalStatsCache {
451 fn default() -> Self {
452 Self::new()
453 }
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459
460 #[test]
461 fn test_sparse_idf_computation() {
462 let mut builder = GlobalStatsBuilder::new();
463 builder.total_docs = 1000;
464 builder.add_sparse_df(Field(0), 42, 100); builder.add_sparse_df(Field(0), 43, 10); let stats = builder.build(1);
468
469 let idf_42 = stats.sparse_idf(Field(0), 42);
471 let idf_43 = stats.sparse_idf(Field(0), 43);
472
473 assert!(idf_43 > idf_42);
475 assert!((idf_42 - (1000.0_f32 / 100.0).ln()).abs() < 0.001);
476 assert!((idf_43 - (1000.0_f32 / 10.0).ln()).abs() < 0.001);
477 }
478
479 #[test]
480 fn test_text_idf_computation() {
481 let mut builder = GlobalStatsBuilder::new();
482 builder.total_docs = 10000;
483 builder.add_text_df(Field(0), "common".to_string(), 5000);
484 builder.add_text_df(Field(0), "rare".to_string(), 10);
485
486 let stats = builder.build(1);
487
488 let idf_common = stats.text_idf(Field(0), "common");
489 let idf_rare = stats.text_idf(Field(0), "rare");
490
491 assert!(idf_rare > idf_common);
493 }
494
495 #[test]
496 fn test_cache_invalidation() {
497 let cache = GlobalStatsCache::new();
498
499 assert!(cache.get().is_none());
501
502 let stats = cache.get_or_compute(|builder| {
504 builder.total_docs = 100;
505 });
506 assert_eq!(stats.total_docs(), 100);
507
508 assert!(cache.get().is_some());
510
511 cache.invalidate();
513 assert!(cache.get().is_none());
514 }
515}