1use crate::optimizer::cost_model::{CostWeights, IndexFamily};
15use anyhow::{anyhow, Context, Result};
16use serde::{Deserialize, Serialize};
17use std::collections::BTreeMap;
18use std::fs;
19use std::path::{Path, PathBuf};
20
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
23pub struct QueryObservation {
24 pub family: IndexFamily,
26 pub hit: bool,
28 pub latency_us: f64,
30 pub recall: Option<f32>,
32 pub predicted_cost: f64,
34}
35
36impl QueryObservation {
37 pub fn new(
39 family: IndexFamily,
40 hit: bool,
41 latency_us: f64,
42 recall: Option<f32>,
43 predicted_cost: f64,
44 ) -> Self {
45 Self {
46 family,
47 hit,
48 latency_us,
49 recall,
50 predicted_cost,
51 }
52 }
53}
54
55#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
57pub struct FamilyStats {
58 pub queries: u64,
60 pub hits: u64,
62 pub total_latency_us: f64,
64 pub mean_recall: f64,
66 pub recall_samples: u64,
68 pub mean_predicted_cost: f64,
70}
71
72impl FamilyStats {
73 pub fn mean_latency_us(&self) -> f64 {
75 if self.queries == 0 {
76 0.0
77 } else {
78 self.total_latency_us / self.queries as f64
79 }
80 }
81
82 pub fn hit_rate(&self) -> f64 {
84 if self.queries == 0 {
85 1.0
86 } else {
87 self.hits as f64 / self.queries as f64
88 }
89 }
90
91 fn update(&mut self, obs: &QueryObservation) {
94 self.queries += 1;
95 if obs.hit {
96 self.hits += 1;
97 }
98 self.total_latency_us += obs.latency_us;
99
100 let n = self.queries as f64;
102 self.mean_predicted_cost =
103 self.mean_predicted_cost + (obs.predicted_cost - self.mean_predicted_cost) / n;
104
105 if let Some(r) = obs.recall {
107 self.recall_samples += 1;
108 let m = self.recall_samples as f64;
109 self.mean_recall = self.mean_recall + (r as f64 - self.mean_recall) / m;
110 }
111 }
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
119pub struct QueryStats {
120 pub version: u32,
122 pub families: BTreeMap<IndexFamily, FamilyStats>,
124 pub total_observations: u64,
126}
127
128impl Default for QueryStats {
129 fn default() -> Self {
130 let mut families = BTreeMap::new();
131 for fam in IndexFamily::all() {
132 families.insert(fam, FamilyStats::default());
133 }
134 Self {
135 version: 1,
136 families,
137 total_observations: 0,
138 }
139 }
140}
141
142impl QueryStats {
143 pub const CURRENT_VERSION: u32 = 1;
145
146 pub fn new() -> Self {
148 Self::default()
149 }
150
151 pub fn family_stats(&self, family: IndexFamily) -> &FamilyStats {
153 self.families.get(&family).unwrap_or(&FALLBACK_FAMILY_STATS)
157 }
158
159 pub fn record(&mut self, obs: QueryObservation) {
161 let family = obs.family;
162 let entry = self.families.entry(family).or_default();
163 entry.update(&obs);
164 self.total_observations += 1;
165 }
166
167 pub fn recommended_weights(&self, prior: &CostWeights) -> CostWeights {
177 let mut next = prior.clone();
178 for fam in IndexFamily::all() {
179 if let Some(stats) = self.families.get(&fam) {
180 if stats.queries == 0 || stats.mean_predicted_cost <= 0.0 {
181 continue;
182 }
183 let mean_lat = stats.mean_latency_us();
184 if mean_lat <= 0.0 {
185 continue;
186 }
187 let new_weight = mean_lat / stats.mean_predicted_cost;
188 next.set(fam, new_weight);
189 }
190 }
191 next
192 }
193
194 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
196 let path = path.as_ref();
197 if let Some(parent) = path.parent() {
198 if !parent.as_os_str().is_empty() {
199 fs::create_dir_all(parent).with_context(|| {
200 format!("QueryStats::save: failed to create parent dir {:?}", parent)
201 })?;
202 }
203 }
204 let tmp_path = tmp_sibling(path);
205 let json = serde_json::to_string_pretty(self)
206 .context("QueryStats::save: serde_json encode failed")?;
207 fs::write(&tmp_path, json).with_context(|| {
208 format!("QueryStats::save: write to temp file {:?} failed", tmp_path)
209 })?;
210 fs::rename(&tmp_path, path).with_context(|| {
211 format!(
212 "QueryStats::save: rename {:?} -> {:?} failed",
213 tmp_path, path
214 )
215 })?;
216 Ok(())
217 }
218
219 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
221 let path = path.as_ref();
222 let bytes =
223 fs::read(path).with_context(|| format!("QueryStats::load: read {:?} failed", path))?;
224 let stats: QueryStats = serde_json::from_slice(&bytes)
225 .with_context(|| format!("QueryStats::load: parse {:?} failed", path))?;
226 if stats.version > Self::CURRENT_VERSION {
227 return Err(anyhow!(
228 "QueryStats::load: version {} is newer than this build's {}",
229 stats.version,
230 Self::CURRENT_VERSION
231 ));
232 }
233 Ok(stats)
234 }
235}
236
237static FALLBACK_FAMILY_STATS: FamilyStats = FamilyStats {
240 queries: 0,
241 hits: 0,
242 total_latency_us: 0.0,
243 mean_recall: 0.0,
244 recall_samples: 0,
245 mean_predicted_cost: 0.0,
246};
247
248fn tmp_sibling(path: &Path) -> PathBuf {
250 let mut tmp = path.to_path_buf();
251 let file_name = path
252 .file_name()
253 .map(|f| f.to_string_lossy().to_string())
254 .unwrap_or_else(|| "query_stats".to_string());
255 tmp.set_file_name(format!("{}.tmp", file_name));
256 tmp
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use std::env::temp_dir;
263
264 fn unique_path(label: &str) -> PathBuf {
265 let mut p = temp_dir();
266 let stamp = std::time::SystemTime::now()
267 .duration_since(std::time::UNIX_EPOCH)
268 .map(|d| d.as_nanos())
269 .unwrap_or(0);
270 p.push(format!("oxirs_vec_optstats_{}_{}.json", label, stamp));
271 p
272 }
273
274 #[test]
275 fn family_stats_default_is_zeroed() {
276 let s = FamilyStats::default();
277 assert_eq!(s.queries, 0);
278 assert_eq!(s.hits, 0);
279 assert_eq!(s.total_latency_us, 0.0);
280 assert!(s.mean_recall.abs() < 1e-12);
281 assert!(s.hit_rate() == 1.0); }
283
284 #[test]
285 fn record_updates_running_means() {
286 let mut stats = QueryStats::new();
287 stats.record(QueryObservation::new(
288 IndexFamily::Hnsw,
289 true,
290 100.0,
291 Some(0.95),
292 80.0,
293 ));
294 stats.record(QueryObservation::new(
295 IndexFamily::Hnsw,
296 true,
297 200.0,
298 Some(0.93),
299 80.0,
300 ));
301 let s = stats.family_stats(IndexFamily::Hnsw);
302 assert_eq!(s.queries, 2);
303 assert_eq!(s.hits, 2);
304 assert!((s.mean_latency_us() - 150.0).abs() < 1e-6);
305 assert!((s.mean_recall - 0.94).abs() < 1e-3);
306 assert_eq!(stats.total_observations, 2);
307 }
308
309 #[test]
310 fn record_handles_missing_recall() {
311 let mut stats = QueryStats::new();
312 stats.record(QueryObservation::new(
313 IndexFamily::Lsh,
314 true,
315 50.0,
316 None,
317 40.0,
318 ));
319 let s = stats.family_stats(IndexFamily::Lsh);
320 assert_eq!(s.queries, 1);
321 assert_eq!(s.recall_samples, 0);
322 assert!(s.mean_recall.abs() < 1e-12);
323 }
324
325 #[test]
326 fn hit_rate_reflects_misses() {
327 let mut stats = QueryStats::new();
328 stats.record(QueryObservation::new(
329 IndexFamily::Pq,
330 true,
331 10.0,
332 None,
333 10.0,
334 ));
335 stats.record(QueryObservation::new(
336 IndexFamily::Pq,
337 false,
338 12.0,
339 None,
340 10.0,
341 ));
342 stats.record(QueryObservation::new(
343 IndexFamily::Pq,
344 false,
345 14.0,
346 None,
347 10.0,
348 ));
349 let r = stats.family_stats(IndexFamily::Pq).hit_rate();
350 assert!((r - (1.0 / 3.0)).abs() < 1e-9);
351 }
352
353 #[test]
354 fn recommended_weights_derive_from_observed_vs_predicted() {
355 let mut stats = QueryStats::new();
356 for _ in 0..10 {
358 stats.record(QueryObservation::new(
359 IndexFamily::Hnsw,
360 true,
361 200.0,
362 Some(0.95),
363 100.0,
364 ));
365 }
366 let w = stats.recommended_weights(&CostWeights::default());
367 assert!((w.get(IndexFamily::Hnsw) - 2.0).abs() < 1e-6);
368 assert!((w.get(IndexFamily::Ivf) - 1.0).abs() < 1e-12);
370 }
371
372 #[test]
373 fn recommended_weights_clamped_for_outliers() {
374 let mut stats = QueryStats::new();
375 stats.record(QueryObservation::new(
377 IndexFamily::Lsh,
378 true,
379 5_000.0,
380 None,
381 0.001,
382 ));
383 let w = stats.recommended_weights(&CostWeights::default());
384 assert!((w.get(IndexFamily::Lsh) - 20.0).abs() < 1e-6);
386 }
387
388 #[test]
389 fn save_load_roundtrip() -> Result<()> {
390 let path = unique_path("roundtrip");
391 let mut original = QueryStats::new();
392 original.record(QueryObservation::new(
393 IndexFamily::Ivf,
394 true,
395 150.0,
396 Some(0.91),
397 120.0,
398 ));
399 original.save(&path)?;
400 let loaded = QueryStats::load(&path)?;
401 assert_eq!(loaded.version, original.version);
405 assert_eq!(loaded.total_observations, original.total_observations);
406 let lhs = loaded.family_stats(IndexFamily::Ivf);
407 let rhs = original.family_stats(IndexFamily::Ivf);
408 assert_eq!(lhs.queries, rhs.queries);
409 assert_eq!(lhs.hits, rhs.hits);
410 assert!((lhs.total_latency_us - rhs.total_latency_us).abs() < 1e-9);
411 assert!((lhs.mean_recall - rhs.mean_recall).abs() < 1e-6);
412 assert_eq!(lhs.recall_samples, rhs.recall_samples);
413 assert!((lhs.mean_predicted_cost - rhs.mean_predicted_cost).abs() < 1e-9);
414 let _ = fs::remove_file(&path);
415 Ok(())
416 }
417
418 #[test]
419 fn load_rejects_future_version() -> Result<()> {
420 let path = unique_path("future");
421 let mut stats = QueryStats::new();
422 stats.version = QueryStats::CURRENT_VERSION + 1;
423 let json = serde_json::to_string_pretty(&stats)?;
424 fs::write(&path, json)?;
425 let res = QueryStats::load(&path);
426 assert!(res.is_err(), "future version must be rejected");
427 let _ = fs::remove_file(&path);
428 Ok(())
429 }
430
431 #[test]
432 fn load_rejects_corrupt_json() {
433 let path = unique_path("corrupt");
434 fs::write(&path, b"{not json}").expect("temp write");
435 let res = QueryStats::load(&path);
436 assert!(res.is_err());
437 let _ = fs::remove_file(&path);
438 }
439
440 #[test]
441 fn fallback_returned_for_missing_family() {
442 let stats = QueryStats {
445 version: 1,
446 families: BTreeMap::new(),
447 total_observations: 0,
448 };
449 let s = stats.family_stats(IndexFamily::Hnsw);
450 assert_eq!(s.queries, 0);
451 }
452}