1use std::sync::Arc;
13
14use parking_lot::RwLock;
15use rustc_hash::FxHashMap;
16
17use crate::dsl::Field;
18use crate::segment::SegmentReader;
19
20pub struct LazyGlobalStats {
26 segments: Vec<Arc<SegmentReader>>,
28 total_docs: u64,
30 sparse_idf_cache: RwLock<FxHashMap<u32, FxHashMap<u32, f32>>>,
32 text_idf_cache: RwLock<FxHashMap<u32, FxHashMap<String, f32>>>,
34 avg_field_len_cache: RwLock<FxHashMap<u32, f32>>,
36}
37
38impl LazyGlobalStats {
39 pub fn new(segments: Vec<Arc<SegmentReader>>) -> Self {
41 let total_docs: u64 = segments.iter().map(|s| s.num_docs() as u64).sum();
42 Self {
43 segments,
44 total_docs,
45 sparse_idf_cache: RwLock::new(FxHashMap::default()),
46 text_idf_cache: RwLock::new(FxHashMap::default()),
47 avg_field_len_cache: RwLock::new(FxHashMap::default()),
48 }
49 }
50
51 #[inline]
53 pub fn total_docs(&self) -> u64 {
54 self.total_docs
55 }
56
57 pub fn sparse_idf(&self, field: Field, dim_id: u32) -> f32 {
61 {
63 let cache = self.sparse_idf_cache.read();
64 if let Some(field_cache) = cache.get(&field.0)
65 && let Some(&idf) = field_cache.get(&dim_id)
66 {
67 return idf;
68 }
69 }
70
71 let df = self.compute_sparse_df(field, dim_id);
73 let total_vectors = self.compute_sparse_total_vectors(field);
76 let n = total_vectors.max(self.total_docs);
77 let idf = if df > 0 && n > 0 {
78 (n as f32 / df as f32).ln().max(0.0)
79 } else {
80 0.0
81 };
82
83 {
85 let mut cache = self.sparse_idf_cache.write();
86 cache.entry(field.0).or_default().insert(dim_id, idf);
87 }
88
89 idf
90 }
91
92 pub fn sparse_idf_weights(&self, field: Field, dim_ids: &[u32]) -> Vec<f32> {
94 dim_ids.iter().map(|&d| self.sparse_idf(field, d)).collect()
95 }
96
97 pub fn text_idf(&self, field: Field, term: &str) -> f32 {
101 {
103 let cache = self.text_idf_cache.read();
104 if let Some(field_cache) = cache.get(&field.0)
105 && let Some(&idf) = field_cache.get(term)
106 {
107 return idf;
108 }
109 }
110
111 let df = self.compute_text_df(field, term);
113 let n = self.total_docs as f32;
114 let df_f = df as f32;
115 let idf = if df > 0 {
116 ((n - df_f + 0.5) / (df_f + 0.5) + 1.0).ln()
117 } else {
118 0.0
119 };
120
121 {
123 let mut cache = self.text_idf_cache.write();
124 cache
125 .entry(field.0)
126 .or_default()
127 .insert(term.to_string(), idf);
128 }
129
130 idf
131 }
132
133 pub fn avg_field_len(&self, field: Field) -> f32 {
135 {
137 let cache = self.avg_field_len_cache.read();
138 if let Some(&avg) = cache.get(&field.0) {
139 return avg;
140 }
141 }
142
143 let mut weighted_sum = 0.0f64;
145 let mut total_weight = 0u64;
146
147 for segment in &self.segments {
148 let avg_len = segment.avg_field_len(field);
149 let doc_count = segment.num_docs() as u64;
150 if avg_len > 0.0 && doc_count > 0 {
151 weighted_sum += avg_len as f64 * doc_count as f64;
152 total_weight += doc_count;
153 }
154 }
155
156 let avg = if total_weight > 0 {
157 (weighted_sum / total_weight as f64) as f32
158 } else {
159 1.0
160 };
161
162 {
164 let mut cache = self.avg_field_len_cache.write();
165 cache.insert(field.0, avg);
166 }
167
168 avg
169 }
170
171 fn compute_sparse_df(&self, field: Field, dim_id: u32) -> u64 {
173 let mut df = 0u64;
174 for segment in &self.segments {
175 if let Some(sparse_index) = segment.sparse_indexes().get(&field.0)
176 && let Some(Some(posting)) = sparse_index.postings.get(dim_id as usize)
177 {
178 df += posting.doc_count() as u64;
179 }
180 }
181 df
182 }
183
184 fn compute_sparse_total_vectors(&self, field: Field) -> u64 {
187 let mut total = 0u64;
188 for segment in &self.segments {
189 if let Some(sparse_index) = segment.sparse_indexes().get(&field.0) {
190 total += sparse_index.total_vectors as u64;
191 }
192 }
193 total
194 }
195
196 fn compute_text_df(&self, _field: Field, _term: &str) -> u64 {
201 0
205 }
206
207 pub fn num_segments(&self) -> usize {
209 self.segments.len()
210 }
211}
212
213impl std::fmt::Debug for LazyGlobalStats {
214 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215 f.debug_struct("LazyGlobalStats")
216 .field("total_docs", &self.total_docs)
217 .field("num_segments", &self.segments.len())
218 .field("sparse_cache_fields", &self.sparse_idf_cache.read().len())
219 .field("text_cache_fields", &self.text_idf_cache.read().len())
220 .finish()
221 }
222}
223
224#[derive(Debug)]
228pub struct GlobalStats {
229 total_docs: u64,
231 sparse_stats: FxHashMap<u32, SparseFieldStats>,
233 text_stats: FxHashMap<u32, TextFieldStats>,
235 generation: u64,
237}
238
239#[derive(Debug, Default)]
241pub struct SparseFieldStats {
242 pub doc_freqs: FxHashMap<u32, u64>,
244}
245
246#[derive(Debug, Default)]
248pub struct TextFieldStats {
249 pub doc_freqs: FxHashMap<String, u64>,
251 pub avg_field_len: f32,
253}
254
255impl GlobalStats {
256 pub fn new() -> Self {
258 Self {
259 total_docs: 0,
260 sparse_stats: FxHashMap::default(),
261 text_stats: FxHashMap::default(),
262 generation: 0,
263 }
264 }
265
266 #[inline]
268 pub fn total_docs(&self) -> u64 {
269 self.total_docs
270 }
271
272 #[inline]
274 pub fn sparse_idf(&self, field: Field, dim_id: u32) -> f32 {
275 if let Some(stats) = self.sparse_stats.get(&field.0)
276 && let Some(&df) = stats.doc_freqs.get(&dim_id)
277 && df > 0
278 {
279 return (self.total_docs as f32 / df as f32).ln();
280 }
281 0.0
282 }
283
284 pub fn sparse_idf_weights(&self, field: Field, dim_ids: &[u32]) -> Vec<f32> {
286 dim_ids.iter().map(|&d| self.sparse_idf(field, d)).collect()
287 }
288
289 #[inline]
291 pub fn text_idf(&self, field: Field, term: &str) -> f32 {
292 if let Some(stats) = self.text_stats.get(&field.0)
293 && let Some(&df) = stats.doc_freqs.get(term)
294 {
295 let n = self.total_docs as f32;
296 let df = df as f32;
297 return ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
298 }
299 0.0
300 }
301
302 #[inline]
304 pub fn avg_field_len(&self, field: Field) -> f32 {
305 self.text_stats
306 .get(&field.0)
307 .map(|s| s.avg_field_len)
308 .unwrap_or(1.0)
309 }
310
311 #[inline]
313 pub fn generation(&self) -> u64 {
314 self.generation
315 }
316}
317
318impl Default for GlobalStats {
319 fn default() -> Self {
320 Self::new()
321 }
322}
323
324pub struct GlobalStatsBuilder {
326 pub total_docs: u64,
328 sparse_stats: FxHashMap<u32, SparseFieldStats>,
329 text_stats: FxHashMap<u32, TextFieldStats>,
330}
331
332impl GlobalStatsBuilder {
333 pub fn new() -> Self {
335 Self {
336 total_docs: 0,
337 sparse_stats: FxHashMap::default(),
338 text_stats: FxHashMap::default(),
339 }
340 }
341
342 pub fn add_segment(&mut self, reader: &SegmentReader) {
344 self.total_docs += reader.num_docs() as u64;
345
346 }
349
350 pub fn add_sparse_df(&mut self, field: Field, dim_id: u32, doc_count: u64) {
352 let stats = self.sparse_stats.entry(field.0).or_default();
353 *stats.doc_freqs.entry(dim_id).or_insert(0) += doc_count;
354 }
355
356 pub fn add_text_df(&mut self, field: Field, term: String, doc_count: u64) {
358 let stats = self.text_stats.entry(field.0).or_default();
359 *stats.doc_freqs.entry(term).or_insert(0) += doc_count;
360 }
361
362 pub fn set_avg_field_len(&mut self, field: Field, avg_len: f32) {
364 let stats = self.text_stats.entry(field.0).or_default();
365 stats.avg_field_len = avg_len;
366 }
367
368 pub fn build(self, generation: u64) -> GlobalStats {
370 GlobalStats {
371 total_docs: self.total_docs,
372 sparse_stats: self.sparse_stats,
373 text_stats: self.text_stats,
374 generation,
375 }
376 }
377}
378
379impl Default for GlobalStatsBuilder {
380 fn default() -> Self {
381 Self::new()
382 }
383}
384
385pub struct GlobalStatsCache {
390 stats: RwLock<Option<Arc<GlobalStats>>>,
392 generation: RwLock<u64>,
394}
395
396impl GlobalStatsCache {
397 pub fn new() -> Self {
399 Self {
400 stats: RwLock::new(None),
401 generation: RwLock::new(0),
402 }
403 }
404
405 pub fn invalidate(&self) {
407 let mut current_gen = self.generation.write();
408 *current_gen += 1;
409 let mut stats = self.stats.write();
410 *stats = None;
411 }
412
413 pub fn generation(&self) -> u64 {
415 *self.generation.read()
416 }
417
418 pub fn get(&self) -> Option<Arc<GlobalStats>> {
420 self.stats.read().clone()
421 }
422
423 pub fn set(&self, stats: GlobalStats) {
425 let mut cached = self.stats.write();
426 *cached = Some(Arc::new(stats));
427 }
428
429 pub fn get_or_compute<F>(&self, compute: F) -> Arc<GlobalStats>
433 where
434 F: FnOnce(&mut GlobalStatsBuilder),
435 {
436 if let Some(stats) = self.get() {
438 return stats;
439 }
440
441 let current_gen = self.generation();
443 let mut builder = GlobalStatsBuilder::new();
444 compute(&mut builder);
445 let stats = Arc::new(builder.build(current_gen));
446
447 let mut cached = self.stats.write();
449 *cached = Some(Arc::clone(&stats));
450
451 stats
452 }
453
454 pub fn needs_rebuild(&self) -> bool {
456 self.stats.read().is_none()
457 }
458
459 pub fn set_stats(&self, stats: GlobalStats) {
461 let mut cached = self.stats.write();
462 *cached = Some(Arc::new(stats));
463 }
464}
465
466impl Default for GlobalStatsCache {
467 fn default() -> Self {
468 Self::new()
469 }
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475
476 #[test]
477 fn test_sparse_idf_computation() {
478 let mut builder = GlobalStatsBuilder::new();
479 builder.total_docs = 1000;
480 builder.add_sparse_df(Field(0), 42, 100); builder.add_sparse_df(Field(0), 43, 10); let stats = builder.build(1);
484
485 let idf_42 = stats.sparse_idf(Field(0), 42);
487 let idf_43 = stats.sparse_idf(Field(0), 43);
488
489 assert!(idf_43 > idf_42);
491 assert!((idf_42 - (1000.0_f32 / 100.0).ln()).abs() < 0.001);
492 assert!((idf_43 - (1000.0_f32 / 10.0).ln()).abs() < 0.001);
493 }
494
495 #[test]
496 fn test_text_idf_computation() {
497 let mut builder = GlobalStatsBuilder::new();
498 builder.total_docs = 10000;
499 builder.add_text_df(Field(0), "common".to_string(), 5000);
500 builder.add_text_df(Field(0), "rare".to_string(), 10);
501
502 let stats = builder.build(1);
503
504 let idf_common = stats.text_idf(Field(0), "common");
505 let idf_rare = stats.text_idf(Field(0), "rare");
506
507 assert!(idf_rare > idf_common);
509 }
510
511 #[test]
512 fn test_cache_invalidation() {
513 let cache = GlobalStatsCache::new();
514
515 assert!(cache.get().is_none());
517
518 let stats = cache.get_or_compute(|builder| {
520 builder.total_docs = 100;
521 });
522 assert_eq!(stats.total_docs(), 100);
523
524 assert!(cache.get().is_some());
526
527 cache.invalidate();
529 assert!(cache.get().is_none());
530 }
531}