Skip to main content

oxirs_vec/optimizer/
query_stats.rs

1//! Persistent runtime statistics for online cost-model adaptation.
2//!
3//! Each query through the optimizer reports a [`QueryObservation`] capturing
4//! the family it dispatched to, the wall-clock latency, and the observed
5//! recall (if measurable).  [`QueryStats`] aggregates these observations into
6//! per-family running averages and produces [`recommended_weights`] that the
7//! cost model uses to correct systematic over/underestimates.
8//!
9//! Persistence uses **`serde_json`** so the file is grep-able for operators;
10//! atomic writes go through a temporary `.tmp` sibling and a rename.
11//!
12//! [`recommended_weights`]: QueryStats::recommended_weights
13
14use crate::optimizer::cost_model::{CostWeights, IndexFamily};
15use anyhow::{anyhow, Context, Result};
16use serde::{Deserialize, Serialize};
17use std::collections::BTreeMap;
18use std::fs;
19use std::path::{Path, PathBuf};
20
21/// Single query observation reported back by the dispatcher after execution.
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
23pub struct QueryObservation {
24    /// Family that served the query.
25    pub family: IndexFamily,
26    /// Whether the query produced any results (used for hit/miss counting).
27    pub hit: bool,
28    /// Observed wall-clock latency in microseconds.
29    pub latency_us: f64,
30    /// Observed recall in `[0.0, 1.0]`, or `None` when ground truth is unknown.
31    pub recall: Option<f32>,
32    /// The cost the model predicted for this family at dispatch time.
33    pub predicted_cost: f64,
34}
35
36impl QueryObservation {
37    /// Convenience constructor.
38    pub fn new(
39        family: IndexFamily,
40        hit: bool,
41        latency_us: f64,
42        recall: Option<f32>,
43        predicted_cost: f64,
44    ) -> Self {
45        Self {
46            family,
47            hit,
48            latency_us,
49            recall,
50            predicted_cost,
51        }
52    }
53}
54
55/// Per-family aggregate running statistics.
56#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
57pub struct FamilyStats {
58    /// Total number of queries dispatched to this family.
59    pub queries: u64,
60    /// Number of queries that returned at least one result.
61    pub hits: u64,
62    /// Sum of observed latencies (microseconds) — divide by `queries` for mean.
63    pub total_latency_us: f64,
64    /// Mean observed recall across queries that reported a recall.
65    pub mean_recall: f64,
66    /// Number of queries that contributed to `mean_recall`.
67    pub recall_samples: u64,
68    /// Mean predicted cost at dispatch time.
69    pub mean_predicted_cost: f64,
70}
71
72impl FamilyStats {
73    /// Mean latency in microseconds (returns 0.0 with no samples).
74    pub fn mean_latency_us(&self) -> f64 {
75        if self.queries == 0 {
76            0.0
77        } else {
78            self.total_latency_us / self.queries as f64
79        }
80    }
81
82    /// Hit rate in `[0.0, 1.0]` (returns 1.0 with no samples — assume best).
83    pub fn hit_rate(&self) -> f64 {
84        if self.queries == 0 {
85            1.0
86        } else {
87            self.hits as f64 / self.queries as f64
88        }
89    }
90
91    /// Update this aggregate with a single observation using running mean
92    /// formulas (no buffer growth).
93    fn update(&mut self, obs: &QueryObservation) {
94        self.queries += 1;
95        if obs.hit {
96            self.hits += 1;
97        }
98        self.total_latency_us += obs.latency_us;
99
100        // Running mean for predicted cost.
101        let n = self.queries as f64;
102        self.mean_predicted_cost =
103            self.mean_predicted_cost + (obs.predicted_cost - self.mean_predicted_cost) / n;
104
105        // Recall mean is updated only when the observation reports a recall.
106        if let Some(r) = obs.recall {
107            self.recall_samples += 1;
108            let m = self.recall_samples as f64;
109            self.mean_recall = self.mean_recall + (r as f64 - self.mean_recall) / m;
110        }
111    }
112}
113
114/// Aggregated statistics across all index families.
115///
116/// `version` is bumped when the on-disk layout changes; loaders refuse to
117/// read incompatible versions.
118#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
119pub struct QueryStats {
120    /// Storage format version.
121    pub version: u32,
122    /// Per-family running aggregates.
123    pub families: BTreeMap<IndexFamily, FamilyStats>,
124    /// Total number of observations recorded since creation.
125    pub total_observations: u64,
126}
127
128impl Default for QueryStats {
129    fn default() -> Self {
130        let mut families = BTreeMap::new();
131        for fam in IndexFamily::all() {
132            families.insert(fam, FamilyStats::default());
133        }
134        Self {
135            version: 1,
136            families,
137            total_observations: 0,
138        }
139    }
140}
141
142impl QueryStats {
143    /// On-disk format version this build emits.
144    pub const CURRENT_VERSION: u32 = 1;
145
146    /// Construct an empty stats container.
147    pub fn new() -> Self {
148        Self::default()
149    }
150
151    /// Borrow stats for a specific family (always present after `default()`).
152    pub fn family_stats(&self, family: IndexFamily) -> &FamilyStats {
153        // BTreeMap was populated by Default::default() with all families;
154        // if a deserialized file is missing one we fall back to a stable
155        // pointer into a static cell to avoid panicking.
156        self.families.get(&family).unwrap_or(&FALLBACK_FAMILY_STATS)
157    }
158
159    /// Record a new observation, updating running aggregates.
160    pub fn record(&mut self, obs: QueryObservation) {
161        let family = obs.family;
162        let entry = self.families.entry(family).or_default();
163        entry.update(&obs);
164        self.total_observations += 1;
165    }
166
167    /// Recommend cost-model weights from accumulated observations.
168    ///
169    /// The weight for a family is set to `mean_observed_latency_us /
170    /// mean_predicted_cost`, capped to the safe range enforced by
171    /// [`CostWeights::set`].  Families with no observations keep their
172    /// previous weights.
173    ///
174    /// Pass the current weights as `prior` so untouched families retain
175    /// their existing values.
176    pub fn recommended_weights(&self, prior: &CostWeights) -> CostWeights {
177        let mut next = prior.clone();
178        for fam in IndexFamily::all() {
179            if let Some(stats) = self.families.get(&fam) {
180                if stats.queries == 0 || stats.mean_predicted_cost <= 0.0 {
181                    continue;
182                }
183                let mean_lat = stats.mean_latency_us();
184                if mean_lat <= 0.0 {
185                    continue;
186                }
187                let new_weight = mean_lat / stats.mean_predicted_cost;
188                next.set(fam, new_weight);
189            }
190        }
191        next
192    }
193
194    /// Serialise to a JSON file, atomically replacing any existing copy.
195    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
196        let path = path.as_ref();
197        if let Some(parent) = path.parent() {
198            if !parent.as_os_str().is_empty() {
199                fs::create_dir_all(parent).with_context(|| {
200                    format!("QueryStats::save: failed to create parent dir {:?}", parent)
201                })?;
202            }
203        }
204        let tmp_path = tmp_sibling(path);
205        let json = serde_json::to_string_pretty(self)
206            .context("QueryStats::save: serde_json encode failed")?;
207        fs::write(&tmp_path, json).with_context(|| {
208            format!("QueryStats::save: write to temp file {:?} failed", tmp_path)
209        })?;
210        fs::rename(&tmp_path, path).with_context(|| {
211            format!(
212                "QueryStats::save: rename {:?} -> {:?} failed",
213                tmp_path, path
214            )
215        })?;
216        Ok(())
217    }
218
219    /// Load from a JSON file.  Refuses to read versions newer than this build.
220    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
221        let path = path.as_ref();
222        let bytes =
223            fs::read(path).with_context(|| format!("QueryStats::load: read {:?} failed", path))?;
224        let stats: QueryStats = serde_json::from_slice(&bytes)
225            .with_context(|| format!("QueryStats::load: parse {:?} failed", path))?;
226        if stats.version > Self::CURRENT_VERSION {
227            return Err(anyhow!(
228                "QueryStats::load: version {} is newer than this build's {}",
229                stats.version,
230                Self::CURRENT_VERSION
231            ));
232        }
233        Ok(stats)
234    }
235}
236
237/// Fallback value returned by `family_stats()` when a family is absent
238/// from a deserialized file (defensive — `default()` populates all families).
239static FALLBACK_FAMILY_STATS: FamilyStats = FamilyStats {
240    queries: 0,
241    hits: 0,
242    total_latency_us: 0.0,
243    mean_recall: 0.0,
244    recall_samples: 0,
245    mean_predicted_cost: 0.0,
246};
247
248/// Compute the temporary sibling file path used during atomic writes.
249fn tmp_sibling(path: &Path) -> PathBuf {
250    let mut tmp = path.to_path_buf();
251    let file_name = path
252        .file_name()
253        .map(|f| f.to_string_lossy().to_string())
254        .unwrap_or_else(|| "query_stats".to_string());
255    tmp.set_file_name(format!("{}.tmp", file_name));
256    tmp
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use std::env::temp_dir;
263
264    fn unique_path(label: &str) -> PathBuf {
265        let mut p = temp_dir();
266        let stamp = std::time::SystemTime::now()
267            .duration_since(std::time::UNIX_EPOCH)
268            .map(|d| d.as_nanos())
269            .unwrap_or(0);
270        p.push(format!("oxirs_vec_optstats_{}_{}.json", label, stamp));
271        p
272    }
273
274    #[test]
275    fn family_stats_default_is_zeroed() {
276        let s = FamilyStats::default();
277        assert_eq!(s.queries, 0);
278        assert_eq!(s.hits, 0);
279        assert_eq!(s.total_latency_us, 0.0);
280        assert!(s.mean_recall.abs() < 1e-12);
281        assert!(s.hit_rate() == 1.0); // no data → assume best
282    }
283
284    #[test]
285    fn record_updates_running_means() {
286        let mut stats = QueryStats::new();
287        stats.record(QueryObservation::new(
288            IndexFamily::Hnsw,
289            true,
290            100.0,
291            Some(0.95),
292            80.0,
293        ));
294        stats.record(QueryObservation::new(
295            IndexFamily::Hnsw,
296            true,
297            200.0,
298            Some(0.93),
299            80.0,
300        ));
301        let s = stats.family_stats(IndexFamily::Hnsw);
302        assert_eq!(s.queries, 2);
303        assert_eq!(s.hits, 2);
304        assert!((s.mean_latency_us() - 150.0).abs() < 1e-6);
305        assert!((s.mean_recall - 0.94).abs() < 1e-3);
306        assert_eq!(stats.total_observations, 2);
307    }
308
309    #[test]
310    fn record_handles_missing_recall() {
311        let mut stats = QueryStats::new();
312        stats.record(QueryObservation::new(
313            IndexFamily::Lsh,
314            true,
315            50.0,
316            None,
317            40.0,
318        ));
319        let s = stats.family_stats(IndexFamily::Lsh);
320        assert_eq!(s.queries, 1);
321        assert_eq!(s.recall_samples, 0);
322        assert!(s.mean_recall.abs() < 1e-12);
323    }
324
325    #[test]
326    fn hit_rate_reflects_misses() {
327        let mut stats = QueryStats::new();
328        stats.record(QueryObservation::new(
329            IndexFamily::Pq,
330            true,
331            10.0,
332            None,
333            10.0,
334        ));
335        stats.record(QueryObservation::new(
336            IndexFamily::Pq,
337            false,
338            12.0,
339            None,
340            10.0,
341        ));
342        stats.record(QueryObservation::new(
343            IndexFamily::Pq,
344            false,
345            14.0,
346            None,
347            10.0,
348        ));
349        let r = stats.family_stats(IndexFamily::Pq).hit_rate();
350        assert!((r - (1.0 / 3.0)).abs() < 1e-9);
351    }
352
353    #[test]
354    fn recommended_weights_derive_from_observed_vs_predicted() {
355        let mut stats = QueryStats::new();
356        // Observed average 200µs, predicted 100 → weight should be 2.0.
357        for _ in 0..10 {
358            stats.record(QueryObservation::new(
359                IndexFamily::Hnsw,
360                true,
361                200.0,
362                Some(0.95),
363                100.0,
364            ));
365        }
366        let w = stats.recommended_weights(&CostWeights::default());
367        assert!((w.get(IndexFamily::Hnsw) - 2.0).abs() < 1e-6);
368        // Untouched families keep prior 1.0.
369        assert!((w.get(IndexFamily::Ivf) - 1.0).abs() < 1e-12);
370    }
371
372    #[test]
373    fn recommended_weights_clamped_for_outliers() {
374        let mut stats = QueryStats::new();
375        // Predicted near zero → would yield enormous weight; must clamp.
376        stats.record(QueryObservation::new(
377            IndexFamily::Lsh,
378            true,
379            5_000.0,
380            None,
381            0.001,
382        ));
383        let w = stats.recommended_weights(&CostWeights::default());
384        // Clamp ceiling is 20.0 in CostWeights::set.
385        assert!((w.get(IndexFamily::Lsh) - 20.0).abs() < 1e-6);
386    }
387
388    #[test]
389    fn save_load_roundtrip() -> Result<()> {
390        let path = unique_path("roundtrip");
391        let mut original = QueryStats::new();
392        original.record(QueryObservation::new(
393            IndexFamily::Ivf,
394            true,
395            150.0,
396            Some(0.91),
397            120.0,
398        ));
399        original.save(&path)?;
400        let loaded = QueryStats::load(&path)?;
401        // JSON serialisation can introduce ≤1 ULP float drift on the recall
402        // mean (f32→f64→f32); fields-level equality with epsilon comparison
403        // is the right test, not bitwise.
404        assert_eq!(loaded.version, original.version);
405        assert_eq!(loaded.total_observations, original.total_observations);
406        let lhs = loaded.family_stats(IndexFamily::Ivf);
407        let rhs = original.family_stats(IndexFamily::Ivf);
408        assert_eq!(lhs.queries, rhs.queries);
409        assert_eq!(lhs.hits, rhs.hits);
410        assert!((lhs.total_latency_us - rhs.total_latency_us).abs() < 1e-9);
411        assert!((lhs.mean_recall - rhs.mean_recall).abs() < 1e-6);
412        assert_eq!(lhs.recall_samples, rhs.recall_samples);
413        assert!((lhs.mean_predicted_cost - rhs.mean_predicted_cost).abs() < 1e-9);
414        let _ = fs::remove_file(&path);
415        Ok(())
416    }
417
418    #[test]
419    fn load_rejects_future_version() -> Result<()> {
420        let path = unique_path("future");
421        let mut stats = QueryStats::new();
422        stats.version = QueryStats::CURRENT_VERSION + 1;
423        let json = serde_json::to_string_pretty(&stats)?;
424        fs::write(&path, json)?;
425        let res = QueryStats::load(&path);
426        assert!(res.is_err(), "future version must be rejected");
427        let _ = fs::remove_file(&path);
428        Ok(())
429    }
430
431    #[test]
432    fn load_rejects_corrupt_json() {
433        let path = unique_path("corrupt");
434        fs::write(&path, b"{not json}").expect("temp write");
435        let res = QueryStats::load(&path);
436        assert!(res.is_err());
437        let _ = fs::remove_file(&path);
438    }
439
440    #[test]
441    fn fallback_returned_for_missing_family() {
442        // Construct stats *without* using Default to simulate an old
443        // serialization that omitted some families.
444        let stats = QueryStats {
445            version: 1,
446            families: BTreeMap::new(),
447            total_observations: 0,
448        };
449        let s = stats.family_stats(IndexFamily::Hnsw);
450        assert_eq!(s.queries, 0);
451    }
452}