use std::collections::{HashMap, VecDeque};
use tracing::debug;
#[derive(Debug, Clone)]
pub struct CascadeOutcome {
pub query_class: String,
pub weak_model_succeeded: bool,
pub weak_latency_ms: u64,
pub strong_latency_ms: Option<u64>,
pub total_cost: f64,
}
#[derive(Debug)]
pub struct CascadeOptimizer {
outcomes: HashMap<String, VecDeque<CascadeOutcome>>,
window_size: usize,
}
impl CascadeOptimizer {
pub fn new(window_size: usize) -> Self {
Self {
outcomes: HashMap::new(),
window_size,
}
}
pub fn record(&mut self, outcome: CascadeOutcome) {
let class = outcome.query_class.clone();
let entries = self.outcomes.entry(class).or_default();
entries.push_back(outcome);
if entries.len() > self.window_size {
entries.pop_front();
}
}
pub fn expected_utility(&self, query_class: &str) -> (f64, f64) {
let outcomes = match self.outcomes.get(query_class) {
Some(o) if !o.is_empty() => o,
_ => return (0.5, 0.5),
};
let total = outcomes.len() as f64;
let weak_success_count = outcomes.iter().filter(|o| o.weak_model_succeeded).count() as f64;
let weak_success_rate = weak_success_count / total;
let avg_weak_latency = outcomes
.iter()
.map(|o| o.weak_latency_ms as f64)
.sum::<f64>()
/ total;
let avg_strong_latency = outcomes
.iter()
.filter_map(|o| o.strong_latency_ms.map(|ms| ms as f64))
.sum::<f64>()
/ outcomes
.iter()
.filter(|o| o.strong_latency_ms.is_some())
.count()
.max(1) as f64;
let latency_weight = 0.001;
let cascade_utility = weak_success_rate
- latency_weight * avg_weak_latency
- (1.0 - weak_success_rate) * latency_weight * avg_strong_latency;
let direct_utility = 1.0 - latency_weight * avg_strong_latency;
(cascade_utility, direct_utility)
}
pub fn should_cascade(&self, query_class: &str) -> CascadeStrategy {
let (cascade_util, direct_util) = self.expected_utility(query_class);
if cascade_util > direct_util {
debug!(
class = query_class,
cascade = cascade_util,
direct = direct_util,
"cascade recommended"
);
CascadeStrategy::Cascade
} else {
debug!(
class = query_class,
cascade = cascade_util,
direct = direct_util,
"direct recommended"
);
CascadeStrategy::Direct
}
}
pub fn weak_success_rate(&self, query_class: &str) -> f64 {
match self.outcomes.get(query_class) {
Some(outcomes) if !outcomes.is_empty() => {
let successes = outcomes.iter().filter(|o| o.weak_model_succeeded).count();
successes as f64 / outcomes.len() as f64
}
_ => 0.5,
}
}
pub fn query_classes(&self) -> Vec<&str> {
self.outcomes.keys().map(|s| s.as_str()).collect()
}
pub fn observation_count(&self, query_class: &str) -> usize {
self.outcomes.get(query_class).map(|o| o.len()).unwrap_or(0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CascadeStrategy {
Cascade,
Direct,
}
impl std::fmt::Display for CascadeStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CascadeStrategy::Cascade => write!(f, "cascade"),
CascadeStrategy::Direct => write!(f, "direct"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn cascade_win(class: &str) -> CascadeOutcome {
CascadeOutcome {
query_class: class.to_string(),
weak_model_succeeded: true,
weak_latency_ms: 200,
strong_latency_ms: None,
total_cost: 0.001,
}
}
fn cascade_fail(class: &str) -> CascadeOutcome {
CascadeOutcome {
query_class: class.to_string(),
weak_model_succeeded: false,
weak_latency_ms: 200,
strong_latency_ms: Some(2000),
total_cost: 0.01,
}
}
#[test]
fn high_success_rate_favors_cascade() {
let mut opt = CascadeOptimizer::new(100);
for _ in 0..9 {
opt.record(cascade_win("simple"));
}
opt.record(cascade_fail("simple"));
let rate = opt.weak_success_rate("simple");
assert!((rate - 0.9).abs() < f64::EPSILON);
assert_eq!(opt.should_cascade("simple"), CascadeStrategy::Cascade);
}
#[test]
fn low_success_rate_favors_direct() {
let mut opt = CascadeOptimizer::new(100);
for _ in 0..9 {
opt.record(cascade_fail("complex"));
}
opt.record(cascade_win("complex"));
let rate = opt.weak_success_rate("complex");
assert!((rate - 0.1).abs() < f64::EPSILON);
assert_eq!(opt.should_cascade("complex"), CascadeStrategy::Direct);
}
#[test]
fn unknown_class_defaults() {
let opt = CascadeOptimizer::new(100);
let rate = opt.weak_success_rate("unknown");
assert!((rate - 0.5).abs() < f64::EPSILON);
}
#[test]
fn window_eviction() {
let mut opt = CascadeOptimizer::new(5);
for _ in 0..10 {
opt.record(cascade_win("test"));
}
assert_eq!(opt.observation_count("test"), 5);
}
#[test]
fn expected_utility_no_data() {
let opt = CascadeOptimizer::new(100);
let (c, d) = opt.expected_utility("none");
assert!((c - 0.5).abs() < f64::EPSILON);
assert!((d - 0.5).abs() < f64::EPSILON);
}
#[test]
fn strategy_display() {
assert_eq!(format!("{}", CascadeStrategy::Cascade), "cascade");
assert_eq!(format!("{}", CascadeStrategy::Direct), "direct");
}
#[test]
fn query_classes_listed() {
let mut opt = CascadeOptimizer::new(100);
opt.record(cascade_win("a"));
opt.record(cascade_win("b"));
let classes = opt.query_classes();
assert_eq!(classes.len(), 2);
}
}