Skip to main content

hermes_core/query/
global_stats.rs

1//! Lazy global statistics for cross-segment IDF computation
2//!
3//! Provides lazily computed and cached statistics across multiple segments for:
4//! - Sparse vector dimensions (for sparse vector queries)
5//! - Full-text terms (for BM25/TF-IDF scoring)
6//!
7//! Key design principles:
8//! - **Lazy computation**: IDF values computed on first access, not upfront
9//! - **Per-term caching**: Each term/dimension's IDF is cached independently
10//! - **Bound to Searcher**: Stats tied to segment snapshot lifetime
11
12use std::sync::Arc;
13
14use parking_lot::RwLock;
15use rustc_hash::FxHashMap;
16
17use crate::dsl::Field;
18use crate::segment::SegmentReader;
19
20/// Lazy global statistics bound to a fixed set of segments
21///
22/// Computes IDF values lazily on first access and caches them.
23/// Lifetime is bound to the Searcher that created it, ensuring
24/// statistics always match the current segment set.
25pub struct LazyGlobalStats {
26    /// Segment readers (Arc for shared ownership with Searcher)
27    segments: Vec<Arc<SegmentReader>>,
28    /// Total documents (computed once on construction)
29    total_docs: u64,
30    /// Cached sparse IDF values: field_id -> (dim_id -> idf)
31    sparse_idf_cache: RwLock<FxHashMap<u32, FxHashMap<u32, f32>>>,
32    /// Cached text IDF values: field_id -> (term -> idf)
33    text_idf_cache: RwLock<FxHashMap<u32, FxHashMap<String, f32>>>,
34    /// Cached average field lengths: field_id -> avg_len
35    avg_field_len_cache: RwLock<FxHashMap<u32, f32>>,
36}
37
38impl LazyGlobalStats {
39    /// Create new lazy stats bound to a set of segments
40    pub fn new(segments: Vec<Arc<SegmentReader>>) -> Self {
41        let total_docs: u64 = segments.iter().map(|s| s.num_docs() as u64).sum();
42        Self {
43            segments,
44            total_docs,
45            sparse_idf_cache: RwLock::new(FxHashMap::default()),
46            text_idf_cache: RwLock::new(FxHashMap::default()),
47            avg_field_len_cache: RwLock::new(FxHashMap::default()),
48        }
49    }
50
51    /// Total documents across all segments
52    #[inline]
53    pub fn total_docs(&self) -> u64 {
54        self.total_docs
55    }
56
57    /// Get or compute IDF for a sparse vector dimension (lazy + cached)
58    ///
59    /// IDF = ln(N / df) where N = total docs, df = docs containing dimension
60    pub fn sparse_idf(&self, field: Field, dim_id: u32) -> f32 {
61        // Fast path: check cache
62        {
63            let cache = self.sparse_idf_cache.read();
64            if let Some(field_cache) = cache.get(&field.0)
65                && let Some(&idf) = field_cache.get(&dim_id)
66            {
67                return idf;
68            }
69        }
70
71        // Slow path: compute and cache
72        let df = self.compute_sparse_df(field, dim_id);
73        let idf = if df > 0 && self.total_docs > 0 {
74            (self.total_docs as f32 / df as f32).ln()
75        } else {
76            0.0
77        };
78
79        // Cache the result
80        {
81            let mut cache = self.sparse_idf_cache.write();
82            cache.entry(field.0).or_default().insert(dim_id, idf);
83        }
84
85        idf
86    }
87
88    /// Compute IDF weights for multiple sparse dimensions (batch, uses cache)
89    pub fn sparse_idf_weights(&self, field: Field, dim_ids: &[u32]) -> Vec<f32> {
90        dim_ids.iter().map(|&d| self.sparse_idf(field, d)).collect()
91    }
92
93    /// Get or compute IDF for a full-text term (lazy + cached)
94    ///
95    /// IDF = ln((N - df + 0.5) / (df + 0.5) + 1) (BM25 variant)
96    pub fn text_idf(&self, field: Field, term: &str) -> f32 {
97        // Fast path: check cache
98        {
99            let cache = self.text_idf_cache.read();
100            if let Some(field_cache) = cache.get(&field.0)
101                && let Some(&idf) = field_cache.get(term)
102            {
103                return idf;
104            }
105        }
106
107        // Slow path: compute and cache
108        let df = self.compute_text_df(field, term);
109        let n = self.total_docs as f32;
110        let df_f = df as f32;
111        let idf = if df > 0 {
112            ((n - df_f + 0.5) / (df_f + 0.5) + 1.0).ln()
113        } else {
114            0.0
115        };
116
117        // Cache the result
118        {
119            let mut cache = self.text_idf_cache.write();
120            cache
121                .entry(field.0)
122                .or_default()
123                .insert(term.to_string(), idf);
124        }
125
126        idf
127    }
128
129    /// Get or compute average field length for BM25 (lazy + cached)
130    pub fn avg_field_len(&self, field: Field) -> f32 {
131        // Fast path: check cache
132        {
133            let cache = self.avg_field_len_cache.read();
134            if let Some(&avg) = cache.get(&field.0) {
135                return avg;
136            }
137        }
138
139        // Slow path: compute weighted average across segments
140        let mut weighted_sum = 0.0f64;
141        let mut total_weight = 0u64;
142
143        for segment in &self.segments {
144            let avg_len = segment.avg_field_len(field);
145            let doc_count = segment.num_docs() as u64;
146            if avg_len > 0.0 && doc_count > 0 {
147                weighted_sum += avg_len as f64 * doc_count as f64;
148                total_weight += doc_count;
149            }
150        }
151
152        let avg = if total_weight > 0 {
153            (weighted_sum / total_weight as f64) as f32
154        } else {
155            1.0
156        };
157
158        // Cache the result
159        {
160            let mut cache = self.avg_field_len_cache.write();
161            cache.insert(field.0, avg);
162        }
163
164        avg
165    }
166
167    /// Compute document frequency for a sparse dimension (not cached - internal)
168    fn compute_sparse_df(&self, field: Field, dim_id: u32) -> u64 {
169        let mut df = 0u64;
170        for segment in &self.segments {
171            if let Some(sparse_index) = segment.sparse_indexes().get(&field.0)
172                && let Some(Some(posting)) = sparse_index.postings.get(dim_id as usize)
173            {
174                df += posting.doc_count() as u64;
175            }
176        }
177        df
178    }
179
180    /// Compute document frequency for a text term (not cached - internal)
181    ///
182    /// Note: This is expensive as it requires async term lookup.
183    /// For now, returns 0 - text IDF should be computed via term dictionary.
184    fn compute_text_df(&self, _field: Field, _term: &str) -> u64 {
185        // Text term lookup requires async access to term dictionary
186        // For now, this is a placeholder - actual implementation would
187        // need to be async or use pre-computed stats
188        0
189    }
190
191    /// Number of segments
192    pub fn num_segments(&self) -> usize {
193        self.segments.len()
194    }
195}
196
197impl std::fmt::Debug for LazyGlobalStats {
198    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199        f.debug_struct("LazyGlobalStats")
200            .field("total_docs", &self.total_docs)
201            .field("num_segments", &self.segments.len())
202            .field("sparse_cache_fields", &self.sparse_idf_cache.read().len())
203            .field("text_cache_fields", &self.text_idf_cache.read().len())
204            .finish()
205    }
206}
207
208// Keep old types for backwards compatibility during transition
209
210/// Global statistics aggregated across all segments (legacy)
211#[derive(Debug)]
212pub struct GlobalStats {
213    /// Total documents across all segments
214    total_docs: u64,
215    /// Sparse vector statistics per field: field_id -> dimension stats
216    sparse_stats: FxHashMap<u32, SparseFieldStats>,
217    /// Full-text statistics per field: field_id -> term stats
218    text_stats: FxHashMap<u32, TextFieldStats>,
219    /// Generation counter for cache invalidation
220    generation: u64,
221}
222
223/// Statistics for a sparse vector field
224#[derive(Debug, Default)]
225pub struct SparseFieldStats {
226    /// Document frequency per dimension: dim_id -> doc_count
227    pub doc_freqs: FxHashMap<u32, u64>,
228}
229
230/// Statistics for a full-text field
231#[derive(Debug, Default)]
232pub struct TextFieldStats {
233    /// Document frequency per term: term -> doc_count
234    pub doc_freqs: FxHashMap<String, u64>,
235    /// Average field length (for BM25)
236    pub avg_field_len: f32,
237}
238
239impl GlobalStats {
240    /// Create empty stats
241    pub fn new() -> Self {
242        Self {
243            total_docs: 0,
244            sparse_stats: FxHashMap::default(),
245            text_stats: FxHashMap::default(),
246            generation: 0,
247        }
248    }
249
250    /// Total documents in the index
251    #[inline]
252    pub fn total_docs(&self) -> u64 {
253        self.total_docs
254    }
255
256    /// Compute IDF for a sparse vector dimension
257    #[inline]
258    pub fn sparse_idf(&self, field: Field, dim_id: u32) -> f32 {
259        if let Some(stats) = self.sparse_stats.get(&field.0)
260            && let Some(&df) = stats.doc_freqs.get(&dim_id)
261            && df > 0
262        {
263            return (self.total_docs as f32 / df as f32).ln();
264        }
265        0.0
266    }
267
268    /// Compute IDF weights for multiple sparse dimensions
269    pub fn sparse_idf_weights(&self, field: Field, dim_ids: &[u32]) -> Vec<f32> {
270        dim_ids.iter().map(|&d| self.sparse_idf(field, d)).collect()
271    }
272
273    /// Compute IDF for a full-text term
274    #[inline]
275    pub fn text_idf(&self, field: Field, term: &str) -> f32 {
276        if let Some(stats) = self.text_stats.get(&field.0)
277            && let Some(&df) = stats.doc_freqs.get(term)
278        {
279            let n = self.total_docs as f32;
280            let df = df as f32;
281            return ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
282        }
283        0.0
284    }
285
286    /// Get average field length for BM25
287    #[inline]
288    pub fn avg_field_len(&self, field: Field) -> f32 {
289        self.text_stats
290            .get(&field.0)
291            .map(|s| s.avg_field_len)
292            .unwrap_or(1.0)
293    }
294
295    /// Current generation
296    #[inline]
297    pub fn generation(&self) -> u64 {
298        self.generation
299    }
300}
301
302impl Default for GlobalStats {
303    fn default() -> Self {
304        Self::new()
305    }
306}
307
308/// Builder for aggregating statistics from multiple segments
309pub struct GlobalStatsBuilder {
310    /// Total documents across all segments
311    pub total_docs: u64,
312    sparse_stats: FxHashMap<u32, SparseFieldStats>,
313    text_stats: FxHashMap<u32, TextFieldStats>,
314}
315
316impl GlobalStatsBuilder {
317    /// Create a new builder
318    pub fn new() -> Self {
319        Self {
320            total_docs: 0,
321            sparse_stats: FxHashMap::default(),
322            text_stats: FxHashMap::default(),
323        }
324    }
325
326    /// Add statistics from a segment reader
327    pub fn add_segment(&mut self, reader: &SegmentReader) {
328        self.total_docs += reader.num_docs() as u64;
329
330        // Aggregate sparse vector statistics
331        // Note: This requires access to sparse_indexes which may need to be exposed
332    }
333
334    /// Add sparse dimension document frequency
335    pub fn add_sparse_df(&mut self, field: Field, dim_id: u32, doc_count: u64) {
336        let stats = self.sparse_stats.entry(field.0).or_default();
337        *stats.doc_freqs.entry(dim_id).or_insert(0) += doc_count;
338    }
339
340    /// Add text term document frequency
341    pub fn add_text_df(&mut self, field: Field, term: String, doc_count: u64) {
342        let stats = self.text_stats.entry(field.0).or_default();
343        *stats.doc_freqs.entry(term).or_insert(0) += doc_count;
344    }
345
346    /// Set average field length for a text field
347    pub fn set_avg_field_len(&mut self, field: Field, avg_len: f32) {
348        let stats = self.text_stats.entry(field.0).or_default();
349        stats.avg_field_len = avg_len;
350    }
351
352    /// Build the final GlobalStats
353    pub fn build(self, generation: u64) -> GlobalStats {
354        GlobalStats {
355            total_docs: self.total_docs,
356            sparse_stats: self.sparse_stats,
357            text_stats: self.text_stats,
358            generation,
359        }
360    }
361}
362
363impl Default for GlobalStatsBuilder {
364    fn default() -> Self {
365        Self::new()
366    }
367}
368
369/// Cached global statistics with automatic invalidation
370///
371/// This is the main entry point for getting global IDF values.
372/// It caches statistics and rebuilds them when the segment list changes.
373pub struct GlobalStatsCache {
374    /// Cached statistics
375    stats: RwLock<Option<Arc<GlobalStats>>>,
376    /// Current generation (incremented when segments change)
377    generation: RwLock<u64>,
378}
379
380impl GlobalStatsCache {
381    /// Create a new cache
382    pub fn new() -> Self {
383        Self {
384            stats: RwLock::new(None),
385            generation: RwLock::new(0),
386        }
387    }
388
389    /// Invalidate the cache (call when segments are added/removed/merged)
390    pub fn invalidate(&self) {
391        let mut current_gen = self.generation.write();
392        *current_gen += 1;
393        let mut stats = self.stats.write();
394        *stats = None;
395    }
396
397    /// Get current generation
398    pub fn generation(&self) -> u64 {
399        *self.generation.read()
400    }
401
402    /// Get cached stats if valid, or None if needs rebuild
403    pub fn get(&self) -> Option<Arc<GlobalStats>> {
404        self.stats.read().clone()
405    }
406
407    /// Update the cache with new stats
408    pub fn set(&self, stats: GlobalStats) {
409        let mut cached = self.stats.write();
410        *cached = Some(Arc::new(stats));
411    }
412
413    /// Get or compute stats using the provided builder function (sync version)
414    ///
415    /// For basic stats that don't require async iteration.
416    pub fn get_or_compute<F>(&self, compute: F) -> Arc<GlobalStats>
417    where
418        F: FnOnce(&mut GlobalStatsBuilder),
419    {
420        // Fast path: return cached if available
421        if let Some(stats) = self.get() {
422            return stats;
423        }
424
425        // Slow path: compute new stats
426        let current_gen = self.generation();
427        let mut builder = GlobalStatsBuilder::new();
428        compute(&mut builder);
429        let stats = Arc::new(builder.build(current_gen));
430
431        // Cache the result
432        let mut cached = self.stats.write();
433        *cached = Some(Arc::clone(&stats));
434
435        stats
436    }
437
438    /// Check if stats need to be rebuilt
439    pub fn needs_rebuild(&self) -> bool {
440        self.stats.read().is_none()
441    }
442
443    /// Set pre-built stats (for async computation)
444    pub fn set_stats(&self, stats: GlobalStats) {
445        let mut cached = self.stats.write();
446        *cached = Some(Arc::new(stats));
447    }
448}
449
450impl Default for GlobalStatsCache {
451    fn default() -> Self {
452        Self::new()
453    }
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459
460    #[test]
461    fn test_sparse_idf_computation() {
462        let mut builder = GlobalStatsBuilder::new();
463        builder.total_docs = 1000;
464        builder.add_sparse_df(Field(0), 42, 100); // dim 42 appears in 100 docs
465        builder.add_sparse_df(Field(0), 43, 10); // dim 43 appears in 10 docs
466
467        let stats = builder.build(1);
468
469        // IDF = ln(N/df)
470        let idf_42 = stats.sparse_idf(Field(0), 42);
471        let idf_43 = stats.sparse_idf(Field(0), 43);
472
473        // dim 43 should have higher IDF (rarer)
474        assert!(idf_43 > idf_42);
475        assert!((idf_42 - (1000.0_f32 / 100.0).ln()).abs() < 0.001);
476        assert!((idf_43 - (1000.0_f32 / 10.0).ln()).abs() < 0.001);
477    }
478
479    #[test]
480    fn test_text_idf_computation() {
481        let mut builder = GlobalStatsBuilder::new();
482        builder.total_docs = 10000;
483        builder.add_text_df(Field(0), "common".to_string(), 5000);
484        builder.add_text_df(Field(0), "rare".to_string(), 10);
485
486        let stats = builder.build(1);
487
488        let idf_common = stats.text_idf(Field(0), "common");
489        let idf_rare = stats.text_idf(Field(0), "rare");
490
491        // Rare term should have higher IDF
492        assert!(idf_rare > idf_common);
493    }
494
495    #[test]
496    fn test_cache_invalidation() {
497        let cache = GlobalStatsCache::new();
498
499        // Initially no stats
500        assert!(cache.get().is_none());
501
502        // Compute stats
503        let stats = cache.get_or_compute(|builder| {
504            builder.total_docs = 100;
505        });
506        assert_eq!(stats.total_docs(), 100);
507
508        // Should be cached now
509        assert!(cache.get().is_some());
510
511        // Invalidate
512        cache.invalidate();
513        assert!(cache.get().is_none());
514    }
515}