use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy)]
pub struct ErrorSample {
pub proxy: f32,
pub true_score: f32,
pub error: f32,
}
impl ErrorSample {
pub fn new(proxy: f32, true_score: f32) -> Self {
Self {
proxy,
true_score,
error: proxy - true_score,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorEnvelope {
pub list_idx: u32,
pub quantiles: HashMap<u32, f32>,
pub mean_error: f32,
pub std_error: f32,
pub max_error: f32,
pub min_error: f32,
pub sample_count: u32,
}
impl ErrorEnvelope {
pub fn error_at_quantile(&self, quantile: f32) -> f32 {
let key = (quantile * 10000.0).round() as u32;
if let Some(&error) = self.quantiles.get(&key) {
return error;
}
let mut below_key = 0u32;
let mut above_key = 10000u32;
let mut below_val = self.min_error;
let mut above_val = self.max_error;
for (&k, &v) in &self.quantiles {
if k < key && k > below_key {
below_key = k;
below_val = v;
}
if k > key && k < above_key {
above_key = k;
above_val = v;
}
}
if above_key > below_key {
let t = (key - below_key) as f32 / (above_key - below_key) as f32;
below_val + t * (above_val - below_val)
} else {
self.max_error
}
}
pub fn safe_true_threshold(&self, proxy: f32, confidence: f32) -> f32 {
let error_bound = self.error_at_quantile(confidence);
proxy - error_bound
}
pub fn safe_proxy_threshold(&self, true_threshold: f32, confidence: f32) -> f32 {
let error_bound = self.error_at_quantile(confidence);
true_threshold + error_bound
}
pub fn definitely_beats(&self, proxy: f32, true_threshold: f32) -> bool {
proxy - self.max_error > true_threshold
}
pub fn might_beat(&self, proxy: f32, true_threshold: f32, confidence: f32) -> bool {
let error_bound = self.error_at_quantile(confidence);
proxy - error_bound > true_threshold
}
}
impl Default for ErrorEnvelope {
fn default() -> Self {
Self {
list_idx: 0,
quantiles: HashMap::new(),
mean_error: 0.0,
std_error: 0.0,
max_error: 0.0,
min_error: 0.0,
sample_count: 0,
}
}
}
pub struct ErrorCalibrator {
samples: Vec<Vec<ErrorSample>>,
n_lists: usize,
quantiles: Vec<f32>,
}
impl ErrorCalibrator {
pub fn new(n_lists: usize) -> Self {
Self {
samples: vec![Vec::new(); n_lists],
n_lists,
quantiles: vec![0.50, 0.75, 0.90, 0.95, 0.99, 0.999],
}
}
pub fn with_quantiles(n_lists: usize, quantiles: Vec<f32>) -> Self {
Self {
samples: vec![Vec::new(); n_lists],
n_lists,
quantiles,
}
}
pub fn record_error(&mut self, list_idx: usize, proxy: f32, true_score: f32) {
if list_idx < self.n_lists {
self.samples[list_idx].push(ErrorSample::new(proxy, true_score));
}
}
pub fn record_errors(&mut self, list_idx: usize, samples: &[(f32, f32)]) {
if list_idx < self.n_lists {
for &(proxy, true_score) in samples {
self.samples[list_idx].push(ErrorSample::new(proxy, true_score));
}
}
}
pub fn finalize(&self) -> ErrorEnvelopeSet {
let envelopes: Vec<ErrorEnvelope> = (0..self.n_lists)
.map(|i| self.compute_envelope(i))
.collect();
let global = self.compute_global_envelope();
ErrorEnvelopeSet { envelopes, global }
}
fn compute_envelope(&self, list_idx: usize) -> ErrorEnvelope {
let samples = &self.samples[list_idx];
if samples.is_empty() {
return ErrorEnvelope {
list_idx: list_idx as u32,
..Default::default()
};
}
let mut errors: Vec<f32> = samples.iter().map(|s| s.error).collect();
errors.sort_by(|a, b| a.partial_cmp(b).unwrap());
let n = errors.len();
let sum: f32 = errors.iter().sum();
let mean = sum / n as f32;
let variance: f32 = errors.iter().map(|&e| (e - mean).powi(2)).sum::<f32>() / n as f32;
let std = variance.sqrt();
let mut quantiles = HashMap::new();
for &q in &self.quantiles {
let idx = ((n as f32 * q) as usize).min(n - 1);
let key = (q * 10000.0).round() as u32;
quantiles.insert(key, errors[idx]);
}
ErrorEnvelope {
list_idx: list_idx as u32,
quantiles,
mean_error: mean,
std_error: std,
max_error: errors[n - 1],
min_error: errors[0],
sample_count: n as u32,
}
}
fn compute_global_envelope(&self) -> ErrorEnvelope {
let mut all_errors: Vec<f32> = self
.samples
.iter()
.flat_map(|s| s.iter().map(|e| e.error))
.collect();
if all_errors.is_empty() {
return ErrorEnvelope::default();
}
all_errors.sort_by(|a, b| a.partial_cmp(b).unwrap());
let n = all_errors.len();
let sum: f32 = all_errors.iter().sum();
let mean = sum / n as f32;
let variance: f32 = all_errors.iter().map(|&e| (e - mean).powi(2)).sum::<f32>() / n as f32;
let std = variance.sqrt();
let mut quantiles = HashMap::new();
for &q in &self.quantiles {
let idx = ((n as f32 * q) as usize).min(n - 1);
let key = (q * 10000.0).round() as u32;
quantiles.insert(key, all_errors[idx]);
}
ErrorEnvelope {
list_idx: u32::MAX, quantiles,
mean_error: mean,
std_error: std,
max_error: all_errors[n - 1],
min_error: all_errors[0],
sample_count: n as u32,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorEnvelopeSet {
pub envelopes: Vec<ErrorEnvelope>,
pub global: ErrorEnvelope,
}
impl ErrorEnvelopeSet {
pub fn get(&self, list_idx: usize) -> &ErrorEnvelope {
if list_idx < self.envelopes.len() && self.envelopes[list_idx].sample_count > 0 {
&self.envelopes[list_idx]
} else {
&self.global
}
}
pub fn safe_true_threshold(&self, list_idx: usize, proxy: f32, confidence: f32) -> f32 {
self.get(list_idx).safe_true_threshold(proxy, confidence)
}
pub fn can_terminate(
&self,
kth_proxy: f32,
remaining_list_bounds: &[(usize, f32)],
confidence: f32,
) -> bool {
let kth_true_lower = self.global.safe_true_threshold(kth_proxy, confidence);
remaining_list_bounds.iter().all(|(list_idx, bound)| {
let envelope = self.get(*list_idx);
let true_upper = *bound + envelope.max_error.abs();
true_upper < kth_true_lower
})
}
pub fn to_bytes(&self) -> Vec<u8> {
bincode::serialize(self).unwrap_or_default()
}
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
bincode::deserialize(bytes).ok()
}
}
pub struct CalibrationRunner {
n_lists: usize,
quantize_fn: Option<Box<dyn Fn(&[f32]) -> Vec<u8> + Send + Sync>>,
proxy_distance_fn: Option<Box<dyn Fn(&[f32], &[u8]) -> f32 + Send + Sync>>,
true_distance_fn: Option<Box<dyn Fn(&[f32], &[f32]) -> f32 + Send + Sync>>,
}
impl CalibrationRunner {
pub fn new(n_lists: usize) -> Self {
Self {
n_lists,
quantize_fn: None,
proxy_distance_fn: None,
true_distance_fn: None,
}
}
pub fn calibrate(
&self,
queries: &[Vec<f32>],
lists: &[Vec<Vec<f32>>],
quantized_lists: &[Vec<Vec<u8>>],
) -> ErrorEnvelopeSet {
let mut calibrator = ErrorCalibrator::new(self.n_lists);
for query in queries {
for (list_idx, (vectors, codes)) in lists.iter().zip(quantized_lists.iter()).enumerate()
{
for (vec, code) in vectors.iter().zip(codes.iter()) {
let true_score = dot_product(query, vec);
let proxy_score = if let Some(ref f) = self.proxy_distance_fn {
f(query, code)
} else {
true_score };
calibrator.record_error(list_idx, proxy_score, true_score);
}
}
}
calibrator.finalize()
}
pub fn calibrate_synthetic(
n_lists: usize,
mean_error: f32,
std_error: f32,
samples_per_list: usize,
) -> ErrorEnvelopeSet {
let mut calibrator = ErrorCalibrator::new(n_lists);
let mut rng_state: u64 = 12345;
let mut rand = || {
rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
(rng_state >> 33) as f32 / (1u64 << 31) as f32
};
for list_idx in 0..n_lists {
for _ in 0..samples_per_list {
let u1 = rand();
let u2 = rand();
let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
let error = mean_error + std_error * z;
let true_score = 0.5 + rand() * 0.5; let proxy_score = true_score + error;
calibrator.record_error(list_idx, proxy_score, true_score);
}
}
calibrator.finalize()
}
}
fn dot_product(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_sample() {
let sample = ErrorSample::new(0.92, 0.90);
assert!((sample.error - 0.02).abs() < 1e-6);
}
#[test]
fn test_calibrator() {
let mut calibrator = ErrorCalibrator::new(3);
calibrator.record_error(0, 0.90, 0.88);
calibrator.record_error(0, 0.85, 0.82);
calibrator.record_error(0, 0.92, 0.91);
calibrator.record_error(0, 0.88, 0.85);
calibrator.record_error(0, 0.95, 0.90);
let envelopes = calibrator.finalize();
assert!(envelopes.envelopes[0].sample_count == 5);
assert!(envelopes.envelopes[0].mean_error > 0.0);
assert!(envelopes.envelopes[0].max_error > envelopes.envelopes[0].mean_error);
}
#[test]
fn test_envelope_threshold() {
let mut quantiles = HashMap::new();
quantiles.insert(9500, 0.05); quantiles.insert(9900, 0.08);
let envelope = ErrorEnvelope {
list_idx: 0,
quantiles,
mean_error: 0.03,
std_error: 0.02,
max_error: 0.10,
min_error: 0.00,
sample_count: 100,
};
let threshold = envelope.safe_true_threshold(0.90, 0.95);
assert!((threshold - 0.85).abs() < 0.01);
let threshold99 = envelope.safe_true_threshold(0.90, 0.99);
assert!((threshold99 - 0.82).abs() < 0.01);
}
#[test]
fn test_can_terminate() {
let envelopes = CalibrationRunner::calibrate_synthetic(5, 0.03, 0.01, 100);
let kth_proxy = 0.95;
let remaining = vec![(1, 0.70), (2, 0.65)];
let can_term = envelopes.can_terminate(kth_proxy, &remaining, 0.99);
assert!(
can_term,
"Should be able to terminate with high kth and low bounds"
);
let remaining_high = vec![(1, 0.94), (2, 0.93)];
let cannot_term = envelopes.can_terminate(kth_proxy, &remaining_high, 0.99);
assert!(!cannot_term, "Should not terminate with close bounds");
}
#[test]
fn test_synthetic_calibration() {
let envelopes = CalibrationRunner::calibrate_synthetic(10, 0.02, 0.01, 500);
assert_eq!(envelopes.envelopes.len(), 10);
assert!(envelopes.global.sample_count > 0);
assert!((envelopes.global.mean_error - 0.02).abs() < 0.01);
}
}