use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use tracing::{debug, info, warn};
use crate::outcome::{InferenceTask, TaskStats};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum RoutingMode {
#[default]
Auto,
Fast,
Best,
}
impl RoutingMode {
pub fn weights(&self) -> (f64, f64, f64) {
match self {
RoutingMode::Auto => (0.45, 0.40, 0.15),
RoutingMode::Fast => (0.15, 0.35, 0.50),
RoutingMode::Best => (0.70, 0.20, 0.10),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CircuitBreaker {
pub state: CircuitState,
pub failure_count: u32,
pub failure_threshold: u32,
pub cooldown_secs: u64,
pub opened_at: u64,
pub trip_count: u32,
}
impl CircuitBreaker {
pub fn new(failure_threshold: u32, cooldown_secs: u64) -> Self {
Self {
state: CircuitState::Closed,
failure_count: 0,
failure_threshold,
cooldown_secs,
opened_at: 0,
trip_count: 0,
}
}
pub fn allow_request(&mut self) -> bool {
match self.state {
CircuitState::Closed => true,
CircuitState::Open => {
let now = now_unix();
if now.saturating_sub(self.opened_at) >= self.cooldown_secs {
self.state = CircuitState::HalfOpen;
debug!("circuit breaker: Open → HalfOpen (cooldown expired)");
true } else {
false
}
}
CircuitState::HalfOpen => {
false
}
}
}
pub fn record_success(&mut self) {
match self.state {
CircuitState::HalfOpen => {
self.state = CircuitState::Closed;
self.failure_count = 0;
info!("circuit breaker: HalfOpen → Closed (probe succeeded)");
}
CircuitState::Closed => {
self.failure_count = 0;
}
CircuitState::Open => {} }
}
pub fn record_failure(&mut self) {
self.failure_count += 1;
match self.state {
CircuitState::Closed => {
if self.failure_count >= self.failure_threshold {
self.state = CircuitState::Open;
self.opened_at = now_unix();
self.trip_count += 1;
warn!(
failures = self.failure_count,
trips = self.trip_count,
"circuit breaker: Closed → Open"
);
}
}
CircuitState::HalfOpen => {
self.state = CircuitState::Open;
self.opened_at = now_unix();
self.trip_count += 1;
warn!("circuit breaker: HalfOpen → Open (probe failed)");
}
CircuitState::Open => {} }
}
pub fn is_blocking(&self) -> bool {
matches!(self.state, CircuitState::Open)
}
}
impl Default for CircuitBreaker {
fn default() -> Self {
Self::new(3, 60)
}
}
#[derive(Debug, Default)]
pub struct CircuitBreakerRegistry {
breakers: HashMap<String, CircuitBreaker>,
default_threshold: u32,
default_cooldown: u64,
}
impl CircuitBreakerRegistry {
pub fn new(default_threshold: u32, default_cooldown_secs: u64) -> Self {
Self {
breakers: HashMap::new(),
default_threshold,
default_cooldown: default_cooldown_secs,
}
}
pub fn allow_request(&mut self, model_id: &str) -> bool {
self.get_or_create(model_id).allow_request()
}
pub fn record_success(&mut self, model_id: &str) {
self.get_or_create(model_id).record_success();
}
pub fn record_failure(&mut self, model_id: &str) {
self.get_or_create(model_id).record_failure();
}
pub fn state(&self, model_id: &str) -> Option<CircuitState> {
self.breakers.get(model_id).map(|b| b.state)
}
fn get_or_create(&mut self, model_id: &str) -> &mut CircuitBreaker {
self.breakers
.entry(model_id.to_string())
.or_insert_with(|| CircuitBreaker::new(self.default_threshold, self.default_cooldown))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImplicitSignal {
pub model_id: String,
pub signal_type: ImplicitSignalType,
pub timestamp: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ImplicitSignalType {
Success,
RateLimited,
ServerError,
ClientError,
Timeout,
Retried,
}
impl ImplicitSignalType {
pub fn quality_delta(&self) -> f64 {
match self {
ImplicitSignalType::Success => 1.0,
ImplicitSignalType::RateLimited => -0.3, ImplicitSignalType::ServerError => -0.8,
ImplicitSignalType::ClientError => -0.2, ImplicitSignalType::Timeout => -0.5,
ImplicitSignalType::Retried => -0.7, }
}
pub fn is_circuit_failure(&self) -> bool {
matches!(
self,
ImplicitSignalType::RateLimited
| ImplicitSignalType::ServerError
| ImplicitSignalType::Timeout
)
}
}
pub fn signal_from_status(status: u16) -> ImplicitSignalType {
match status {
200..=299 => ImplicitSignalType::Success,
429 => ImplicitSignalType::RateLimited,
400..=428 | 430..=499 => ImplicitSignalType::ClientError,
500..=599 => ImplicitSignalType::ServerError,
_ => ImplicitSignalType::ClientError,
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpendLimits {
#[serde(default)]
pub per_request_usd: Option<f64>,
#[serde(default)]
pub hourly_usd: Option<f64>,
#[serde(default)]
pub daily_usd: Option<f64>,
}
impl Default for SpendLimits {
fn default() -> Self {
Self {
per_request_usd: None,
hourly_usd: None,
daily_usd: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SpendRecord {
cost_usd: f64,
timestamp: u64,
}
#[derive(Debug)]
pub struct SpendControl {
limits: SpendLimits,
records: Vec<SpendRecord>,
}
impl SpendControl {
pub fn new(limits: SpendLimits) -> Self {
Self {
limits,
records: Vec::new(),
}
}
pub fn check(&self, estimated_cost_usd: f64) -> Result<(), SpendLimitExceeded> {
if let Some(max) = self.limits.per_request_usd {
if estimated_cost_usd > max {
return Err(SpendLimitExceeded {
limit_type: "per_request".into(),
limit_usd: max,
current_usd: estimated_cost_usd,
window_secs: 0,
});
}
}
let now = now_unix();
if let Some(max) = self.limits.hourly_usd {
let hourly_spend = self.spend_in_window(now, 3600);
if hourly_spend + estimated_cost_usd > max {
return Err(SpendLimitExceeded {
limit_type: "hourly".into(),
limit_usd: max,
current_usd: hourly_spend,
window_secs: 3600,
});
}
}
if let Some(max) = self.limits.daily_usd {
let daily_spend = self.spend_in_window(now, 86400);
if daily_spend + estimated_cost_usd > max {
return Err(SpendLimitExceeded {
limit_type: "daily".into(),
limit_usd: max,
current_usd: daily_spend,
window_secs: 86400,
});
}
}
Ok(())
}
pub fn record(&mut self, cost_usd: f64) {
self.records.push(SpendRecord {
cost_usd,
timestamp: now_unix(),
});
let cutoff = now_unix().saturating_sub(86400);
self.records.retain(|r| r.timestamp >= cutoff);
}
pub fn spend_in_window(&self, now: u64, window_secs: u64) -> f64 {
let cutoff = now.saturating_sub(window_secs);
self.records
.iter()
.filter(|r| r.timestamp >= cutoff)
.map(|r| r.cost_usd)
.sum()
}
pub fn hourly_spend(&self) -> f64 {
self.spend_in_window(now_unix(), 3600)
}
pub fn daily_spend(&self) -> f64 {
self.spend_in_window(now_unix(), 86400)
}
pub fn status(&self) -> SpendStatus {
SpendStatus {
hourly_spend: self.hourly_spend(),
daily_spend: self.daily_spend(),
hourly_limit: self.limits.hourly_usd,
daily_limit: self.limits.daily_usd,
per_request_limit: self.limits.per_request_usd,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpendLimitExceeded {
pub limit_type: String,
pub limit_usd: f64,
pub current_usd: f64,
pub window_secs: u64,
}
impl std::fmt::Display for SpendLimitExceeded {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{} spend limit exceeded: ${:.4} / ${:.4}",
self.limit_type, self.current_usd, self.limit_usd
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpendStatus {
pub hourly_spend: f64,
pub daily_spend: f64,
pub hourly_limit: Option<f64>,
pub daily_limit: Option<f64>,
pub per_request_limit: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchmarkPrior {
pub overall_score: f64,
#[serde(default)]
pub overall_latency_ms: Option<f64>,
#[serde(default)]
pub task_scores: HashMap<String, f64>,
#[serde(default)]
pub task_latency_ms: HashMap<String, f64>,
}
pub fn apply_benchmark_priors(
tracker: &mut crate::outcome::OutcomeTracker,
benchmark_priors: &HashMap<String, BenchmarkPrior>,
) {
for (model_id, prior) in benchmark_priors {
let profile = tracker.profile(model_id);
if profile.is_none() || profile.map(|p| p.total_calls == 0).unwrap_or(true) {
let mut new_profile = crate::outcome::ModelProfile::new(model_id.clone());
new_profile.ema_quality = prior.overall_score.clamp(0.0, 1.0);
for (task, score) in &prior.task_scores {
new_profile.task_stats.insert(
task.clone(),
TaskStats {
ema_quality: score.clamp(0.0, 1.0),
avg_latency_ms: prior.task_latency_ms.get(task).copied().unwrap_or_default(),
..Default::default()
},
);
}
tracker.import_profiles(vec![new_profile]);
debug!(
model = %model_id,
quality = prior.overall_score,
task_priors = prior.task_scores.len(),
latency_priors = prior.task_latency_ms.len(),
"set benchmark quality prior"
);
}
}
}
pub fn load_benchmark_priors(
path: &std::path::Path,
) -> Result<HashMap<String, BenchmarkPrior>, String> {
if !path.exists() {
return Ok(HashMap::new());
}
let json = std::fs::read_to_string(path).map_err(|e| e.to_string())?;
let value: serde_json::Value = serde_json::from_str(&json).map_err(|e| e.to_string())?;
let mut priors = HashMap::new();
if let Some(model_id) = value.get("model_id").and_then(|v| v.as_str()) {
if let Some(overall) = value.get("overall_score").and_then(|v| v.as_f64()) {
priors.insert(
model_id.to_string(),
BenchmarkPrior {
overall_score: overall,
overall_latency_ms: value.get("avg_latency_ms").and_then(|v| v.as_f64()),
task_scores: extract_task_scores(&value),
task_latency_ms: extract_task_latencies(&value),
},
);
}
}
if let Some(arr) = value.as_array() {
for item in arr {
if let (Some(id), Some(score)) = (
item.get("model_id").and_then(|v| v.as_str()),
item.get("overall_score").and_then(|v| v.as_f64()),
) {
priors.insert(
id.to_string(),
BenchmarkPrior {
overall_score: score,
overall_latency_ms: item.get("avg_latency_ms").and_then(|v| v.as_f64()),
task_scores: extract_task_scores(item),
task_latency_ms: extract_task_latencies(item),
},
);
}
}
}
Ok(priors)
}
fn now_unix() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
fn extract_task_scores(value: &serde_json::Value) -> HashMap<String, f64> {
let mut task_scores: HashMap<String, Vec<f64>> = HashMap::new();
let Some(cases) = value.get("cases").and_then(|v| v.as_array()) else {
return HashMap::new();
};
for case in cases {
let Some(category) = case.get("category").and_then(|v| v.as_str()) else {
continue;
};
let Some(score) = case.get("score").and_then(|v| v.as_f64()) else {
continue;
};
if let Some(task) = benchmark_category_to_task(category) {
task_scores
.entry(task.to_string())
.or_default()
.push(score.clamp(0.0, 1.0));
}
}
task_scores
.into_iter()
.map(|(task, scores)| {
let avg = scores.iter().sum::<f64>() / scores.len() as f64;
(task, avg)
})
.collect()
}
fn extract_task_latencies(value: &serde_json::Value) -> HashMap<String, f64> {
let mut task_latencies: HashMap<String, Vec<f64>> = HashMap::new();
let Some(cases) = value.get("cases").and_then(|v| v.as_array()) else {
return HashMap::new();
};
for case in cases {
let Some(category) = case.get("category").and_then(|v| v.as_str()) else {
continue;
};
let Some(latency_ms) = case.get("latency_ms").and_then(|v| v.as_f64()) else {
continue;
};
if let Some(task) = benchmark_category_to_task(category) {
task_latencies
.entry(task.to_string())
.or_default()
.push(latency_ms.max(1.0));
}
}
task_latencies
.into_iter()
.map(|(task, latencies)| {
let avg = latencies.iter().sum::<f64>() / latencies.len() as f64;
(task, avg)
})
.collect()
}
fn benchmark_category_to_task(category: &str) -> Option<InferenceTask> {
match category {
"basic" | "generate" | "tool_use" | "vision" => Some(InferenceTask::Generate),
"code" | "coding" => Some(InferenceTask::Code),
"reasoning" | "analysis" => Some(InferenceTask::Reasoning),
"classify" | "classification" => Some(InferenceTask::Classify),
"embed" | "embedding" => Some(InferenceTask::Embed),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn routing_mode_weights() {
let (q, l, c) = RoutingMode::Auto.weights();
assert!((q + l + c - 1.0).abs() < 0.01);
let (q, _, c) = RoutingMode::Fast.weights();
assert!(c > q, "Fast mode should weight cost > quality");
let (q, _, c) = RoutingMode::Best.weights();
assert!(q > c, "Best mode should weight quality > cost");
}
#[test]
fn circuit_breaker_lifecycle() {
let mut cb = CircuitBreaker::new(3, 60);
assert_eq!(cb.state, CircuitState::Closed);
assert!(cb.allow_request());
cb.record_failure();
cb.record_failure();
assert_eq!(cb.state, CircuitState::Closed);
assert!(cb.allow_request());
cb.record_failure();
assert_eq!(cb.state, CircuitState::Open);
assert!(!cb.allow_request());
let mut cb2 = CircuitBreaker::new(3, 60);
cb2.record_failure();
cb2.record_failure();
cb2.record_success();
assert_eq!(cb2.failure_count, 0);
}
#[test]
fn circuit_breaker_half_open_recovery() {
let mut cb = CircuitBreaker::new(2, 0); cb.record_failure();
cb.record_failure();
assert_eq!(cb.state, CircuitState::Open);
assert!(cb.allow_request());
assert_eq!(cb.state, CircuitState::HalfOpen);
cb.record_success();
assert_eq!(cb.state, CircuitState::Closed);
}
#[test]
fn circuit_breaker_half_open_failure() {
let mut cb = CircuitBreaker::new(2, 0);
cb.record_failure();
cb.record_failure();
assert!(cb.allow_request()); assert_eq!(cb.state, CircuitState::HalfOpen);
cb.record_failure();
assert_eq!(cb.state, CircuitState::Open);
assert_eq!(cb.trip_count, 2);
}
#[test]
fn circuit_breaker_registry() {
let mut reg = CircuitBreakerRegistry::new(2, 0);
assert!(reg.allow_request("model-a"));
reg.record_failure("model-a");
reg.record_failure("model-a");
assert!(
!reg.allow_request("model-a") || reg.state("model-a") == Some(CircuitState::HalfOpen)
);
assert!(reg.allow_request("model-b"));
}
#[test]
fn signal_from_http_status() {
assert_eq!(signal_from_status(200), ImplicitSignalType::Success);
assert_eq!(signal_from_status(429), ImplicitSignalType::RateLimited);
assert_eq!(signal_from_status(500), ImplicitSignalType::ServerError);
assert_eq!(signal_from_status(400), ImplicitSignalType::ClientError);
}
#[test]
fn quality_deltas() {
assert!(ImplicitSignalType::Success.quality_delta() > 0.0);
assert!(ImplicitSignalType::ServerError.quality_delta() < 0.0);
assert!(ImplicitSignalType::Retried.quality_delta() < 0.0);
}
#[test]
fn spend_per_request_limit() {
let sc = SpendControl::new(SpendLimits {
per_request_usd: Some(0.10),
..Default::default()
});
assert!(sc.check(0.05).is_ok());
assert!(sc.check(0.15).is_err());
}
#[test]
fn spend_hourly_limit() {
let mut sc = SpendControl::new(SpendLimits {
hourly_usd: Some(1.00),
..Default::default()
});
sc.record(0.40);
sc.record(0.40);
assert!(sc.check(0.10).is_ok());
assert!(sc.check(0.25).is_err());
}
#[test]
fn spend_status() {
let mut sc = SpendControl::new(SpendLimits {
hourly_usd: Some(5.0),
daily_usd: Some(20.0),
..Default::default()
});
sc.record(1.50);
let status = sc.status();
assert!((status.hourly_spend - 1.50).abs() < 0.01);
assert_eq!(status.hourly_limit, Some(5.0));
}
#[test]
fn apply_priors() {
let mut tracker = crate::outcome::OutcomeTracker::new();
let mut priors = HashMap::new();
priors.insert(
"model-a".to_string(),
BenchmarkPrior {
overall_score: 0.85,
overall_latency_ms: Some(1100.0),
task_scores: HashMap::from([
("generate".to_string(), 0.82),
("code".to_string(), 0.91),
]),
task_latency_ms: HashMap::from([
("generate".to_string(), 900.0),
("code".to_string(), 2100.0),
]),
},
);
priors.insert(
"model-b".to_string(),
BenchmarkPrior {
overall_score: 0.60,
overall_latency_ms: Some(2000.0),
task_scores: HashMap::new(),
task_latency_ms: HashMap::new(),
},
);
apply_benchmark_priors(&mut tracker, &priors);
let profile_a = tracker.profile("model-a").unwrap();
assert!((profile_a.ema_quality - 0.85).abs() < 0.01);
assert!(
(profile_a
.task_stats(crate::outcome::InferenceTask::Generate)
.unwrap()
.ema_quality
- 0.82)
.abs()
< 0.01
);
assert!(
(profile_a
.task_stats(crate::outcome::InferenceTask::Code)
.unwrap()
.ema_quality
- 0.91)
.abs()
< 0.01
);
assert!(
(profile_a
.task_stats(crate::outcome::InferenceTask::Code)
.unwrap()
.avg_latency_ms
- 2100.0)
.abs()
< 0.01
);
let profile_b = tracker.profile("model-b").unwrap();
assert!((profile_b.ema_quality - 0.60).abs() < 0.01);
}
#[test]
fn priors_dont_overwrite_observed() {
let mut tracker = crate::outcome::OutcomeTracker::new();
let trace =
tracker.record_start("model-a", crate::outcome::InferenceTask::Generate, "test");
tracker.record_complete(&trace, 100, 10, 5);
let mut priors = HashMap::new();
priors.insert(
"model-a".to_string(),
BenchmarkPrior {
overall_score: 0.99,
overall_latency_ms: Some(1500.0),
task_scores: HashMap::from([("generate".to_string(), 0.99)]),
task_latency_ms: HashMap::from([("generate".to_string(), 1500.0)]),
},
);
apply_benchmark_priors(&mut tracker, &priors);
let profile = tracker.profile("model-a").unwrap();
assert!(profile.ema_quality < 0.9);
}
#[test]
fn load_benchmark_priors_extracts_task_scores_from_cases() {
let tmp = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
tmp.path(),
serde_json::json!({
"model_id": "model-a",
"overall_score": 0.78,
"cases": [
{"id": "basic_exact", "category": "basic", "score": 0.9, "latency_ms": 800},
{"id": "code_fibonacci", "category": "code", "score": 0.8, "latency_ms": 2200},
{"id": "reasoning_lp", "category": "reasoning", "score": 0.7, "latency_ms": 3300},
{"id": "reasoning_lp_2", "category": "reasoning", "score": 0.5, "latency_ms": 2700}
]
})
.to_string(),
)
.unwrap();
let priors = load_benchmark_priors(tmp.path()).unwrap();
let prior = priors.get("model-a").unwrap();
assert!((prior.overall_score - 0.78).abs() < 0.01);
assert!((prior.task_scores["generate"] - 0.9).abs() < 0.01);
assert!((prior.task_scores["code"] - 0.8).abs() < 0.01);
assert!((prior.task_scores["reasoning"] - 0.6).abs() < 0.01);
assert!((prior.task_latency_ms["generate"] - 800.0).abs() < 0.01);
assert!((prior.task_latency_ms["code"] - 2200.0).abs() < 0.01);
assert!((prior.task_latency_ms["reasoning"] - 3000.0).abs() < 0.01);
}
#[test]
fn load_benchmark_priors_maps_tool_and_vision_cases_to_generate() {
let tmp = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
tmp.path(),
serde_json::json!({
"model_id": "model-a",
"overall_score": 1.0,
"cases": [
{"id": "tool_weather", "category": "tool_use", "score": 1.0, "latency_ms": 1300},
{"id": "vision_cat", "category": "vision", "score": 1.0, "latency_ms": 1700}
]
})
.to_string(),
)
.unwrap();
let priors = load_benchmark_priors(tmp.path()).unwrap();
let prior = priors.get("model-a").unwrap();
assert!((prior.task_scores["generate"] - 1.0).abs() < 0.01);
assert!((prior.task_latency_ms["generate"] - 1500.0).abs() < 0.01);
}
}