Skip to main content

hermes_core/query/
global_stats.rs

1//! Lazy global statistics for cross-segment IDF computation
2//!
3//! Provides lazily computed and cached statistics across multiple segments for:
4//! - Sparse vector dimensions (for sparse vector queries)
5//! - Full-text terms (for BM25/TF-IDF scoring)
6//!
7//! Key design principles:
8//! - **Lazy computation**: IDF values computed on first access, not upfront
9//! - **Per-term caching**: Each term/dimension's IDF is cached independently
10//! - **Bound to Searcher**: Stats tied to segment snapshot lifetime
11
12use std::sync::Arc;
13
14use parking_lot::RwLock;
15use rustc_hash::FxHashMap;
16
17use crate::dsl::Field;
18use crate::segment::SegmentReader;
19
20/// Lazy global statistics bound to a fixed set of segments
21///
22/// Computes IDF values lazily on first access and caches them.
23/// Lifetime is bound to the Searcher that created it, ensuring
24/// statistics always match the current segment set.
25pub struct LazyGlobalStats {
26    /// Segment readers (Arc for shared ownership with Searcher)
27    segments: Vec<Arc<SegmentReader>>,
28    /// Total documents (computed once on construction)
29    total_docs: u64,
30    /// Cached sparse IDF values: field_id -> (dim_id -> idf)
31    sparse_idf_cache: RwLock<FxHashMap<u32, FxHashMap<u32, f32>>>,
32    /// Cached total_vectors per sparse field (computed once per field, not per dim)
33    sparse_total_vectors_cache: RwLock<FxHashMap<u32, u64>>,
34    /// Cached text IDF values: field_id -> (term -> idf)
35    text_idf_cache: RwLock<FxHashMap<u32, FxHashMap<String, f32>>>,
36    /// Cached average field lengths: field_id -> avg_len
37    avg_field_len_cache: RwLock<FxHashMap<u32, f32>>,
38}
39
40impl LazyGlobalStats {
41    /// Create new lazy stats bound to a set of segments
42    pub fn new(segments: Vec<Arc<SegmentReader>>) -> Self {
43        let total_docs: u64 = segments.iter().map(|s| s.num_docs() as u64).sum();
44        Self {
45            segments,
46            total_docs,
47            sparse_idf_cache: RwLock::new(FxHashMap::default()),
48            sparse_total_vectors_cache: RwLock::new(FxHashMap::default()),
49            text_idf_cache: RwLock::new(FxHashMap::default()),
50            avg_field_len_cache: RwLock::new(FxHashMap::default()),
51        }
52    }
53
54    /// Total documents across all segments
55    #[inline]
56    pub fn total_docs(&self) -> u64 {
57        self.total_docs
58    }
59
60    /// Get or compute IDF for a sparse vector dimension (lazy + cached)
61    ///
62    /// IDF = ln(N / df) where N = total docs, df = docs containing dimension
63    pub fn sparse_idf(&self, field: Field, dim_id: u32) -> f32 {
64        // Fast path: check cache
65        {
66            let cache = self.sparse_idf_cache.read();
67            if let Some(field_cache) = cache.get(&field.0)
68                && let Some(&idf) = field_cache.get(&dim_id)
69            {
70                return idf;
71            }
72        }
73
74        // Slow path: compute and cache
75        let df = self.compute_sparse_df(field, dim_id);
76        let n = self.cached_sparse_n(field);
77        let idf = if df > 0 && n > 0 {
78            (n as f32 / df as f32).ln().max(0.0)
79        } else {
80            0.0
81        };
82
83        // Cache the result
84        {
85            let mut cache = self.sparse_idf_cache.write();
86            cache.entry(field.0).or_default().insert(dim_id, idf);
87        }
88
89        idf
90    }
91
92    /// Compute IDF weights for multiple sparse dimensions (batch, uses cache)
93    ///
94    /// More efficient than calling `sparse_idf()` per dimension: resolves
95    /// total_vectors once and acquires write lock once for all cache misses.
96    pub fn sparse_idf_weights(&self, field: Field, dim_ids: &[u32]) -> Vec<f32> {
97        // Fast path: check how many are already cached
98        let mut result = vec![0.0f32; dim_ids.len()];
99        let mut misses: Vec<usize> = Vec::new();
100        {
101            let cache = self.sparse_idf_cache.read();
102            if let Some(field_cache) = cache.get(&field.0) {
103                for (i, &dim_id) in dim_ids.iter().enumerate() {
104                    if let Some(&idf) = field_cache.get(&dim_id) {
105                        result[i] = idf;
106                    } else {
107                        misses.push(i);
108                    }
109                }
110            } else {
111                misses.extend(0..dim_ids.len());
112            }
113        }
114
115        if misses.is_empty() {
116            return result;
117        }
118
119        // Compute N once for all misses (was previously per-dimension)
120        let n = self.cached_sparse_n(field);
121
122        // Compute missing IDF values
123        let mut new_entries: Vec<(u32, f32)> = Vec::with_capacity(misses.len());
124        for &i in &misses {
125            let dim_id = dim_ids[i];
126            let df = self.compute_sparse_df(field, dim_id);
127            let idf = if df > 0 && n > 0 {
128                (n as f32 / df as f32).ln().max(0.0)
129            } else {
130                0.0
131            };
132            result[i] = idf;
133            new_entries.push((dim_id, idf));
134        }
135
136        // Batch-insert into cache with single write lock
137        {
138            let mut cache = self.sparse_idf_cache.write();
139            let field_cache = cache.entry(field.0).or_default();
140            for (dim_id, idf) in new_entries {
141                field_cache.insert(dim_id, idf);
142            }
143        }
144
145        result
146    }
147
148    /// Get cached N = max(total_vectors, total_docs) for a sparse field.
149    /// Computed once per field and cached.
150    fn cached_sparse_n(&self, field: Field) -> u64 {
151        // Fast path
152        {
153            let cache = self.sparse_total_vectors_cache.read();
154            if let Some(&tv) = cache.get(&field.0) {
155                return tv.max(self.total_docs);
156            }
157        }
158        // Slow path: compute and cache
159        let tv = self.compute_sparse_total_vectors(field);
160        self.sparse_total_vectors_cache.write().insert(field.0, tv);
161        tv.max(self.total_docs)
162    }
163
164    /// Get or compute IDF for a full-text term (lazy + cached)
165    ///
166    /// IDF = ln((N - df + 0.5) / (df + 0.5) + 1) (BM25 variant)
167    pub fn text_idf(&self, field: Field, term: &str) -> f32 {
168        // Fast path: check cache
169        {
170            let cache = self.text_idf_cache.read();
171            if let Some(field_cache) = cache.get(&field.0)
172                && let Some(&idf) = field_cache.get(term)
173            {
174                return idf;
175            }
176        }
177
178        // Slow path: compute and cache
179        let df = self.compute_text_df(field, term);
180        let n = self.total_docs as f32;
181        let df_f = df as f32;
182        let idf = if df > 0 {
183            ((n - df_f + 0.5) / (df_f + 0.5) + 1.0).ln()
184        } else {
185            0.0
186        };
187
188        // Cache the result
189        {
190            let mut cache = self.text_idf_cache.write();
191            cache
192                .entry(field.0)
193                .or_default()
194                .insert(term.to_string(), idf);
195        }
196
197        idf
198    }
199
200    /// Get or compute average field length for BM25 (lazy + cached)
201    pub fn avg_field_len(&self, field: Field) -> f32 {
202        // Fast path: check cache
203        {
204            let cache = self.avg_field_len_cache.read();
205            if let Some(&avg) = cache.get(&field.0) {
206                return avg;
207            }
208        }
209
210        // Slow path: compute weighted average across segments
211        let mut weighted_sum = 0.0f64;
212        let mut total_weight = 0u64;
213
214        for segment in &self.segments {
215            let avg_len = segment.avg_field_len(field);
216            let doc_count = segment.num_docs() as u64;
217            if avg_len > 0.0 && doc_count > 0 {
218                weighted_sum += avg_len as f64 * doc_count as f64;
219                total_weight += doc_count;
220            }
221        }
222
223        let avg = if total_weight > 0 {
224            (weighted_sum / total_weight as f64) as f32
225        } else {
226            1.0
227        };
228
229        // Cache the result
230        {
231            let mut cache = self.avg_field_len_cache.write();
232            cache.insert(field.0, avg);
233        }
234
235        avg
236    }
237
238    /// Compute document frequency for a sparse dimension (not cached - internal)
239    /// Uses skip list metadata - no I/O needed
240    fn compute_sparse_df(&self, field: Field, dim_id: u32) -> u64 {
241        let mut df = 0u64;
242        for segment in &self.segments {
243            if let Some(sparse_index) = segment.sparse_indexes().get(&field.0) {
244                df += sparse_index.doc_count(dim_id) as u64;
245            }
246        }
247        df
248    }
249
250    /// Compute total sparse vectors for a field across all segments
251    /// For multi-valued fields, this may exceed total_docs
252    fn compute_sparse_total_vectors(&self, field: Field) -> u64 {
253        let mut total = 0u64;
254        for segment in &self.segments {
255            if let Some(sparse_index) = segment.sparse_indexes().get(&field.0) {
256                total += sparse_index.total_vectors as u64;
257            }
258        }
259        total
260    }
261
262    /// Compute document frequency for a text term (not cached - internal)
263    ///
264    /// Note: This is expensive as it requires async term lookup.
265    /// For now, returns 0 - text IDF should be computed via term dictionary.
266    fn compute_text_df(&self, _field: Field, _term: &str) -> u64 {
267        // Text term lookup requires async access to term dictionary
268        // For now, this is a placeholder - actual implementation would
269        // need to be async or use pre-computed stats
270        0
271    }
272
273    /// Number of segments
274    pub fn num_segments(&self) -> usize {
275        self.segments.len()
276    }
277}
278
279impl std::fmt::Debug for LazyGlobalStats {
280    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281        f.debug_struct("LazyGlobalStats")
282            .field("total_docs", &self.total_docs)
283            .field("num_segments", &self.segments.len())
284            .field("sparse_cache_fields", &self.sparse_idf_cache.read().len())
285            .field("text_cache_fields", &self.text_idf_cache.read().len())
286            .finish()
287    }
288}
289
290// Keep old types for backwards compatibility during transition
291
292/// Global statistics aggregated across all segments (legacy)
293#[derive(Debug)]
294pub struct GlobalStats {
295    /// Total documents across all segments
296    total_docs: u64,
297    /// Sparse vector statistics per field: field_id -> dimension stats
298    sparse_stats: FxHashMap<u32, SparseFieldStats>,
299    /// Full-text statistics per field: field_id -> term stats
300    text_stats: FxHashMap<u32, TextFieldStats>,
301    /// Generation counter for cache invalidation
302    generation: u64,
303}
304
305/// Statistics for a sparse vector field
306#[derive(Debug, Default)]
307pub struct SparseFieldStats {
308    /// Document frequency per dimension: dim_id -> doc_count
309    pub doc_freqs: FxHashMap<u32, u64>,
310}
311
312/// Statistics for a full-text field
313#[derive(Debug, Default)]
314pub struct TextFieldStats {
315    /// Document frequency per term: term -> doc_count
316    pub doc_freqs: FxHashMap<String, u64>,
317    /// Average field length (for BM25)
318    pub avg_field_len: f32,
319}
320
321impl GlobalStats {
322    /// Create empty stats
323    pub fn new() -> Self {
324        Self {
325            total_docs: 0,
326            sparse_stats: FxHashMap::default(),
327            text_stats: FxHashMap::default(),
328            generation: 0,
329        }
330    }
331
332    /// Total documents in the index
333    #[inline]
334    pub fn total_docs(&self) -> u64 {
335        self.total_docs
336    }
337
338    /// Compute IDF for a sparse vector dimension
339    #[inline]
340    pub fn sparse_idf(&self, field: Field, dim_id: u32) -> f32 {
341        if let Some(stats) = self.sparse_stats.get(&field.0)
342            && let Some(&df) = stats.doc_freqs.get(&dim_id)
343            && df > 0
344        {
345            return (self.total_docs as f32 / df as f32).ln();
346        }
347        0.0
348    }
349
350    /// Compute IDF weights for multiple sparse dimensions
351    pub fn sparse_idf_weights(&self, field: Field, dim_ids: &[u32]) -> Vec<f32> {
352        dim_ids.iter().map(|&d| self.sparse_idf(field, d)).collect()
353    }
354
355    /// Compute IDF for a full-text term
356    #[inline]
357    pub fn text_idf(&self, field: Field, term: &str) -> f32 {
358        if let Some(stats) = self.text_stats.get(&field.0)
359            && let Some(&df) = stats.doc_freqs.get(term)
360        {
361            let n = self.total_docs as f32;
362            let df = df as f32;
363            return ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
364        }
365        0.0
366    }
367
368    /// Get average field length for BM25
369    #[inline]
370    pub fn avg_field_len(&self, field: Field) -> f32 {
371        self.text_stats
372            .get(&field.0)
373            .map(|s| s.avg_field_len)
374            .unwrap_or(1.0)
375    }
376
377    /// Current generation
378    #[inline]
379    pub fn generation(&self) -> u64 {
380        self.generation
381    }
382}
383
384impl Default for GlobalStats {
385    fn default() -> Self {
386        Self::new()
387    }
388}
389
390/// Builder for aggregating statistics from multiple segments
391pub struct GlobalStatsBuilder {
392    /// Total documents across all segments
393    pub total_docs: u64,
394    sparse_stats: FxHashMap<u32, SparseFieldStats>,
395    text_stats: FxHashMap<u32, TextFieldStats>,
396}
397
398impl GlobalStatsBuilder {
399    /// Create a new builder
400    pub fn new() -> Self {
401        Self {
402            total_docs: 0,
403            sparse_stats: FxHashMap::default(),
404            text_stats: FxHashMap::default(),
405        }
406    }
407
408    /// Add statistics from a segment reader
409    pub fn add_segment(&mut self, reader: &SegmentReader) {
410        self.total_docs += reader.num_docs() as u64;
411
412        // Aggregate sparse vector statistics
413        // Note: This requires access to sparse_indexes which may need to be exposed
414    }
415
416    /// Add sparse dimension document frequency
417    pub fn add_sparse_df(&mut self, field: Field, dim_id: u32, doc_count: u64) {
418        let stats = self.sparse_stats.entry(field.0).or_default();
419        *stats.doc_freqs.entry(dim_id).or_insert(0) += doc_count;
420    }
421
422    /// Add text term document frequency
423    pub fn add_text_df(&mut self, field: Field, term: String, doc_count: u64) {
424        let stats = self.text_stats.entry(field.0).or_default();
425        *stats.doc_freqs.entry(term).or_insert(0) += doc_count;
426    }
427
428    /// Set average field length for a text field
429    pub fn set_avg_field_len(&mut self, field: Field, avg_len: f32) {
430        let stats = self.text_stats.entry(field.0).or_default();
431        stats.avg_field_len = avg_len;
432    }
433
434    /// Build the final GlobalStats
435    pub fn build(self, generation: u64) -> GlobalStats {
436        GlobalStats {
437            total_docs: self.total_docs,
438            sparse_stats: self.sparse_stats,
439            text_stats: self.text_stats,
440            generation,
441        }
442    }
443}
444
445impl Default for GlobalStatsBuilder {
446    fn default() -> Self {
447        Self::new()
448    }
449}
450
451/// Cached global statistics with automatic invalidation
452///
453/// This is the main entry point for getting global IDF values.
454/// It caches statistics and rebuilds them when the segment list changes.
455pub struct GlobalStatsCache {
456    /// Cached statistics
457    stats: RwLock<Option<Arc<GlobalStats>>>,
458    /// Current generation (incremented when segments change)
459    generation: RwLock<u64>,
460}
461
462impl GlobalStatsCache {
463    /// Create a new cache
464    pub fn new() -> Self {
465        Self {
466            stats: RwLock::new(None),
467            generation: RwLock::new(0),
468        }
469    }
470
471    /// Invalidate the cache (call when segments are added/removed/merged)
472    pub fn invalidate(&self) {
473        let mut current_gen = self.generation.write();
474        *current_gen += 1;
475        let mut stats = self.stats.write();
476        *stats = None;
477    }
478
479    /// Get current generation
480    pub fn generation(&self) -> u64 {
481        *self.generation.read()
482    }
483
484    /// Get cached stats if valid, or None if needs rebuild
485    pub fn get(&self) -> Option<Arc<GlobalStats>> {
486        self.stats.read().clone()
487    }
488
489    /// Update the cache with new stats
490    pub fn set(&self, stats: GlobalStats) {
491        let mut cached = self.stats.write();
492        *cached = Some(Arc::new(stats));
493    }
494
495    /// Get or compute stats using the provided builder function (sync version)
496    ///
497    /// For basic stats that don't require async iteration.
498    pub fn get_or_compute<F>(&self, compute: F) -> Arc<GlobalStats>
499    where
500        F: FnOnce(&mut GlobalStatsBuilder),
501    {
502        // Fast path: return cached if available
503        if let Some(stats) = self.get() {
504            return stats;
505        }
506
507        // Slow path: compute new stats
508        let current_gen = self.generation();
509        let mut builder = GlobalStatsBuilder::new();
510        compute(&mut builder);
511        let stats = Arc::new(builder.build(current_gen));
512
513        // Cache the result
514        let mut cached = self.stats.write();
515        *cached = Some(Arc::clone(&stats));
516
517        stats
518    }
519
520    /// Check if stats need to be rebuilt
521    pub fn needs_rebuild(&self) -> bool {
522        self.stats.read().is_none()
523    }
524
525    /// Set pre-built stats (for async computation)
526    pub fn set_stats(&self, stats: GlobalStats) {
527        let mut cached = self.stats.write();
528        *cached = Some(Arc::new(stats));
529    }
530}
531
532impl Default for GlobalStatsCache {
533    fn default() -> Self {
534        Self::new()
535    }
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541
542    #[test]
543    fn test_sparse_idf_computation() {
544        let mut builder = GlobalStatsBuilder::new();
545        builder.total_docs = 1000;
546        builder.add_sparse_df(Field(0), 42, 100); // dim 42 appears in 100 docs
547        builder.add_sparse_df(Field(0), 43, 10); // dim 43 appears in 10 docs
548
549        let stats = builder.build(1);
550
551        // IDF = ln(N/df)
552        let idf_42 = stats.sparse_idf(Field(0), 42);
553        let idf_43 = stats.sparse_idf(Field(0), 43);
554
555        // dim 43 should have higher IDF (rarer)
556        assert!(idf_43 > idf_42);
557        assert!((idf_42 - (1000.0_f32 / 100.0).ln()).abs() < 0.001);
558        assert!((idf_43 - (1000.0_f32 / 10.0).ln()).abs() < 0.001);
559    }
560
561    #[test]
562    fn test_text_idf_computation() {
563        let mut builder = GlobalStatsBuilder::new();
564        builder.total_docs = 10000;
565        builder.add_text_df(Field(0), "common".to_string(), 5000);
566        builder.add_text_df(Field(0), "rare".to_string(), 10);
567
568        let stats = builder.build(1);
569
570        let idf_common = stats.text_idf(Field(0), "common");
571        let idf_rare = stats.text_idf(Field(0), "rare");
572
573        // Rare term should have higher IDF
574        assert!(idf_rare > idf_common);
575    }
576
577    #[test]
578    fn test_cache_invalidation() {
579        let cache = GlobalStatsCache::new();
580
581        // Initially no stats
582        assert!(cache.get().is_none());
583
584        // Compute stats
585        let stats = cache.get_or_compute(|builder| {
586            builder.total_docs = 100;
587        });
588        assert_eq!(stats.total_docs(), 100);
589
590        // Should be cached now
591        assert!(cache.get().is_some());
592
593        // Invalidate
594        cache.invalidate();
595        assert!(cache.get().is_none());
596    }
597}