1use std::sync::Arc;
13
14use parking_lot::RwLock;
15use rustc_hash::FxHashMap;
16
17use crate::dsl::Field;
18use crate::segment::SegmentReader;
19
20pub struct LazyGlobalStats {
26 segments: Vec<Arc<SegmentReader>>,
28 total_docs: u64,
30 sparse_idf_cache: RwLock<FxHashMap<u32, FxHashMap<u32, f32>>>,
32 sparse_total_vectors_cache: RwLock<FxHashMap<u32, u64>>,
34 text_idf_cache: RwLock<FxHashMap<u32, FxHashMap<String, f32>>>,
36 avg_field_len_cache: RwLock<FxHashMap<u32, f32>>,
38}
39
40impl LazyGlobalStats {
41 pub fn new(segments: Vec<Arc<SegmentReader>>) -> Self {
43 let total_docs: u64 = segments.iter().map(|s| s.num_docs() as u64).sum();
44 Self {
45 segments,
46 total_docs,
47 sparse_idf_cache: RwLock::new(FxHashMap::default()),
48 sparse_total_vectors_cache: RwLock::new(FxHashMap::default()),
49 text_idf_cache: RwLock::new(FxHashMap::default()),
50 avg_field_len_cache: RwLock::new(FxHashMap::default()),
51 }
52 }
53
54 #[inline]
56 pub fn total_docs(&self) -> u64 {
57 self.total_docs
58 }
59
60 pub fn sparse_idf(&self, field: Field, dim_id: u32) -> f32 {
64 {
66 let cache = self.sparse_idf_cache.read();
67 if let Some(field_cache) = cache.get(&field.0)
68 && let Some(&idf) = field_cache.get(&dim_id)
69 {
70 return idf;
71 }
72 }
73
74 let df = self.compute_sparse_df(field, dim_id);
76 let n = self.cached_sparse_n(field);
77 let idf = if df > 0 && n > 0 {
78 (n as f32 / df as f32).ln().max(0.0)
79 } else {
80 0.0
81 };
82
83 {
85 let mut cache = self.sparse_idf_cache.write();
86 cache.entry(field.0).or_default().insert(dim_id, idf);
87 }
88
89 idf
90 }
91
92 pub fn sparse_idf_weights(&self, field: Field, dim_ids: &[u32]) -> Vec<f32> {
97 let mut result = vec![0.0f32; dim_ids.len()];
99 let mut misses: Vec<usize> = Vec::new();
100 {
101 let cache = self.sparse_idf_cache.read();
102 if let Some(field_cache) = cache.get(&field.0) {
103 for (i, &dim_id) in dim_ids.iter().enumerate() {
104 if let Some(&idf) = field_cache.get(&dim_id) {
105 result[i] = idf;
106 } else {
107 misses.push(i);
108 }
109 }
110 } else {
111 misses.extend(0..dim_ids.len());
112 }
113 }
114
115 if misses.is_empty() {
116 return result;
117 }
118
119 let n = self.cached_sparse_n(field);
121
122 let mut new_entries: Vec<(u32, f32)> = Vec::with_capacity(misses.len());
124 for &i in &misses {
125 let dim_id = dim_ids[i];
126 let df = self.compute_sparse_df(field, dim_id);
127 let idf = if df > 0 && n > 0 {
128 (n as f32 / df as f32).ln().max(0.0)
129 } else {
130 0.0
131 };
132 result[i] = idf;
133 new_entries.push((dim_id, idf));
134 }
135
136 {
138 let mut cache = self.sparse_idf_cache.write();
139 let field_cache = cache.entry(field.0).or_default();
140 for (dim_id, idf) in new_entries {
141 field_cache.insert(dim_id, idf);
142 }
143 }
144
145 result
146 }
147
148 fn cached_sparse_n(&self, field: Field) -> u64 {
151 {
153 let cache = self.sparse_total_vectors_cache.read();
154 if let Some(&tv) = cache.get(&field.0) {
155 return tv.max(self.total_docs);
156 }
157 }
158 let tv = self.compute_sparse_total_vectors(field);
160 self.sparse_total_vectors_cache.write().insert(field.0, tv);
161 tv.max(self.total_docs)
162 }
163
164 pub fn text_idf(&self, field: Field, term: &str) -> f32 {
168 {
170 let cache = self.text_idf_cache.read();
171 if let Some(field_cache) = cache.get(&field.0)
172 && let Some(&idf) = field_cache.get(term)
173 {
174 return idf;
175 }
176 }
177
178 let df = self.compute_text_df(field, term);
180 let n = self.total_docs as f32;
181 let df_f = df as f32;
182 let idf = if df > 0 {
183 ((n - df_f + 0.5) / (df_f + 0.5) + 1.0).ln()
184 } else {
185 0.0
186 };
187
188 {
190 let mut cache = self.text_idf_cache.write();
191 cache
192 .entry(field.0)
193 .or_default()
194 .insert(term.to_string(), idf);
195 }
196
197 idf
198 }
199
200 pub fn avg_field_len(&self, field: Field) -> f32 {
202 {
204 let cache = self.avg_field_len_cache.read();
205 if let Some(&avg) = cache.get(&field.0) {
206 return avg;
207 }
208 }
209
210 let mut weighted_sum = 0.0f64;
212 let mut total_weight = 0u64;
213
214 for segment in &self.segments {
215 let avg_len = segment.avg_field_len(field);
216 let doc_count = segment.num_docs() as u64;
217 if avg_len > 0.0 && doc_count > 0 {
218 weighted_sum += avg_len as f64 * doc_count as f64;
219 total_weight += doc_count;
220 }
221 }
222
223 let avg = if total_weight > 0 {
224 (weighted_sum / total_weight as f64) as f32
225 } else {
226 1.0
227 };
228
229 {
231 let mut cache = self.avg_field_len_cache.write();
232 cache.insert(field.0, avg);
233 }
234
235 avg
236 }
237
238 fn compute_sparse_df(&self, field: Field, dim_id: u32) -> u64 {
241 let mut df = 0u64;
242 for segment in &self.segments {
243 if let Some(sparse_index) = segment.sparse_indexes().get(&field.0) {
244 df += sparse_index.doc_count(dim_id) as u64;
245 }
246 }
247 df
248 }
249
250 fn compute_sparse_total_vectors(&self, field: Field) -> u64 {
253 let mut total = 0u64;
254 for segment in &self.segments {
255 if let Some(sparse_index) = segment.sparse_indexes().get(&field.0) {
256 total += sparse_index.total_vectors as u64;
257 }
258 }
259 total
260 }
261
262 fn compute_text_df(&self, _field: Field, _term: &str) -> u64 {
267 0
271 }
272
273 pub fn num_segments(&self) -> usize {
275 self.segments.len()
276 }
277}
278
279impl std::fmt::Debug for LazyGlobalStats {
280 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281 f.debug_struct("LazyGlobalStats")
282 .field("total_docs", &self.total_docs)
283 .field("num_segments", &self.segments.len())
284 .field("sparse_cache_fields", &self.sparse_idf_cache.read().len())
285 .field("text_cache_fields", &self.text_idf_cache.read().len())
286 .finish()
287 }
288}
289
290#[derive(Debug)]
294pub struct GlobalStats {
295 total_docs: u64,
297 sparse_stats: FxHashMap<u32, SparseFieldStats>,
299 text_stats: FxHashMap<u32, TextFieldStats>,
301 generation: u64,
303}
304
305#[derive(Debug, Default)]
307pub struct SparseFieldStats {
308 pub doc_freqs: FxHashMap<u32, u64>,
310}
311
312#[derive(Debug, Default)]
314pub struct TextFieldStats {
315 pub doc_freqs: FxHashMap<String, u64>,
317 pub avg_field_len: f32,
319}
320
321impl GlobalStats {
322 pub fn new() -> Self {
324 Self {
325 total_docs: 0,
326 sparse_stats: FxHashMap::default(),
327 text_stats: FxHashMap::default(),
328 generation: 0,
329 }
330 }
331
332 #[inline]
334 pub fn total_docs(&self) -> u64 {
335 self.total_docs
336 }
337
338 #[inline]
340 pub fn sparse_idf(&self, field: Field, dim_id: u32) -> f32 {
341 if let Some(stats) = self.sparse_stats.get(&field.0)
342 && let Some(&df) = stats.doc_freqs.get(&dim_id)
343 && df > 0
344 {
345 return (self.total_docs as f32 / df as f32).ln();
346 }
347 0.0
348 }
349
350 pub fn sparse_idf_weights(&self, field: Field, dim_ids: &[u32]) -> Vec<f32> {
352 dim_ids.iter().map(|&d| self.sparse_idf(field, d)).collect()
353 }
354
355 #[inline]
357 pub fn text_idf(&self, field: Field, term: &str) -> f32 {
358 if let Some(stats) = self.text_stats.get(&field.0)
359 && let Some(&df) = stats.doc_freqs.get(term)
360 {
361 let n = self.total_docs as f32;
362 let df = df as f32;
363 return ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
364 }
365 0.0
366 }
367
368 #[inline]
370 pub fn avg_field_len(&self, field: Field) -> f32 {
371 self.text_stats
372 .get(&field.0)
373 .map(|s| s.avg_field_len)
374 .unwrap_or(1.0)
375 }
376
377 #[inline]
379 pub fn generation(&self) -> u64 {
380 self.generation
381 }
382}
383
384impl Default for GlobalStats {
385 fn default() -> Self {
386 Self::new()
387 }
388}
389
390pub struct GlobalStatsBuilder {
392 pub total_docs: u64,
394 sparse_stats: FxHashMap<u32, SparseFieldStats>,
395 text_stats: FxHashMap<u32, TextFieldStats>,
396}
397
398impl GlobalStatsBuilder {
399 pub fn new() -> Self {
401 Self {
402 total_docs: 0,
403 sparse_stats: FxHashMap::default(),
404 text_stats: FxHashMap::default(),
405 }
406 }
407
408 pub fn add_segment(&mut self, reader: &SegmentReader) {
410 self.total_docs += reader.num_docs() as u64;
411
412 }
415
416 pub fn add_sparse_df(&mut self, field: Field, dim_id: u32, doc_count: u64) {
418 let stats = self.sparse_stats.entry(field.0).or_default();
419 *stats.doc_freqs.entry(dim_id).or_insert(0) += doc_count;
420 }
421
422 pub fn add_text_df(&mut self, field: Field, term: String, doc_count: u64) {
424 let stats = self.text_stats.entry(field.0).or_default();
425 *stats.doc_freqs.entry(term).or_insert(0) += doc_count;
426 }
427
428 pub fn set_avg_field_len(&mut self, field: Field, avg_len: f32) {
430 let stats = self.text_stats.entry(field.0).or_default();
431 stats.avg_field_len = avg_len;
432 }
433
434 pub fn build(self, generation: u64) -> GlobalStats {
436 GlobalStats {
437 total_docs: self.total_docs,
438 sparse_stats: self.sparse_stats,
439 text_stats: self.text_stats,
440 generation,
441 }
442 }
443}
444
445impl Default for GlobalStatsBuilder {
446 fn default() -> Self {
447 Self::new()
448 }
449}
450
451pub struct GlobalStatsCache {
456 stats: RwLock<Option<Arc<GlobalStats>>>,
458 generation: RwLock<u64>,
460}
461
462impl GlobalStatsCache {
463 pub fn new() -> Self {
465 Self {
466 stats: RwLock::new(None),
467 generation: RwLock::new(0),
468 }
469 }
470
471 pub fn invalidate(&self) {
473 let mut current_gen = self.generation.write();
474 *current_gen += 1;
475 let mut stats = self.stats.write();
476 *stats = None;
477 }
478
479 pub fn generation(&self) -> u64 {
481 *self.generation.read()
482 }
483
484 pub fn get(&self) -> Option<Arc<GlobalStats>> {
486 self.stats.read().clone()
487 }
488
489 pub fn set(&self, stats: GlobalStats) {
491 let mut cached = self.stats.write();
492 *cached = Some(Arc::new(stats));
493 }
494
495 pub fn get_or_compute<F>(&self, compute: F) -> Arc<GlobalStats>
499 where
500 F: FnOnce(&mut GlobalStatsBuilder),
501 {
502 if let Some(stats) = self.get() {
504 return stats;
505 }
506
507 let current_gen = self.generation();
509 let mut builder = GlobalStatsBuilder::new();
510 compute(&mut builder);
511 let stats = Arc::new(builder.build(current_gen));
512
513 let mut cached = self.stats.write();
515 *cached = Some(Arc::clone(&stats));
516
517 stats
518 }
519
520 pub fn needs_rebuild(&self) -> bool {
522 self.stats.read().is_none()
523 }
524
525 pub fn set_stats(&self, stats: GlobalStats) {
527 let mut cached = self.stats.write();
528 *cached = Some(Arc::new(stats));
529 }
530}
531
532impl Default for GlobalStatsCache {
533 fn default() -> Self {
534 Self::new()
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541
542 #[test]
543 fn test_sparse_idf_computation() {
544 let mut builder = GlobalStatsBuilder::new();
545 builder.total_docs = 1000;
546 builder.add_sparse_df(Field(0), 42, 100); builder.add_sparse_df(Field(0), 43, 10); let stats = builder.build(1);
550
551 let idf_42 = stats.sparse_idf(Field(0), 42);
553 let idf_43 = stats.sparse_idf(Field(0), 43);
554
555 assert!(idf_43 > idf_42);
557 assert!((idf_42 - (1000.0_f32 / 100.0).ln()).abs() < 0.001);
558 assert!((idf_43 - (1000.0_f32 / 10.0).ln()).abs() < 0.001);
559 }
560
561 #[test]
562 fn test_text_idf_computation() {
563 let mut builder = GlobalStatsBuilder::new();
564 builder.total_docs = 10000;
565 builder.add_text_df(Field(0), "common".to_string(), 5000);
566 builder.add_text_df(Field(0), "rare".to_string(), 10);
567
568 let stats = builder.build(1);
569
570 let idf_common = stats.text_idf(Field(0), "common");
571 let idf_rare = stats.text_idf(Field(0), "rare");
572
573 assert!(idf_rare > idf_common);
575 }
576
577 #[test]
578 fn test_cache_invalidation() {
579 let cache = GlobalStatsCache::new();
580
581 assert!(cache.get().is_none());
583
584 let stats = cache.get_or_compute(|builder| {
586 builder.total_docs = 100;
587 });
588 assert_eq!(stats.total_docs(), 100);
589
590 assert!(cache.get().is_some());
592
593 cache.invalidate();
595 assert!(cache.get().is_none());
596 }
597}