Skip to main content

hermes_core/query/
global_stats.rs

1//! Lazy global statistics for cross-segment IDF computation
2//!
3//! Provides lazily computed and cached statistics across multiple segments for:
4//! - Sparse vector dimensions (for sparse vector queries)
5//! - Full-text terms (for BM25/TF-IDF scoring)
6//!
7//! Key design principles:
8//! - **Lazy computation**: IDF values computed on first access, not upfront
9//! - **Per-term caching**: Each term/dimension's IDF is cached independently
10//! - **Bound to Searcher**: Stats tied to segment snapshot lifetime
11
12use std::sync::Arc;
13
14use parking_lot::RwLock;
15use rustc_hash::FxHashMap;
16
17use crate::dsl::Field;
18use crate::segment::SegmentReader;
19
20/// Lazy global statistics bound to a fixed set of segments
21///
22/// Computes IDF values lazily on first access and caches them.
23/// Lifetime is bound to the Searcher that created it, ensuring
24/// statistics always match the current segment set.
25pub struct LazyGlobalStats {
26    /// Segment readers (Arc for shared ownership with Searcher)
27    segments: Vec<Arc<SegmentReader>>,
28    /// Total documents (computed once on construction)
29    total_docs: u64,
30    /// Cached sparse IDF values: field_id -> (dim_id -> idf)
31    sparse_idf_cache: RwLock<FxHashMap<u32, FxHashMap<u32, f32>>>,
32    /// Cached text IDF values: field_id -> (term -> idf)
33    text_idf_cache: RwLock<FxHashMap<u32, FxHashMap<String, f32>>>,
34    /// Cached average field lengths: field_id -> avg_len
35    avg_field_len_cache: RwLock<FxHashMap<u32, f32>>,
36}
37
38impl LazyGlobalStats {
39    /// Create new lazy stats bound to a set of segments
40    pub fn new(segments: Vec<Arc<SegmentReader>>) -> Self {
41        let total_docs: u64 = segments.iter().map(|s| s.num_docs() as u64).sum();
42        Self {
43            segments,
44            total_docs,
45            sparse_idf_cache: RwLock::new(FxHashMap::default()),
46            text_idf_cache: RwLock::new(FxHashMap::default()),
47            avg_field_len_cache: RwLock::new(FxHashMap::default()),
48        }
49    }
50
51    /// Total documents across all segments
52    #[inline]
53    pub fn total_docs(&self) -> u64 {
54        self.total_docs
55    }
56
57    /// Get or compute IDF for a sparse vector dimension (lazy + cached)
58    ///
59    /// IDF = ln(N / df) where N = total docs, df = docs containing dimension
60    pub fn sparse_idf(&self, field: Field, dim_id: u32) -> f32 {
61        // Fast path: check cache
62        {
63            let cache = self.sparse_idf_cache.read();
64            if let Some(field_cache) = cache.get(&field.0)
65                && let Some(&idf) = field_cache.get(&dim_id)
66            {
67                return idf;
68            }
69        }
70
71        // Slow path: compute and cache
72        let df = self.compute_sparse_df(field, dim_id);
73        // Use total_vectors for proper IDF with multi-valued fields
74        // This ensures df <= N even when documents have multiple sparse vectors
75        let total_vectors = self.compute_sparse_total_vectors(field);
76        let n = total_vectors.max(self.total_docs);
77        let idf = if df > 0 && n > 0 {
78            (n as f32 / df as f32).ln().max(0.0)
79        } else {
80            0.0
81        };
82
83        // Cache the result
84        {
85            let mut cache = self.sparse_idf_cache.write();
86            cache.entry(field.0).or_default().insert(dim_id, idf);
87        }
88
89        idf
90    }
91
92    /// Compute IDF weights for multiple sparse dimensions (batch, uses cache)
93    pub fn sparse_idf_weights(&self, field: Field, dim_ids: &[u32]) -> Vec<f32> {
94        dim_ids.iter().map(|&d| self.sparse_idf(field, d)).collect()
95    }
96
97    /// Get or compute IDF for a full-text term (lazy + cached)
98    ///
99    /// IDF = ln((N - df + 0.5) / (df + 0.5) + 1) (BM25 variant)
100    pub fn text_idf(&self, field: Field, term: &str) -> f32 {
101        // Fast path: check cache
102        {
103            let cache = self.text_idf_cache.read();
104            if let Some(field_cache) = cache.get(&field.0)
105                && let Some(&idf) = field_cache.get(term)
106            {
107                return idf;
108            }
109        }
110
111        // Slow path: compute and cache
112        let df = self.compute_text_df(field, term);
113        let n = self.total_docs as f32;
114        let df_f = df as f32;
115        let idf = if df > 0 {
116            ((n - df_f + 0.5) / (df_f + 0.5) + 1.0).ln()
117        } else {
118            0.0
119        };
120
121        // Cache the result
122        {
123            let mut cache = self.text_idf_cache.write();
124            cache
125                .entry(field.0)
126                .or_default()
127                .insert(term.to_string(), idf);
128        }
129
130        idf
131    }
132
133    /// Get or compute average field length for BM25 (lazy + cached)
134    pub fn avg_field_len(&self, field: Field) -> f32 {
135        // Fast path: check cache
136        {
137            let cache = self.avg_field_len_cache.read();
138            if let Some(&avg) = cache.get(&field.0) {
139                return avg;
140            }
141        }
142
143        // Slow path: compute weighted average across segments
144        let mut weighted_sum = 0.0f64;
145        let mut total_weight = 0u64;
146
147        for segment in &self.segments {
148            let avg_len = segment.avg_field_len(field);
149            let doc_count = segment.num_docs() as u64;
150            if avg_len > 0.0 && doc_count > 0 {
151                weighted_sum += avg_len as f64 * doc_count as f64;
152                total_weight += doc_count;
153            }
154        }
155
156        let avg = if total_weight > 0 {
157            (weighted_sum / total_weight as f64) as f32
158        } else {
159            1.0
160        };
161
162        // Cache the result
163        {
164            let mut cache = self.avg_field_len_cache.write();
165            cache.insert(field.0, avg);
166        }
167
168        avg
169    }
170
171    /// Compute document frequency for a sparse dimension (not cached - internal)
172    #[cfg(feature = "native")]
173    fn compute_sparse_df(&self, field: Field, dim_id: u32) -> u64 {
174        let mut df = 0u64;
175        for segment in &self.segments {
176            if let Some(sparse_index) = segment.sparse_indexes().get(&field.0)
177                && let Ok(Some(posting)) = sparse_index.get_posting_blocking(dim_id)
178            {
179                df += posting.doc_count() as u64;
180            }
181        }
182        df
183    }
184
185    /// Compute document frequency for a sparse dimension (not cached - internal)
186    /// WASM version uses cached postings only (no blocking)
187    #[cfg(not(feature = "native"))]
188    fn compute_sparse_df(&self, field: Field, dim_id: u32) -> u64 {
189        let mut df = 0u64;
190        for segment in &self.segments {
191            if let Some(sparse_index) = segment.sparse_indexes().get(&field.0)
192                && let Some(posting) = sparse_index.get_cached(dim_id)
193            {
194                df += posting.doc_count() as u64;
195            }
196        }
197        df
198    }
199
200    /// Compute total sparse vectors for a field across all segments
201    /// For multi-valued fields, this may exceed total_docs
202    fn compute_sparse_total_vectors(&self, field: Field) -> u64 {
203        let mut total = 0u64;
204        for segment in &self.segments {
205            if let Some(sparse_index) = segment.sparse_indexes().get(&field.0) {
206                total += sparse_index.total_vectors as u64;
207            }
208        }
209        total
210    }
211
212    /// Compute document frequency for a text term (not cached - internal)
213    ///
214    /// Note: This is expensive as it requires async term lookup.
215    /// For now, returns 0 - text IDF should be computed via term dictionary.
216    fn compute_text_df(&self, _field: Field, _term: &str) -> u64 {
217        // Text term lookup requires async access to term dictionary
218        // For now, this is a placeholder - actual implementation would
219        // need to be async or use pre-computed stats
220        0
221    }
222
223    /// Number of segments
224    pub fn num_segments(&self) -> usize {
225        self.segments.len()
226    }
227}
228
229impl std::fmt::Debug for LazyGlobalStats {
230    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231        f.debug_struct("LazyGlobalStats")
232            .field("total_docs", &self.total_docs)
233            .field("num_segments", &self.segments.len())
234            .field("sparse_cache_fields", &self.sparse_idf_cache.read().len())
235            .field("text_cache_fields", &self.text_idf_cache.read().len())
236            .finish()
237    }
238}
239
240// Keep old types for backwards compatibility during transition
241
242/// Global statistics aggregated across all segments (legacy)
243#[derive(Debug)]
244pub struct GlobalStats {
245    /// Total documents across all segments
246    total_docs: u64,
247    /// Sparse vector statistics per field: field_id -> dimension stats
248    sparse_stats: FxHashMap<u32, SparseFieldStats>,
249    /// Full-text statistics per field: field_id -> term stats
250    text_stats: FxHashMap<u32, TextFieldStats>,
251    /// Generation counter for cache invalidation
252    generation: u64,
253}
254
255/// Statistics for a sparse vector field
256#[derive(Debug, Default)]
257pub struct SparseFieldStats {
258    /// Document frequency per dimension: dim_id -> doc_count
259    pub doc_freqs: FxHashMap<u32, u64>,
260}
261
262/// Statistics for a full-text field
263#[derive(Debug, Default)]
264pub struct TextFieldStats {
265    /// Document frequency per term: term -> doc_count
266    pub doc_freqs: FxHashMap<String, u64>,
267    /// Average field length (for BM25)
268    pub avg_field_len: f32,
269}
270
271impl GlobalStats {
272    /// Create empty stats
273    pub fn new() -> Self {
274        Self {
275            total_docs: 0,
276            sparse_stats: FxHashMap::default(),
277            text_stats: FxHashMap::default(),
278            generation: 0,
279        }
280    }
281
282    /// Total documents in the index
283    #[inline]
284    pub fn total_docs(&self) -> u64 {
285        self.total_docs
286    }
287
288    /// Compute IDF for a sparse vector dimension
289    #[inline]
290    pub fn sparse_idf(&self, field: Field, dim_id: u32) -> f32 {
291        if let Some(stats) = self.sparse_stats.get(&field.0)
292            && let Some(&df) = stats.doc_freqs.get(&dim_id)
293            && df > 0
294        {
295            return (self.total_docs as f32 / df as f32).ln();
296        }
297        0.0
298    }
299
300    /// Compute IDF weights for multiple sparse dimensions
301    pub fn sparse_idf_weights(&self, field: Field, dim_ids: &[u32]) -> Vec<f32> {
302        dim_ids.iter().map(|&d| self.sparse_idf(field, d)).collect()
303    }
304
305    /// Compute IDF for a full-text term
306    #[inline]
307    pub fn text_idf(&self, field: Field, term: &str) -> f32 {
308        if let Some(stats) = self.text_stats.get(&field.0)
309            && let Some(&df) = stats.doc_freqs.get(term)
310        {
311            let n = self.total_docs as f32;
312            let df = df as f32;
313            return ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
314        }
315        0.0
316    }
317
318    /// Get average field length for BM25
319    #[inline]
320    pub fn avg_field_len(&self, field: Field) -> f32 {
321        self.text_stats
322            .get(&field.0)
323            .map(|s| s.avg_field_len)
324            .unwrap_or(1.0)
325    }
326
327    /// Current generation
328    #[inline]
329    pub fn generation(&self) -> u64 {
330        self.generation
331    }
332}
333
334impl Default for GlobalStats {
335    fn default() -> Self {
336        Self::new()
337    }
338}
339
340/// Builder for aggregating statistics from multiple segments
341pub struct GlobalStatsBuilder {
342    /// Total documents across all segments
343    pub total_docs: u64,
344    sparse_stats: FxHashMap<u32, SparseFieldStats>,
345    text_stats: FxHashMap<u32, TextFieldStats>,
346}
347
348impl GlobalStatsBuilder {
349    /// Create a new builder
350    pub fn new() -> Self {
351        Self {
352            total_docs: 0,
353            sparse_stats: FxHashMap::default(),
354            text_stats: FxHashMap::default(),
355        }
356    }
357
358    /// Add statistics from a segment reader
359    pub fn add_segment(&mut self, reader: &SegmentReader) {
360        self.total_docs += reader.num_docs() as u64;
361
362        // Aggregate sparse vector statistics
363        // Note: This requires access to sparse_indexes which may need to be exposed
364    }
365
366    /// Add sparse dimension document frequency
367    pub fn add_sparse_df(&mut self, field: Field, dim_id: u32, doc_count: u64) {
368        let stats = self.sparse_stats.entry(field.0).or_default();
369        *stats.doc_freqs.entry(dim_id).or_insert(0) += doc_count;
370    }
371
372    /// Add text term document frequency
373    pub fn add_text_df(&mut self, field: Field, term: String, doc_count: u64) {
374        let stats = self.text_stats.entry(field.0).or_default();
375        *stats.doc_freqs.entry(term).or_insert(0) += doc_count;
376    }
377
378    /// Set average field length for a text field
379    pub fn set_avg_field_len(&mut self, field: Field, avg_len: f32) {
380        let stats = self.text_stats.entry(field.0).or_default();
381        stats.avg_field_len = avg_len;
382    }
383
384    /// Build the final GlobalStats
385    pub fn build(self, generation: u64) -> GlobalStats {
386        GlobalStats {
387            total_docs: self.total_docs,
388            sparse_stats: self.sparse_stats,
389            text_stats: self.text_stats,
390            generation,
391        }
392    }
393}
394
395impl Default for GlobalStatsBuilder {
396    fn default() -> Self {
397        Self::new()
398    }
399}
400
401/// Cached global statistics with automatic invalidation
402///
403/// This is the main entry point for getting global IDF values.
404/// It caches statistics and rebuilds them when the segment list changes.
405pub struct GlobalStatsCache {
406    /// Cached statistics
407    stats: RwLock<Option<Arc<GlobalStats>>>,
408    /// Current generation (incremented when segments change)
409    generation: RwLock<u64>,
410}
411
412impl GlobalStatsCache {
413    /// Create a new cache
414    pub fn new() -> Self {
415        Self {
416            stats: RwLock::new(None),
417            generation: RwLock::new(0),
418        }
419    }
420
421    /// Invalidate the cache (call when segments are added/removed/merged)
422    pub fn invalidate(&self) {
423        let mut current_gen = self.generation.write();
424        *current_gen += 1;
425        let mut stats = self.stats.write();
426        *stats = None;
427    }
428
429    /// Get current generation
430    pub fn generation(&self) -> u64 {
431        *self.generation.read()
432    }
433
434    /// Get cached stats if valid, or None if needs rebuild
435    pub fn get(&self) -> Option<Arc<GlobalStats>> {
436        self.stats.read().clone()
437    }
438
439    /// Update the cache with new stats
440    pub fn set(&self, stats: GlobalStats) {
441        let mut cached = self.stats.write();
442        *cached = Some(Arc::new(stats));
443    }
444
445    /// Get or compute stats using the provided builder function (sync version)
446    ///
447    /// For basic stats that don't require async iteration.
448    pub fn get_or_compute<F>(&self, compute: F) -> Arc<GlobalStats>
449    where
450        F: FnOnce(&mut GlobalStatsBuilder),
451    {
452        // Fast path: return cached if available
453        if let Some(stats) = self.get() {
454            return stats;
455        }
456
457        // Slow path: compute new stats
458        let current_gen = self.generation();
459        let mut builder = GlobalStatsBuilder::new();
460        compute(&mut builder);
461        let stats = Arc::new(builder.build(current_gen));
462
463        // Cache the result
464        let mut cached = self.stats.write();
465        *cached = Some(Arc::clone(&stats));
466
467        stats
468    }
469
470    /// Check if stats need to be rebuilt
471    pub fn needs_rebuild(&self) -> bool {
472        self.stats.read().is_none()
473    }
474
475    /// Set pre-built stats (for async computation)
476    pub fn set_stats(&self, stats: GlobalStats) {
477        let mut cached = self.stats.write();
478        *cached = Some(Arc::new(stats));
479    }
480}
481
482impl Default for GlobalStatsCache {
483    fn default() -> Self {
484        Self::new()
485    }
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491
492    #[test]
493    fn test_sparse_idf_computation() {
494        let mut builder = GlobalStatsBuilder::new();
495        builder.total_docs = 1000;
496        builder.add_sparse_df(Field(0), 42, 100); // dim 42 appears in 100 docs
497        builder.add_sparse_df(Field(0), 43, 10); // dim 43 appears in 10 docs
498
499        let stats = builder.build(1);
500
501        // IDF = ln(N/df)
502        let idf_42 = stats.sparse_idf(Field(0), 42);
503        let idf_43 = stats.sparse_idf(Field(0), 43);
504
505        // dim 43 should have higher IDF (rarer)
506        assert!(idf_43 > idf_42);
507        assert!((idf_42 - (1000.0_f32 / 100.0).ln()).abs() < 0.001);
508        assert!((idf_43 - (1000.0_f32 / 10.0).ln()).abs() < 0.001);
509    }
510
511    #[test]
512    fn test_text_idf_computation() {
513        let mut builder = GlobalStatsBuilder::new();
514        builder.total_docs = 10000;
515        builder.add_text_df(Field(0), "common".to_string(), 5000);
516        builder.add_text_df(Field(0), "rare".to_string(), 10);
517
518        let stats = builder.build(1);
519
520        let idf_common = stats.text_idf(Field(0), "common");
521        let idf_rare = stats.text_idf(Field(0), "rare");
522
523        // Rare term should have higher IDF
524        assert!(idf_rare > idf_common);
525    }
526
527    #[test]
528    fn test_cache_invalidation() {
529        let cache = GlobalStatsCache::new();
530
531        // Initially no stats
532        assert!(cache.get().is_none());
533
534        // Compute stats
535        let stats = cache.get_or_compute(|builder| {
536            builder.total_docs = 100;
537        });
538        assert_eq!(stats.total_docs(), 100);
539
540        // Should be cached now
541        assert!(cache.get().is_some());
542
543        // Invalidate
544        cache.invalidate();
545        assert!(cache.get().is_none());
546    }
547}