1use std::sync::Arc;
13
14use parking_lot::RwLock;
15use rustc_hash::FxHashMap;
16
17use crate::dsl::Field;
18use crate::segment::SegmentReader;
19
20pub struct LazyGlobalStats {
26 segments: Vec<Arc<SegmentReader>>,
28 total_docs: u64,
30 sparse_idf_cache: RwLock<FxHashMap<u32, FxHashMap<u32, f32>>>,
32 text_idf_cache: RwLock<FxHashMap<u32, FxHashMap<String, f32>>>,
34 avg_field_len_cache: RwLock<FxHashMap<u32, f32>>,
36}
37
38impl LazyGlobalStats {
39 pub fn new(segments: Vec<Arc<SegmentReader>>) -> Self {
41 let total_docs: u64 = segments.iter().map(|s| s.num_docs() as u64).sum();
42 Self {
43 segments,
44 total_docs,
45 sparse_idf_cache: RwLock::new(FxHashMap::default()),
46 text_idf_cache: RwLock::new(FxHashMap::default()),
47 avg_field_len_cache: RwLock::new(FxHashMap::default()),
48 }
49 }
50
51 #[inline]
53 pub fn total_docs(&self) -> u64 {
54 self.total_docs
55 }
56
57 pub fn sparse_idf(&self, field: Field, dim_id: u32) -> f32 {
61 {
63 let cache = self.sparse_idf_cache.read();
64 if let Some(field_cache) = cache.get(&field.0)
65 && let Some(&idf) = field_cache.get(&dim_id)
66 {
67 return idf;
68 }
69 }
70
71 let df = self.compute_sparse_df(field, dim_id);
73 let total_vectors = self.compute_sparse_total_vectors(field);
76 let n = total_vectors.max(self.total_docs);
77 let idf = if df > 0 && n > 0 {
78 (n as f32 / df as f32).ln().max(0.0)
79 } else {
80 0.0
81 };
82
83 {
85 let mut cache = self.sparse_idf_cache.write();
86 cache.entry(field.0).or_default().insert(dim_id, idf);
87 }
88
89 idf
90 }
91
92 pub fn sparse_idf_weights(&self, field: Field, dim_ids: &[u32]) -> Vec<f32> {
94 dim_ids.iter().map(|&d| self.sparse_idf(field, d)).collect()
95 }
96
97 pub fn text_idf(&self, field: Field, term: &str) -> f32 {
101 {
103 let cache = self.text_idf_cache.read();
104 if let Some(field_cache) = cache.get(&field.0)
105 && let Some(&idf) = field_cache.get(term)
106 {
107 return idf;
108 }
109 }
110
111 let df = self.compute_text_df(field, term);
113 let n = self.total_docs as f32;
114 let df_f = df as f32;
115 let idf = if df > 0 {
116 ((n - df_f + 0.5) / (df_f + 0.5) + 1.0).ln()
117 } else {
118 0.0
119 };
120
121 {
123 let mut cache = self.text_idf_cache.write();
124 cache
125 .entry(field.0)
126 .or_default()
127 .insert(term.to_string(), idf);
128 }
129
130 idf
131 }
132
133 pub fn avg_field_len(&self, field: Field) -> f32 {
135 {
137 let cache = self.avg_field_len_cache.read();
138 if let Some(&avg) = cache.get(&field.0) {
139 return avg;
140 }
141 }
142
143 let mut weighted_sum = 0.0f64;
145 let mut total_weight = 0u64;
146
147 for segment in &self.segments {
148 let avg_len = segment.avg_field_len(field);
149 let doc_count = segment.num_docs() as u64;
150 if avg_len > 0.0 && doc_count > 0 {
151 weighted_sum += avg_len as f64 * doc_count as f64;
152 total_weight += doc_count;
153 }
154 }
155
156 let avg = if total_weight > 0 {
157 (weighted_sum / total_weight as f64) as f32
158 } else {
159 1.0
160 };
161
162 {
164 let mut cache = self.avg_field_len_cache.write();
165 cache.insert(field.0, avg);
166 }
167
168 avg
169 }
170
171 fn compute_sparse_df(&self, field: Field, dim_id: u32) -> u64 {
174 let mut df = 0u64;
175 for segment in &self.segments {
176 if let Some(sparse_index) = segment.sparse_indexes().get(&field.0) {
177 df += sparse_index.doc_count(dim_id) as u64;
178 }
179 }
180 df
181 }
182
183 fn compute_sparse_total_vectors(&self, field: Field) -> u64 {
186 let mut total = 0u64;
187 for segment in &self.segments {
188 if let Some(sparse_index) = segment.sparse_indexes().get(&field.0) {
189 total += sparse_index.total_vectors as u64;
190 }
191 }
192 total
193 }
194
195 fn compute_text_df(&self, _field: Field, _term: &str) -> u64 {
200 0
204 }
205
206 pub fn num_segments(&self) -> usize {
208 self.segments.len()
209 }
210}
211
212impl std::fmt::Debug for LazyGlobalStats {
213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 f.debug_struct("LazyGlobalStats")
215 .field("total_docs", &self.total_docs)
216 .field("num_segments", &self.segments.len())
217 .field("sparse_cache_fields", &self.sparse_idf_cache.read().len())
218 .field("text_cache_fields", &self.text_idf_cache.read().len())
219 .finish()
220 }
221}
222
223#[derive(Debug)]
227pub struct GlobalStats {
228 total_docs: u64,
230 sparse_stats: FxHashMap<u32, SparseFieldStats>,
232 text_stats: FxHashMap<u32, TextFieldStats>,
234 generation: u64,
236}
237
238#[derive(Debug, Default)]
240pub struct SparseFieldStats {
241 pub doc_freqs: FxHashMap<u32, u64>,
243}
244
245#[derive(Debug, Default)]
247pub struct TextFieldStats {
248 pub doc_freqs: FxHashMap<String, u64>,
250 pub avg_field_len: f32,
252}
253
254impl GlobalStats {
255 pub fn new() -> Self {
257 Self {
258 total_docs: 0,
259 sparse_stats: FxHashMap::default(),
260 text_stats: FxHashMap::default(),
261 generation: 0,
262 }
263 }
264
265 #[inline]
267 pub fn total_docs(&self) -> u64 {
268 self.total_docs
269 }
270
271 #[inline]
273 pub fn sparse_idf(&self, field: Field, dim_id: u32) -> f32 {
274 if let Some(stats) = self.sparse_stats.get(&field.0)
275 && let Some(&df) = stats.doc_freqs.get(&dim_id)
276 && df > 0
277 {
278 return (self.total_docs as f32 / df as f32).ln();
279 }
280 0.0
281 }
282
283 pub fn sparse_idf_weights(&self, field: Field, dim_ids: &[u32]) -> Vec<f32> {
285 dim_ids.iter().map(|&d| self.sparse_idf(field, d)).collect()
286 }
287
288 #[inline]
290 pub fn text_idf(&self, field: Field, term: &str) -> f32 {
291 if let Some(stats) = self.text_stats.get(&field.0)
292 && let Some(&df) = stats.doc_freqs.get(term)
293 {
294 let n = self.total_docs as f32;
295 let df = df as f32;
296 return ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
297 }
298 0.0
299 }
300
301 #[inline]
303 pub fn avg_field_len(&self, field: Field) -> f32 {
304 self.text_stats
305 .get(&field.0)
306 .map(|s| s.avg_field_len)
307 .unwrap_or(1.0)
308 }
309
310 #[inline]
312 pub fn generation(&self) -> u64 {
313 self.generation
314 }
315}
316
317impl Default for GlobalStats {
318 fn default() -> Self {
319 Self::new()
320 }
321}
322
323pub struct GlobalStatsBuilder {
325 pub total_docs: u64,
327 sparse_stats: FxHashMap<u32, SparseFieldStats>,
328 text_stats: FxHashMap<u32, TextFieldStats>,
329}
330
331impl GlobalStatsBuilder {
332 pub fn new() -> Self {
334 Self {
335 total_docs: 0,
336 sparse_stats: FxHashMap::default(),
337 text_stats: FxHashMap::default(),
338 }
339 }
340
341 pub fn add_segment(&mut self, reader: &SegmentReader) {
343 self.total_docs += reader.num_docs() as u64;
344
345 }
348
349 pub fn add_sparse_df(&mut self, field: Field, dim_id: u32, doc_count: u64) {
351 let stats = self.sparse_stats.entry(field.0).or_default();
352 *stats.doc_freqs.entry(dim_id).or_insert(0) += doc_count;
353 }
354
355 pub fn add_text_df(&mut self, field: Field, term: String, doc_count: u64) {
357 let stats = self.text_stats.entry(field.0).or_default();
358 *stats.doc_freqs.entry(term).or_insert(0) += doc_count;
359 }
360
361 pub fn set_avg_field_len(&mut self, field: Field, avg_len: f32) {
363 let stats = self.text_stats.entry(field.0).or_default();
364 stats.avg_field_len = avg_len;
365 }
366
367 pub fn build(self, generation: u64) -> GlobalStats {
369 GlobalStats {
370 total_docs: self.total_docs,
371 sparse_stats: self.sparse_stats,
372 text_stats: self.text_stats,
373 generation,
374 }
375 }
376}
377
378impl Default for GlobalStatsBuilder {
379 fn default() -> Self {
380 Self::new()
381 }
382}
383
384pub struct GlobalStatsCache {
389 stats: RwLock<Option<Arc<GlobalStats>>>,
391 generation: RwLock<u64>,
393}
394
395impl GlobalStatsCache {
396 pub fn new() -> Self {
398 Self {
399 stats: RwLock::new(None),
400 generation: RwLock::new(0),
401 }
402 }
403
404 pub fn invalidate(&self) {
406 let mut current_gen = self.generation.write();
407 *current_gen += 1;
408 let mut stats = self.stats.write();
409 *stats = None;
410 }
411
412 pub fn generation(&self) -> u64 {
414 *self.generation.read()
415 }
416
417 pub fn get(&self) -> Option<Arc<GlobalStats>> {
419 self.stats.read().clone()
420 }
421
422 pub fn set(&self, stats: GlobalStats) {
424 let mut cached = self.stats.write();
425 *cached = Some(Arc::new(stats));
426 }
427
428 pub fn get_or_compute<F>(&self, compute: F) -> Arc<GlobalStats>
432 where
433 F: FnOnce(&mut GlobalStatsBuilder),
434 {
435 if let Some(stats) = self.get() {
437 return stats;
438 }
439
440 let current_gen = self.generation();
442 let mut builder = GlobalStatsBuilder::new();
443 compute(&mut builder);
444 let stats = Arc::new(builder.build(current_gen));
445
446 let mut cached = self.stats.write();
448 *cached = Some(Arc::clone(&stats));
449
450 stats
451 }
452
453 pub fn needs_rebuild(&self) -> bool {
455 self.stats.read().is_none()
456 }
457
458 pub fn set_stats(&self, stats: GlobalStats) {
460 let mut cached = self.stats.write();
461 *cached = Some(Arc::new(stats));
462 }
463}
464
465impl Default for GlobalStatsCache {
466 fn default() -> Self {
467 Self::new()
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474
475 #[test]
476 fn test_sparse_idf_computation() {
477 let mut builder = GlobalStatsBuilder::new();
478 builder.total_docs = 1000;
479 builder.add_sparse_df(Field(0), 42, 100); builder.add_sparse_df(Field(0), 43, 10); let stats = builder.build(1);
483
484 let idf_42 = stats.sparse_idf(Field(0), 42);
486 let idf_43 = stats.sparse_idf(Field(0), 43);
487
488 assert!(idf_43 > idf_42);
490 assert!((idf_42 - (1000.0_f32 / 100.0).ln()).abs() < 0.001);
491 assert!((idf_43 - (1000.0_f32 / 10.0).ln()).abs() < 0.001);
492 }
493
494 #[test]
495 fn test_text_idf_computation() {
496 let mut builder = GlobalStatsBuilder::new();
497 builder.total_docs = 10000;
498 builder.add_text_df(Field(0), "common".to_string(), 5000);
499 builder.add_text_df(Field(0), "rare".to_string(), 10);
500
501 let stats = builder.build(1);
502
503 let idf_common = stats.text_idf(Field(0), "common");
504 let idf_rare = stats.text_idf(Field(0), "rare");
505
506 assert!(idf_rare > idf_common);
508 }
509
510 #[test]
511 fn test_cache_invalidation() {
512 let cache = GlobalStatsCache::new();
513
514 assert!(cache.get().is_none());
516
517 let stats = cache.get_or_compute(|builder| {
519 builder.total_docs = 100;
520 });
521 assert_eq!(stats.total_docs(), 100);
522
523 assert!(cache.get().is_some());
525
526 cache.invalidate();
528 assert!(cache.get().is_none());
529 }
530}