use crate::optimizer::cost_model::{CostWeights, IndexFamily};
use anyhow::{anyhow, Context, Result};
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::fs;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct QueryObservation {
pub family: IndexFamily,
pub hit: bool,
pub latency_us: f64,
pub recall: Option<f32>,
pub predicted_cost: f64,
}
impl QueryObservation {
pub fn new(
family: IndexFamily,
hit: bool,
latency_us: f64,
recall: Option<f32>,
predicted_cost: f64,
) -> Self {
Self {
family,
hit,
latency_us,
recall,
predicted_cost,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct FamilyStats {
pub queries: u64,
pub hits: u64,
pub total_latency_us: f64,
pub mean_recall: f64,
pub recall_samples: u64,
pub mean_predicted_cost: f64,
}
impl FamilyStats {
pub fn mean_latency_us(&self) -> f64 {
if self.queries == 0 {
0.0
} else {
self.total_latency_us / self.queries as f64
}
}
pub fn hit_rate(&self) -> f64 {
if self.queries == 0 {
1.0
} else {
self.hits as f64 / self.queries as f64
}
}
fn update(&mut self, obs: &QueryObservation) {
self.queries += 1;
if obs.hit {
self.hits += 1;
}
self.total_latency_us += obs.latency_us;
let n = self.queries as f64;
self.mean_predicted_cost =
self.mean_predicted_cost + (obs.predicted_cost - self.mean_predicted_cost) / n;
if let Some(r) = obs.recall {
self.recall_samples += 1;
let m = self.recall_samples as f64;
self.mean_recall = self.mean_recall + (r as f64 - self.mean_recall) / m;
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct QueryStats {
pub version: u32,
pub families: BTreeMap<IndexFamily, FamilyStats>,
pub total_observations: u64,
}
impl Default for QueryStats {
fn default() -> Self {
let mut families = BTreeMap::new();
for fam in IndexFamily::all() {
families.insert(fam, FamilyStats::default());
}
Self {
version: 1,
families,
total_observations: 0,
}
}
}
impl QueryStats {
pub const CURRENT_VERSION: u32 = 1;
pub fn new() -> Self {
Self::default()
}
pub fn family_stats(&self, family: IndexFamily) -> &FamilyStats {
self.families.get(&family).unwrap_or(&FALLBACK_FAMILY_STATS)
}
pub fn record(&mut self, obs: QueryObservation) {
let family = obs.family;
let entry = self.families.entry(family).or_default();
entry.update(&obs);
self.total_observations += 1;
}
pub fn recommended_weights(&self, prior: &CostWeights) -> CostWeights {
let mut next = prior.clone();
for fam in IndexFamily::all() {
if let Some(stats) = self.families.get(&fam) {
if stats.queries == 0 || stats.mean_predicted_cost <= 0.0 {
continue;
}
let mean_lat = stats.mean_latency_us();
if mean_lat <= 0.0 {
continue;
}
let new_weight = mean_lat / stats.mean_predicted_cost;
next.set(fam, new_weight);
}
}
next
}
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let path = path.as_ref();
if let Some(parent) = path.parent() {
if !parent.as_os_str().is_empty() {
fs::create_dir_all(parent).with_context(|| {
format!("QueryStats::save: failed to create parent dir {:?}", parent)
})?;
}
}
let tmp_path = tmp_sibling(path);
let json = serde_json::to_string_pretty(self)
.context("QueryStats::save: serde_json encode failed")?;
fs::write(&tmp_path, json).with_context(|| {
format!("QueryStats::save: write to temp file {:?} failed", tmp_path)
})?;
fs::rename(&tmp_path, path).with_context(|| {
format!(
"QueryStats::save: rename {:?} -> {:?} failed",
tmp_path, path
)
})?;
Ok(())
}
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let path = path.as_ref();
let bytes =
fs::read(path).with_context(|| format!("QueryStats::load: read {:?} failed", path))?;
let stats: QueryStats = serde_json::from_slice(&bytes)
.with_context(|| format!("QueryStats::load: parse {:?} failed", path))?;
if stats.version > Self::CURRENT_VERSION {
return Err(anyhow!(
"QueryStats::load: version {} is newer than this build's {}",
stats.version,
Self::CURRENT_VERSION
));
}
Ok(stats)
}
}
static FALLBACK_FAMILY_STATS: FamilyStats = FamilyStats {
queries: 0,
hits: 0,
total_latency_us: 0.0,
mean_recall: 0.0,
recall_samples: 0,
mean_predicted_cost: 0.0,
};
fn tmp_sibling(path: &Path) -> PathBuf {
let mut tmp = path.to_path_buf();
let file_name = path
.file_name()
.map(|f| f.to_string_lossy().to_string())
.unwrap_or_else(|| "query_stats".to_string());
tmp.set_file_name(format!("{}.tmp", file_name));
tmp
}
#[cfg(test)]
mod tests {
use super::*;
use std::env::temp_dir;
fn unique_path(label: &str) -> PathBuf {
let mut p = temp_dir();
let stamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
p.push(format!("oxirs_vec_optstats_{}_{}.json", label, stamp));
p
}
#[test]
fn family_stats_default_is_zeroed() {
let s = FamilyStats::default();
assert_eq!(s.queries, 0);
assert_eq!(s.hits, 0);
assert_eq!(s.total_latency_us, 0.0);
assert!(s.mean_recall.abs() < 1e-12);
assert!(s.hit_rate() == 1.0); }
#[test]
fn record_updates_running_means() {
let mut stats = QueryStats::new();
stats.record(QueryObservation::new(
IndexFamily::Hnsw,
true,
100.0,
Some(0.95),
80.0,
));
stats.record(QueryObservation::new(
IndexFamily::Hnsw,
true,
200.0,
Some(0.93),
80.0,
));
let s = stats.family_stats(IndexFamily::Hnsw);
assert_eq!(s.queries, 2);
assert_eq!(s.hits, 2);
assert!((s.mean_latency_us() - 150.0).abs() < 1e-6);
assert!((s.mean_recall - 0.94).abs() < 1e-3);
assert_eq!(stats.total_observations, 2);
}
#[test]
fn record_handles_missing_recall() {
let mut stats = QueryStats::new();
stats.record(QueryObservation::new(
IndexFamily::Lsh,
true,
50.0,
None,
40.0,
));
let s = stats.family_stats(IndexFamily::Lsh);
assert_eq!(s.queries, 1);
assert_eq!(s.recall_samples, 0);
assert!(s.mean_recall.abs() < 1e-12);
}
#[test]
fn hit_rate_reflects_misses() {
let mut stats = QueryStats::new();
stats.record(QueryObservation::new(
IndexFamily::Pq,
true,
10.0,
None,
10.0,
));
stats.record(QueryObservation::new(
IndexFamily::Pq,
false,
12.0,
None,
10.0,
));
stats.record(QueryObservation::new(
IndexFamily::Pq,
false,
14.0,
None,
10.0,
));
let r = stats.family_stats(IndexFamily::Pq).hit_rate();
assert!((r - (1.0 / 3.0)).abs() < 1e-9);
}
#[test]
fn recommended_weights_derive_from_observed_vs_predicted() {
let mut stats = QueryStats::new();
for _ in 0..10 {
stats.record(QueryObservation::new(
IndexFamily::Hnsw,
true,
200.0,
Some(0.95),
100.0,
));
}
let w = stats.recommended_weights(&CostWeights::default());
assert!((w.get(IndexFamily::Hnsw) - 2.0).abs() < 1e-6);
assert!((w.get(IndexFamily::Ivf) - 1.0).abs() < 1e-12);
}
#[test]
fn recommended_weights_clamped_for_outliers() {
let mut stats = QueryStats::new();
stats.record(QueryObservation::new(
IndexFamily::Lsh,
true,
5_000.0,
None,
0.001,
));
let w = stats.recommended_weights(&CostWeights::default());
assert!((w.get(IndexFamily::Lsh) - 20.0).abs() < 1e-6);
}
#[test]
fn save_load_roundtrip() -> Result<()> {
let path = unique_path("roundtrip");
let mut original = QueryStats::new();
original.record(QueryObservation::new(
IndexFamily::Ivf,
true,
150.0,
Some(0.91),
120.0,
));
original.save(&path)?;
let loaded = QueryStats::load(&path)?;
assert_eq!(loaded.version, original.version);
assert_eq!(loaded.total_observations, original.total_observations);
let lhs = loaded.family_stats(IndexFamily::Ivf);
let rhs = original.family_stats(IndexFamily::Ivf);
assert_eq!(lhs.queries, rhs.queries);
assert_eq!(lhs.hits, rhs.hits);
assert!((lhs.total_latency_us - rhs.total_latency_us).abs() < 1e-9);
assert!((lhs.mean_recall - rhs.mean_recall).abs() < 1e-6);
assert_eq!(lhs.recall_samples, rhs.recall_samples);
assert!((lhs.mean_predicted_cost - rhs.mean_predicted_cost).abs() < 1e-9);
let _ = fs::remove_file(&path);
Ok(())
}
#[test]
fn load_rejects_future_version() -> Result<()> {
let path = unique_path("future");
let mut stats = QueryStats::new();
stats.version = QueryStats::CURRENT_VERSION + 1;
let json = serde_json::to_string_pretty(&stats)?;
fs::write(&path, json)?;
let res = QueryStats::load(&path);
assert!(res.is_err(), "future version must be rejected");
let _ = fs::remove_file(&path);
Ok(())
}
#[test]
fn load_rejects_corrupt_json() {
let path = unique_path("corrupt");
fs::write(&path, b"{not json}").expect("temp write");
let res = QueryStats::load(&path);
assert!(res.is_err());
let _ = fs::remove_file(&path);
}
#[test]
fn fallback_returned_for_missing_family() {
let stats = QueryStats {
version: 1,
families: BTreeMap::new(),
total_observations: 0,
};
let s = stats.family_stats(IndexFamily::Hnsw);
assert_eq!(s.queries, 0);
}
}