1use std::sync::Arc;
13
14use parking_lot::RwLock;
15use rustc_hash::FxHashMap;
16
17use crate::dsl::Field;
18use crate::segment::SegmentReader;
19
20pub struct LazyGlobalStats {
26 segments: Vec<Arc<SegmentReader>>,
28 total_docs: u64,
30 sparse_idf_cache: RwLock<FxHashMap<u32, FxHashMap<u32, f32>>>,
32 text_idf_cache: RwLock<FxHashMap<u32, FxHashMap<String, f32>>>,
34 avg_field_len_cache: RwLock<FxHashMap<u32, f32>>,
36}
37
38impl LazyGlobalStats {
39 pub fn new(segments: Vec<Arc<SegmentReader>>) -> Self {
41 let total_docs: u64 = segments.iter().map(|s| s.num_docs() as u64).sum();
42 Self {
43 segments,
44 total_docs,
45 sparse_idf_cache: RwLock::new(FxHashMap::default()),
46 text_idf_cache: RwLock::new(FxHashMap::default()),
47 avg_field_len_cache: RwLock::new(FxHashMap::default()),
48 }
49 }
50
51 #[inline]
53 pub fn total_docs(&self) -> u64 {
54 self.total_docs
55 }
56
57 pub fn sparse_idf(&self, field: Field, dim_id: u32) -> f32 {
61 {
63 let cache = self.sparse_idf_cache.read();
64 if let Some(field_cache) = cache.get(&field.0)
65 && let Some(&idf) = field_cache.get(&dim_id)
66 {
67 return idf;
68 }
69 }
70
71 let df = self.compute_sparse_df(field, dim_id);
73 let total_vectors = self.compute_sparse_total_vectors(field);
76 let n = total_vectors.max(self.total_docs);
77 let idf = if df > 0 && n > 0 {
78 (n as f32 / df as f32).ln().max(0.0)
79 } else {
80 0.0
81 };
82
83 {
85 let mut cache = self.sparse_idf_cache.write();
86 cache.entry(field.0).or_default().insert(dim_id, idf);
87 }
88
89 idf
90 }
91
92 pub fn sparse_idf_weights(&self, field: Field, dim_ids: &[u32]) -> Vec<f32> {
94 dim_ids.iter().map(|&d| self.sparse_idf(field, d)).collect()
95 }
96
97 pub fn text_idf(&self, field: Field, term: &str) -> f32 {
101 {
103 let cache = self.text_idf_cache.read();
104 if let Some(field_cache) = cache.get(&field.0)
105 && let Some(&idf) = field_cache.get(term)
106 {
107 return idf;
108 }
109 }
110
111 let df = self.compute_text_df(field, term);
113 let n = self.total_docs as f32;
114 let df_f = df as f32;
115 let idf = if df > 0 {
116 ((n - df_f + 0.5) / (df_f + 0.5) + 1.0).ln()
117 } else {
118 0.0
119 };
120
121 {
123 let mut cache = self.text_idf_cache.write();
124 cache
125 .entry(field.0)
126 .or_default()
127 .insert(term.to_string(), idf);
128 }
129
130 idf
131 }
132
133 pub fn avg_field_len(&self, field: Field) -> f32 {
135 {
137 let cache = self.avg_field_len_cache.read();
138 if let Some(&avg) = cache.get(&field.0) {
139 return avg;
140 }
141 }
142
143 let mut weighted_sum = 0.0f64;
145 let mut total_weight = 0u64;
146
147 for segment in &self.segments {
148 let avg_len = segment.avg_field_len(field);
149 let doc_count = segment.num_docs() as u64;
150 if avg_len > 0.0 && doc_count > 0 {
151 weighted_sum += avg_len as f64 * doc_count as f64;
152 total_weight += doc_count;
153 }
154 }
155
156 let avg = if total_weight > 0 {
157 (weighted_sum / total_weight as f64) as f32
158 } else {
159 1.0
160 };
161
162 {
164 let mut cache = self.avg_field_len_cache.write();
165 cache.insert(field.0, avg);
166 }
167
168 avg
169 }
170
171 #[cfg(feature = "native")]
173 fn compute_sparse_df(&self, field: Field, dim_id: u32) -> u64 {
174 let mut df = 0u64;
175 for segment in &self.segments {
176 if let Some(sparse_index) = segment.sparse_indexes().get(&field.0)
177 && let Ok(Some(posting)) = sparse_index.get_posting_blocking(dim_id)
178 {
179 df += posting.doc_count() as u64;
180 }
181 }
182 df
183 }
184
185 #[cfg(not(feature = "native"))]
188 fn compute_sparse_df(&self, field: Field, dim_id: u32) -> u64 {
189 let mut df = 0u64;
190 for segment in &self.segments {
191 if let Some(sparse_index) = segment.sparse_indexes().get(&field.0)
192 && let Some(posting) = sparse_index.get_cached(dim_id)
193 {
194 df += posting.doc_count() as u64;
195 }
196 }
197 df
198 }
199
200 fn compute_sparse_total_vectors(&self, field: Field) -> u64 {
203 let mut total = 0u64;
204 for segment in &self.segments {
205 if let Some(sparse_index) = segment.sparse_indexes().get(&field.0) {
206 total += sparse_index.total_vectors as u64;
207 }
208 }
209 total
210 }
211
212 fn compute_text_df(&self, _field: Field, _term: &str) -> u64 {
217 0
221 }
222
223 pub fn num_segments(&self) -> usize {
225 self.segments.len()
226 }
227}
228
229impl std::fmt::Debug for LazyGlobalStats {
230 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231 f.debug_struct("LazyGlobalStats")
232 .field("total_docs", &self.total_docs)
233 .field("num_segments", &self.segments.len())
234 .field("sparse_cache_fields", &self.sparse_idf_cache.read().len())
235 .field("text_cache_fields", &self.text_idf_cache.read().len())
236 .finish()
237 }
238}
239
240#[derive(Debug)]
244pub struct GlobalStats {
245 total_docs: u64,
247 sparse_stats: FxHashMap<u32, SparseFieldStats>,
249 text_stats: FxHashMap<u32, TextFieldStats>,
251 generation: u64,
253}
254
255#[derive(Debug, Default)]
257pub struct SparseFieldStats {
258 pub doc_freqs: FxHashMap<u32, u64>,
260}
261
262#[derive(Debug, Default)]
264pub struct TextFieldStats {
265 pub doc_freqs: FxHashMap<String, u64>,
267 pub avg_field_len: f32,
269}
270
271impl GlobalStats {
272 pub fn new() -> Self {
274 Self {
275 total_docs: 0,
276 sparse_stats: FxHashMap::default(),
277 text_stats: FxHashMap::default(),
278 generation: 0,
279 }
280 }
281
282 #[inline]
284 pub fn total_docs(&self) -> u64 {
285 self.total_docs
286 }
287
288 #[inline]
290 pub fn sparse_idf(&self, field: Field, dim_id: u32) -> f32 {
291 if let Some(stats) = self.sparse_stats.get(&field.0)
292 && let Some(&df) = stats.doc_freqs.get(&dim_id)
293 && df > 0
294 {
295 return (self.total_docs as f32 / df as f32).ln();
296 }
297 0.0
298 }
299
300 pub fn sparse_idf_weights(&self, field: Field, dim_ids: &[u32]) -> Vec<f32> {
302 dim_ids.iter().map(|&d| self.sparse_idf(field, d)).collect()
303 }
304
305 #[inline]
307 pub fn text_idf(&self, field: Field, term: &str) -> f32 {
308 if let Some(stats) = self.text_stats.get(&field.0)
309 && let Some(&df) = stats.doc_freqs.get(term)
310 {
311 let n = self.total_docs as f32;
312 let df = df as f32;
313 return ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
314 }
315 0.0
316 }
317
318 #[inline]
320 pub fn avg_field_len(&self, field: Field) -> f32 {
321 self.text_stats
322 .get(&field.0)
323 .map(|s| s.avg_field_len)
324 .unwrap_or(1.0)
325 }
326
327 #[inline]
329 pub fn generation(&self) -> u64 {
330 self.generation
331 }
332}
333
334impl Default for GlobalStats {
335 fn default() -> Self {
336 Self::new()
337 }
338}
339
340pub struct GlobalStatsBuilder {
342 pub total_docs: u64,
344 sparse_stats: FxHashMap<u32, SparseFieldStats>,
345 text_stats: FxHashMap<u32, TextFieldStats>,
346}
347
348impl GlobalStatsBuilder {
349 pub fn new() -> Self {
351 Self {
352 total_docs: 0,
353 sparse_stats: FxHashMap::default(),
354 text_stats: FxHashMap::default(),
355 }
356 }
357
358 pub fn add_segment(&mut self, reader: &SegmentReader) {
360 self.total_docs += reader.num_docs() as u64;
361
362 }
365
366 pub fn add_sparse_df(&mut self, field: Field, dim_id: u32, doc_count: u64) {
368 let stats = self.sparse_stats.entry(field.0).or_default();
369 *stats.doc_freqs.entry(dim_id).or_insert(0) += doc_count;
370 }
371
372 pub fn add_text_df(&mut self, field: Field, term: String, doc_count: u64) {
374 let stats = self.text_stats.entry(field.0).or_default();
375 *stats.doc_freqs.entry(term).or_insert(0) += doc_count;
376 }
377
378 pub fn set_avg_field_len(&mut self, field: Field, avg_len: f32) {
380 let stats = self.text_stats.entry(field.0).or_default();
381 stats.avg_field_len = avg_len;
382 }
383
384 pub fn build(self, generation: u64) -> GlobalStats {
386 GlobalStats {
387 total_docs: self.total_docs,
388 sparse_stats: self.sparse_stats,
389 text_stats: self.text_stats,
390 generation,
391 }
392 }
393}
394
395impl Default for GlobalStatsBuilder {
396 fn default() -> Self {
397 Self::new()
398 }
399}
400
401pub struct GlobalStatsCache {
406 stats: RwLock<Option<Arc<GlobalStats>>>,
408 generation: RwLock<u64>,
410}
411
412impl GlobalStatsCache {
413 pub fn new() -> Self {
415 Self {
416 stats: RwLock::new(None),
417 generation: RwLock::new(0),
418 }
419 }
420
421 pub fn invalidate(&self) {
423 let mut current_gen = self.generation.write();
424 *current_gen += 1;
425 let mut stats = self.stats.write();
426 *stats = None;
427 }
428
429 pub fn generation(&self) -> u64 {
431 *self.generation.read()
432 }
433
434 pub fn get(&self) -> Option<Arc<GlobalStats>> {
436 self.stats.read().clone()
437 }
438
439 pub fn set(&self, stats: GlobalStats) {
441 let mut cached = self.stats.write();
442 *cached = Some(Arc::new(stats));
443 }
444
445 pub fn get_or_compute<F>(&self, compute: F) -> Arc<GlobalStats>
449 where
450 F: FnOnce(&mut GlobalStatsBuilder),
451 {
452 if let Some(stats) = self.get() {
454 return stats;
455 }
456
457 let current_gen = self.generation();
459 let mut builder = GlobalStatsBuilder::new();
460 compute(&mut builder);
461 let stats = Arc::new(builder.build(current_gen));
462
463 let mut cached = self.stats.write();
465 *cached = Some(Arc::clone(&stats));
466
467 stats
468 }
469
470 pub fn needs_rebuild(&self) -> bool {
472 self.stats.read().is_none()
473 }
474
475 pub fn set_stats(&self, stats: GlobalStats) {
477 let mut cached = self.stats.write();
478 *cached = Some(Arc::new(stats));
479 }
480}
481
482impl Default for GlobalStatsCache {
483 fn default() -> Self {
484 Self::new()
485 }
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491
492 #[test]
493 fn test_sparse_idf_computation() {
494 let mut builder = GlobalStatsBuilder::new();
495 builder.total_docs = 1000;
496 builder.add_sparse_df(Field(0), 42, 100); builder.add_sparse_df(Field(0), 43, 10); let stats = builder.build(1);
500
501 let idf_42 = stats.sparse_idf(Field(0), 42);
503 let idf_43 = stats.sparse_idf(Field(0), 43);
504
505 assert!(idf_43 > idf_42);
507 assert!((idf_42 - (1000.0_f32 / 100.0).ln()).abs() < 0.001);
508 assert!((idf_43 - (1000.0_f32 / 10.0).ln()).abs() < 0.001);
509 }
510
511 #[test]
512 fn test_text_idf_computation() {
513 let mut builder = GlobalStatsBuilder::new();
514 builder.total_docs = 10000;
515 builder.add_text_df(Field(0), "common".to_string(), 5000);
516 builder.add_text_df(Field(0), "rare".to_string(), 10);
517
518 let stats = builder.build(1);
519
520 let idf_common = stats.text_idf(Field(0), "common");
521 let idf_rare = stats.text_idf(Field(0), "rare");
522
523 assert!(idf_rare > idf_common);
525 }
526
527 #[test]
528 fn test_cache_invalidation() {
529 let cache = GlobalStatsCache::new();
530
531 assert!(cache.get().is_none());
533
534 let stats = cache.get_or_compute(|builder| {
536 builder.total_docs = 100;
537 });
538 assert_eq!(stats.total_docs(), 100);
539
540 assert!(cache.get().is_some());
542
543 cache.invalidate();
545 assert!(cache.get().is_none());
546 }
547}