use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum GuaranteeMode {
Approximate,
Calibrated {
recall_target: f32,
confidence: f32,
},
Certified,
}
impl Default for GuaranteeMode {
fn default() -> Self {
Self::Calibrated {
recall_target: 0.95,
confidence: 0.99,
}
}
}
impl GuaranteeMode {
pub fn calibrated(recall_target: f32, confidence: f32) -> Self {
Self::Calibrated {
recall_target,
confidence,
}
}
pub fn requires_rerank(&self) -> bool {
matches!(self, GuaranteeMode::Certified)
}
pub fn uses_error_envelopes(&self) -> bool {
!matches!(self, GuaranteeMode::Approximate)
}
pub fn error_quantile(&self) -> Option<f32> {
match self {
GuaranteeMode::Approximate => None,
GuaranteeMode::Calibrated { confidence, .. } => Some(*confidence),
GuaranteeMode::Certified => Some(1.0), }
}
}
#[derive(Debug, Clone)]
pub enum StoppingRule {
FixedProbes { n_probes: u32 },
BoundBased {
min_probes: u32,
max_probes: u32,
},
ProbabilisticBound {
probability_threshold: f32,
error_quantile: f32,
min_probes: u32,
max_probes: u32,
},
DeterministicBound {
max_error: f32,
},
BudgetConstrained {
inner: Box<StoppingRule>,
max_ram_bytes: u64,
max_latency: Duration,
},
}
impl StoppingRule {
pub fn for_mode(mode: &GuaranteeMode, default_probes: u32) -> Self {
match mode {
GuaranteeMode::Approximate => Self::FixedProbes {
n_probes: default_probes,
},
GuaranteeMode::Calibrated { confidence, .. } => Self::ProbabilisticBound {
probability_threshold: 0.01, error_quantile: *confidence,
min_probes: default_probes / 4,
max_probes: default_probes * 4,
},
GuaranteeMode::Certified => Self::DeterministicBound {
max_error: 0.0, },
}
}
pub fn with_budget(self, max_ram_bytes: u64, max_latency: Duration) -> Self {
Self::BudgetConstrained {
inner: Box::new(self),
max_ram_bytes,
max_latency,
}
}
}
#[derive(Debug, Clone)]
pub struct StopDecision {
pub should_stop: bool,
pub reason: StopReason,
pub miss_probability: Option<f32>,
pub uncertain_candidates: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StopReason {
ProbesExhausted,
BoundSatisfied,
ProbabilityThreshold,
DeterministicComplete,
BudgetExhausted,
Continuing,
}
#[derive(Debug, Clone, Copy)]
pub struct ScoreEnvelope {
pub proxy: f32,
pub lower_bound: f32,
pub upper_bound: f32,
pub quantile_error: f32,
}
impl ScoreEnvelope {
pub fn new(proxy: f32, max_error: f32) -> Self {
Self {
proxy,
lower_bound: proxy - max_error,
upper_bound: proxy + max_error,
quantile_error: max_error,
}
}
pub fn with_bounds(proxy: f32, lower_bound: f32, upper_bound: f32) -> Self {
Self {
proxy,
lower_bound,
upper_bound,
quantile_error: (upper_bound - lower_bound) / 2.0,
}
}
pub fn definitely_beats(&self, other: &ScoreEnvelope) -> bool {
self.lower_bound > other.upper_bound
}
pub fn might_beat(&self, other: &ScoreEnvelope) -> bool {
self.upper_bound > other.lower_bound
}
pub fn estimated_true(&self) -> f32 {
(self.lower_bound + self.upper_bound) / 2.0
}
}
pub struct StoppingEvaluator {
rule: StoppingRule,
probes_done: u32,
ram_bytes_used: u64,
start_time: std::time::Instant,
}
impl StoppingEvaluator {
pub fn new(rule: StoppingRule) -> Self {
Self {
rule,
probes_done: 0,
ram_bytes_used: 0,
start_time: std::time::Instant::now(),
}
}
pub fn record_probe(&mut self, ram_bytes: u64) {
self.probes_done += 1;
self.ram_bytes_used += ram_bytes;
}
pub fn evaluate(
&self,
kth_score: Option<&ScoreEnvelope>,
best_remaining_bound: Option<f32>,
) -> StopDecision {
match &self.rule {
StoppingRule::FixedProbes { n_probes } => {
if self.probes_done >= *n_probes {
StopDecision {
should_stop: true,
reason: StopReason::ProbesExhausted,
miss_probability: None,
uncertain_candidates: 0,
}
} else {
StopDecision {
should_stop: false,
reason: StopReason::Continuing,
miss_probability: None,
uncertain_candidates: 0,
}
}
}
StoppingRule::BoundBased {
min_probes,
max_probes,
} => {
if self.probes_done >= *max_probes {
return StopDecision {
should_stop: true,
reason: StopReason::ProbesExhausted,
miss_probability: None,
uncertain_candidates: 0,
};
}
if self.probes_done < *min_probes {
return StopDecision {
should_stop: false,
reason: StopReason::Continuing,
miss_probability: None,
uncertain_candidates: 0,
};
}
if let (Some(kth), Some(bound)) = (kth_score, best_remaining_bound) {
if kth.proxy > bound {
return StopDecision {
should_stop: true,
reason: StopReason::BoundSatisfied,
miss_probability: None,
uncertain_candidates: 0,
};
}
}
StopDecision {
should_stop: false,
reason: StopReason::Continuing,
miss_probability: None,
uncertain_candidates: 0,
}
}
StoppingRule::ProbabilisticBound {
probability_threshold,
error_quantile: _,
min_probes,
max_probes,
} => {
if self.probes_done >= *max_probes {
return StopDecision {
should_stop: true,
reason: StopReason::ProbesExhausted,
miss_probability: Some(0.0),
uncertain_candidates: 0,
};
}
if self.probes_done < *min_probes {
return StopDecision {
should_stop: false,
reason: StopReason::Continuing,
miss_probability: Some(1.0),
uncertain_candidates: 0,
};
}
if let (Some(kth), Some(bound)) = (kth_score, best_remaining_bound) {
let margin = kth.lower_bound - bound;
let error_margin = kth.quantile_error;
let miss_prob = if margin > error_margin {
0.0
} else if margin < -error_margin {
1.0
} else {
0.5 - (margin / (2.0 * error_margin))
};
if miss_prob < *probability_threshold {
return StopDecision {
should_stop: true,
reason: StopReason::ProbabilityThreshold,
miss_probability: Some(miss_prob),
uncertain_candidates: 0,
};
}
return StopDecision {
should_stop: false,
reason: StopReason::Continuing,
miss_probability: Some(miss_prob),
uncertain_candidates: 0,
};
}
StopDecision {
should_stop: false,
reason: StopReason::Continuing,
miss_probability: Some(1.0),
uncertain_candidates: 0,
}
}
StoppingRule::DeterministicBound { max_error } => {
if let (Some(kth), Some(bound)) = (kth_score, best_remaining_bound) {
if kth.lower_bound > bound + *max_error {
return StopDecision {
should_stop: true,
reason: StopReason::DeterministicComplete,
miss_probability: Some(0.0),
uncertain_candidates: 0,
};
}
}
StopDecision {
should_stop: false,
reason: StopReason::Continuing,
miss_probability: None,
uncertain_candidates: 0,
}
}
StoppingRule::BudgetConstrained {
inner,
max_ram_bytes,
max_latency,
} => {
if self.ram_bytes_used > *max_ram_bytes {
return StopDecision {
should_stop: true,
reason: StopReason::BudgetExhausted,
miss_probability: None,
uncertain_candidates: 0,
};
}
if self.start_time.elapsed() > *max_latency {
return StopDecision {
should_stop: true,
reason: StopReason::BudgetExhausted,
miss_probability: None,
uncertain_candidates: 0,
};
}
let inner_eval = StoppingEvaluator {
rule: (**inner).clone(),
probes_done: self.probes_done,
ram_bytes_used: self.ram_bytes_used,
start_time: self.start_time,
};
inner_eval.evaluate(kth_score, best_remaining_bound)
}
}
}
}
#[derive(Debug, Clone)]
pub struct SearchContract {
pub mode: GuaranteeMode,
pub k: usize,
pub stopping_rule: StoppingRule,
pub include_envelopes: bool,
}
impl SearchContract {
pub fn approximate(k: usize, n_probes: u32) -> Self {
Self {
mode: GuaranteeMode::Approximate,
k,
stopping_rule: StoppingRule::FixedProbes { n_probes },
include_envelopes: false,
}
}
pub fn calibrated(k: usize, recall_target: f32, confidence: f32) -> Self {
let mode = GuaranteeMode::calibrated(recall_target, confidence);
let stopping_rule = StoppingRule::for_mode(&mode, 16);
Self {
mode,
k,
stopping_rule,
include_envelopes: true,
}
}
pub fn certified(k: usize) -> Self {
Self {
mode: GuaranteeMode::Certified,
k,
stopping_rule: StoppingRule::DeterministicBound { max_error: 0.0 },
include_envelopes: true,
}
}
pub fn with_budget(mut self, max_ram_bytes: u64, max_latency: Duration) -> Self {
self.stopping_rule = self.stopping_rule.with_budget(max_ram_bytes, max_latency);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_guarantee_modes() {
let approx = GuaranteeMode::Approximate;
assert!(!approx.requires_rerank());
assert!(!approx.uses_error_envelopes());
let calibrated = GuaranteeMode::calibrated(0.95, 0.99);
assert!(!calibrated.requires_rerank());
assert!(calibrated.uses_error_envelopes());
assert_eq!(calibrated.error_quantile(), Some(0.99));
let certified = GuaranteeMode::Certified;
assert!(certified.requires_rerank());
assert!(certified.uses_error_envelopes());
}
#[test]
fn test_score_envelope() {
let a = ScoreEnvelope::new(0.9, 0.05);
let b = ScoreEnvelope::new(0.8, 0.05);
assert!(!a.definitely_beats(&b)); assert!(a.might_beat(&b));
let c = ScoreEnvelope::new(0.9, 0.02);
let d = ScoreEnvelope::new(0.8, 0.02);
assert!(c.definitely_beats(&d)); }
#[test]
fn test_fixed_probes_stopping() {
let rule = StoppingRule::FixedProbes { n_probes: 10 };
let mut eval = StoppingEvaluator::new(rule);
for _ in 0..9 {
eval.record_probe(1000);
let decision = eval.evaluate(None, None);
assert!(!decision.should_stop);
}
eval.record_probe(1000);
let decision = eval.evaluate(None, None);
assert!(decision.should_stop);
assert_eq!(decision.reason, StopReason::ProbesExhausted);
}
#[test]
fn test_bound_based_stopping() {
let rule = StoppingRule::BoundBased {
min_probes: 2,
max_probes: 100,
};
let mut eval = StoppingEvaluator::new(rule);
eval.record_probe(1000);
let kth = ScoreEnvelope::new(0.9, 0.01);
let decision = eval.evaluate(Some(&kth), Some(0.8));
assert!(!decision.should_stop);
eval.record_probe(1000);
let decision = eval.evaluate(Some(&kth), Some(0.8));
assert!(decision.should_stop);
assert_eq!(decision.reason, StopReason::BoundSatisfied);
}
}