hermes_core/query/
global_stats.rs

1//! Global statistics for cross-segment IDF computation
2//!
3//! Provides cached aggregated statistics across multiple segments for:
4//! - Sparse vector dimensions (for sparse vector queries)
5//! - Full-text terms (for BM25/TF-IDF scoring)
6//!
7//! This implements a coordinator-style approach where statistics are gathered
8//! from all segments, cached, and used for consistent IDF scoring.
9
10use std::sync::Arc;
11
12use parking_lot::RwLock;
13use rustc_hash::FxHashMap;
14
15use crate::dsl::Field;
16use crate::segment::SegmentReader;
17
18/// Global statistics aggregated across all segments
19///
20/// Used for consistent IDF computation in multi-segment indexes.
21/// Statistics are cached and invalidated when segments change.
22#[derive(Debug)]
23pub struct GlobalStats {
24    /// Total documents across all segments
25    total_docs: u64,
26    /// Sparse vector statistics per field: field_id -> dimension stats
27    sparse_stats: FxHashMap<u32, SparseFieldStats>,
28    /// Full-text statistics per field: field_id -> term stats
29    text_stats: FxHashMap<u32, TextFieldStats>,
30    /// Generation counter for cache invalidation
31    generation: u64,
32}
33
34/// Statistics for a sparse vector field
35#[derive(Debug, Default)]
36pub struct SparseFieldStats {
37    /// Document frequency per dimension: dim_id -> doc_count
38    pub doc_freqs: FxHashMap<u32, u64>,
39}
40
41/// Statistics for a full-text field
42#[derive(Debug, Default)]
43pub struct TextFieldStats {
44    /// Document frequency per term: term -> doc_count
45    pub doc_freqs: FxHashMap<String, u64>,
46    /// Average field length (for BM25)
47    pub avg_field_len: f32,
48}
49
50impl GlobalStats {
51    /// Create empty stats
52    pub fn new() -> Self {
53        Self {
54            total_docs: 0,
55            sparse_stats: FxHashMap::default(),
56            text_stats: FxHashMap::default(),
57            generation: 0,
58        }
59    }
60
61    /// Total documents in the index
62    #[inline]
63    pub fn total_docs(&self) -> u64 {
64        self.total_docs
65    }
66
67    /// Compute IDF for a sparse vector dimension
68    ///
69    /// IDF = ln(N / df) where N = total docs, df = docs containing dimension
70    #[inline]
71    pub fn sparse_idf(&self, field: Field, dim_id: u32) -> f32 {
72        if let Some(stats) = self.sparse_stats.get(&field.0)
73            && let Some(&df) = stats.doc_freqs.get(&dim_id)
74            && df > 0
75        {
76            return (self.total_docs as f32 / df as f32).ln();
77        }
78        0.0
79    }
80
81    /// Compute IDF weights for multiple sparse dimensions
82    pub fn sparse_idf_weights(&self, field: Field, dim_ids: &[u32]) -> Vec<f32> {
83        dim_ids.iter().map(|&d| self.sparse_idf(field, d)).collect()
84    }
85
86    /// Compute IDF for a full-text term
87    ///
88    /// IDF = ln((N - df + 0.5) / (df + 0.5) + 1) (BM25 variant)
89    #[inline]
90    pub fn text_idf(&self, field: Field, term: &str) -> f32 {
91        if let Some(stats) = self.text_stats.get(&field.0)
92            && let Some(&df) = stats.doc_freqs.get(term)
93        {
94            let n = self.total_docs as f32;
95            let df = df as f32;
96            // BM25 IDF formula
97            return ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
98        }
99        0.0
100    }
101
102    /// Get average field length for BM25
103    #[inline]
104    pub fn avg_field_len(&self, field: Field) -> f32 {
105        self.text_stats
106            .get(&field.0)
107            .map(|s| s.avg_field_len)
108            .unwrap_or(1.0)
109    }
110
111    /// Current generation (for cache invalidation)
112    #[inline]
113    pub fn generation(&self) -> u64 {
114        self.generation
115    }
116}
117
118impl Default for GlobalStats {
119    fn default() -> Self {
120        Self::new()
121    }
122}
123
124/// Builder for aggregating statistics from multiple segments
125pub struct GlobalStatsBuilder {
126    /// Total documents across all segments
127    pub total_docs: u64,
128    sparse_stats: FxHashMap<u32, SparseFieldStats>,
129    text_stats: FxHashMap<u32, TextFieldStats>,
130}
131
132impl GlobalStatsBuilder {
133    /// Create a new builder
134    pub fn new() -> Self {
135        Self {
136            total_docs: 0,
137            sparse_stats: FxHashMap::default(),
138            text_stats: FxHashMap::default(),
139        }
140    }
141
142    /// Add statistics from a segment reader
143    pub fn add_segment(&mut self, reader: &SegmentReader) {
144        self.total_docs += reader.num_docs() as u64;
145
146        // Aggregate sparse vector statistics
147        // Note: This requires access to sparse_indexes which may need to be exposed
148    }
149
150    /// Add sparse dimension document frequency
151    pub fn add_sparse_df(&mut self, field: Field, dim_id: u32, doc_count: u64) {
152        let stats = self.sparse_stats.entry(field.0).or_default();
153        *stats.doc_freqs.entry(dim_id).or_insert(0) += doc_count;
154    }
155
156    /// Add text term document frequency
157    pub fn add_text_df(&mut self, field: Field, term: String, doc_count: u64) {
158        let stats = self.text_stats.entry(field.0).or_default();
159        *stats.doc_freqs.entry(term).or_insert(0) += doc_count;
160    }
161
162    /// Set average field length for a text field
163    pub fn set_avg_field_len(&mut self, field: Field, avg_len: f32) {
164        let stats = self.text_stats.entry(field.0).or_default();
165        stats.avg_field_len = avg_len;
166    }
167
168    /// Build the final GlobalStats
169    pub fn build(self, generation: u64) -> GlobalStats {
170        GlobalStats {
171            total_docs: self.total_docs,
172            sparse_stats: self.sparse_stats,
173            text_stats: self.text_stats,
174            generation,
175        }
176    }
177}
178
179impl Default for GlobalStatsBuilder {
180    fn default() -> Self {
181        Self::new()
182    }
183}
184
185/// Cached global statistics with automatic invalidation
186///
187/// This is the main entry point for getting global IDF values.
188/// It caches statistics and rebuilds them when the segment list changes.
189pub struct GlobalStatsCache {
190    /// Cached statistics
191    stats: RwLock<Option<Arc<GlobalStats>>>,
192    /// Current generation (incremented when segments change)
193    generation: RwLock<u64>,
194}
195
196impl GlobalStatsCache {
197    /// Create a new cache
198    pub fn new() -> Self {
199        Self {
200            stats: RwLock::new(None),
201            generation: RwLock::new(0),
202        }
203    }
204
205    /// Invalidate the cache (call when segments are added/removed/merged)
206    pub fn invalidate(&self) {
207        let mut current_gen = self.generation.write();
208        *current_gen += 1;
209        let mut stats = self.stats.write();
210        *stats = None;
211    }
212
213    /// Get current generation
214    pub fn generation(&self) -> u64 {
215        *self.generation.read()
216    }
217
218    /// Get cached stats if valid, or None if needs rebuild
219    pub fn get(&self) -> Option<Arc<GlobalStats>> {
220        self.stats.read().clone()
221    }
222
223    /// Update the cache with new stats
224    pub fn set(&self, stats: GlobalStats) {
225        let mut cached = self.stats.write();
226        *cached = Some(Arc::new(stats));
227    }
228
229    /// Get or compute stats using the provided builder function (sync version)
230    ///
231    /// For basic stats that don't require async iteration.
232    pub fn get_or_compute<F>(&self, compute: F) -> Arc<GlobalStats>
233    where
234        F: FnOnce(&mut GlobalStatsBuilder),
235    {
236        // Fast path: return cached if available
237        if let Some(stats) = self.get() {
238            return stats;
239        }
240
241        // Slow path: compute new stats
242        let current_gen = self.generation();
243        let mut builder = GlobalStatsBuilder::new();
244        compute(&mut builder);
245        let stats = Arc::new(builder.build(current_gen));
246
247        // Cache the result
248        let mut cached = self.stats.write();
249        *cached = Some(Arc::clone(&stats));
250
251        stats
252    }
253
254    /// Check if stats need to be rebuilt
255    pub fn needs_rebuild(&self) -> bool {
256        self.stats.read().is_none()
257    }
258
259    /// Set pre-built stats (for async computation)
260    pub fn set_stats(&self, stats: GlobalStats) {
261        let mut cached = self.stats.write();
262        *cached = Some(Arc::new(stats));
263    }
264}
265
266impl Default for GlobalStatsCache {
267    fn default() -> Self {
268        Self::new()
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn test_sparse_idf_computation() {
278        let mut builder = GlobalStatsBuilder::new();
279        builder.total_docs = 1000;
280        builder.add_sparse_df(Field(0), 42, 100); // dim 42 appears in 100 docs
281        builder.add_sparse_df(Field(0), 43, 10); // dim 43 appears in 10 docs
282
283        let stats = builder.build(1);
284
285        // IDF = ln(N/df)
286        let idf_42 = stats.sparse_idf(Field(0), 42);
287        let idf_43 = stats.sparse_idf(Field(0), 43);
288
289        // dim 43 should have higher IDF (rarer)
290        assert!(idf_43 > idf_42);
291        assert!((idf_42 - (1000.0_f32 / 100.0).ln()).abs() < 0.001);
292        assert!((idf_43 - (1000.0_f32 / 10.0).ln()).abs() < 0.001);
293    }
294
295    #[test]
296    fn test_text_idf_computation() {
297        let mut builder = GlobalStatsBuilder::new();
298        builder.total_docs = 10000;
299        builder.add_text_df(Field(0), "common".to_string(), 5000);
300        builder.add_text_df(Field(0), "rare".to_string(), 10);
301
302        let stats = builder.build(1);
303
304        let idf_common = stats.text_idf(Field(0), "common");
305        let idf_rare = stats.text_idf(Field(0), "rare");
306
307        // Rare term should have higher IDF
308        assert!(idf_rare > idf_common);
309    }
310
311    #[test]
312    fn test_cache_invalidation() {
313        let cache = GlobalStatsCache::new();
314
315        // Initially no stats
316        assert!(cache.get().is_none());
317
318        // Compute stats
319        let stats = cache.get_or_compute(|builder| {
320            builder.total_docs = 100;
321        });
322        assert_eq!(stats.total_docs(), 100);
323
324        // Should be cached now
325        assert!(cache.get().is_some());
326
327        // Invalidate
328        cache.invalidate();
329        assert!(cache.get().is_none());
330    }
331}