Skip to main content

car_inference/
routing_ext.rs

1//! Extended routing features — routing modes, circuit breaker, implicit feedback,
2//! spend control, and benchmark quality priors.
3
4use std::collections::HashMap;
5use std::time::{SystemTime, UNIX_EPOCH};
6
7use serde::{Deserialize, Serialize};
8use tracing::{debug, info, warn};
9
10use crate::outcome::{InferenceTask, TaskStats};
11
12// ─── Routing Modes ───────────────────────────────────────────────────────────
13
14/// Quality-vs-cost preference dial for model routing.
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
16#[serde(rename_all = "snake_case")]
17pub enum RoutingMode {
18    /// Balanced quality-per-dollar (default). Quality weight increases with difficulty.
19    #[default]
20    Auto,
21    /// Cost-dominant — cheapest acceptable model.
22    Fast,
23    /// Quality-dominant — best model available, cost nearly ignored.
24    Best,
25}
26
27impl RoutingMode {
28    /// Get scoring weight adjustments for this mode.
29    /// Returns (quality_weight, latency_weight, cost_weight).
30    pub fn weights(&self) -> (f64, f64, f64) {
31        match self {
32            RoutingMode::Auto => (0.45, 0.40, 0.15),
33            RoutingMode::Fast => (0.15, 0.35, 0.50),
34            RoutingMode::Best => (0.70, 0.20, 0.10),
35        }
36    }
37}
38
39// ─── Circuit Breaker ─────────────────────────────────────────────────────────
40
41/// Circuit breaker state for a single model.
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
43#[serde(rename_all = "snake_case")]
44pub enum CircuitState {
45    /// Normal operation — all requests allowed.
46    Closed,
47    /// Model is failing — requests blocked until cooldown expires.
48    Open,
49    /// Cooldown expired — allowing one probe request.
50    HalfOpen,
51}
52
53/// Per-model circuit breaker.
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct CircuitBreaker {
56    pub state: CircuitState,
57    /// Number of consecutive failures.
58    pub failure_count: u32,
59    /// Threshold: trip to Open after this many failures in the window.
60    pub failure_threshold: u32,
61    /// Cooldown in seconds before transitioning from Open to HalfOpen.
62    pub cooldown_secs: u64,
63    /// When the circuit opened (unix timestamp).
64    pub opened_at: u64,
65    /// Total number of times the circuit has tripped.
66    pub trip_count: u32,
67}
68
69impl CircuitBreaker {
70    pub fn new(failure_threshold: u32, cooldown_secs: u64) -> Self {
71        Self {
72            state: CircuitState::Closed,
73            failure_count: 0,
74            failure_threshold,
75            cooldown_secs,
76            opened_at: 0,
77            trip_count: 0,
78        }
79    }
80
81    /// Check if a request should be allowed.
82    pub fn allow_request(&mut self) -> bool {
83        match self.state {
84            CircuitState::Closed => true,
85            CircuitState::Open => {
86                // Check if cooldown has expired → transition to HalfOpen
87                let now = now_unix();
88                if now.saturating_sub(self.opened_at) >= self.cooldown_secs {
89                    self.state = CircuitState::HalfOpen;
90                    debug!("circuit breaker: Open → HalfOpen (cooldown expired)");
91                    true // allow the probe request
92                } else {
93                    false
94                }
95            }
96            CircuitState::HalfOpen => {
97                // Already allowing one probe — block additional requests
98                // until the probe completes
99                false
100            }
101        }
102    }
103
104    /// Record a successful request.
105    pub fn record_success(&mut self) {
106        match self.state {
107            CircuitState::HalfOpen => {
108                // Probe succeeded — close the circuit
109                self.state = CircuitState::Closed;
110                self.failure_count = 0;
111                info!("circuit breaker: HalfOpen → Closed (probe succeeded)");
112            }
113            CircuitState::Closed => {
114                self.failure_count = 0;
115            }
116            CircuitState::Open => {} // shouldn't happen
117        }
118    }
119
120    /// Record a failed request.
121    pub fn record_failure(&mut self) {
122        self.failure_count += 1;
123
124        match self.state {
125            CircuitState::Closed => {
126                if self.failure_count >= self.failure_threshold {
127                    self.state = CircuitState::Open;
128                    self.opened_at = now_unix();
129                    self.trip_count += 1;
130                    warn!(
131                        failures = self.failure_count,
132                        trips = self.trip_count,
133                        "circuit breaker: Closed → Open"
134                    );
135                }
136            }
137            CircuitState::HalfOpen => {
138                // Probe failed — back to Open with fresh cooldown
139                self.state = CircuitState::Open;
140                self.opened_at = now_unix();
141                self.trip_count += 1;
142                warn!("circuit breaker: HalfOpen → Open (probe failed)");
143            }
144            CircuitState::Open => {} // already open
145        }
146    }
147
148    /// Whether the circuit is currently blocking requests.
149    pub fn is_blocking(&self) -> bool {
150        matches!(self.state, CircuitState::Open)
151    }
152}
153
154impl Default for CircuitBreaker {
155    fn default() -> Self {
156        Self::new(3, 60)
157    }
158}
159
160/// Manages circuit breakers for all models.
161#[derive(Debug, Default)]
162pub struct CircuitBreakerRegistry {
163    breakers: HashMap<String, CircuitBreaker>,
164    default_threshold: u32,
165    default_cooldown: u64,
166}
167
168impl CircuitBreakerRegistry {
169    pub fn new(default_threshold: u32, default_cooldown_secs: u64) -> Self {
170        Self {
171            breakers: HashMap::new(),
172            default_threshold,
173            default_cooldown: default_cooldown_secs,
174        }
175    }
176
177    /// Check if a model is allowed (not circuit-broken).
178    pub fn allow_request(&mut self, model_id: &str) -> bool {
179        self.get_or_create(model_id).allow_request()
180    }
181
182    /// Record success for a model.
183    pub fn record_success(&mut self, model_id: &str) {
184        self.get_or_create(model_id).record_success();
185    }
186
187    /// Record failure for a model.
188    pub fn record_failure(&mut self, model_id: &str) {
189        self.get_or_create(model_id).record_failure();
190    }
191
192    /// Get the circuit breaker state for a model.
193    pub fn state(&self, model_id: &str) -> Option<CircuitState> {
194        self.breakers.get(model_id).map(|b| b.state)
195    }
196
197    fn get_or_create(&mut self, model_id: &str) -> &mut CircuitBreaker {
198        self.breakers
199            .entry(model_id.to_string())
200            .or_insert_with(|| CircuitBreaker::new(self.default_threshold, self.default_cooldown))
201    }
202}
203
204// ─── Implicit Feedback ───────────────────────────────────────────────────────
205
206/// Implicit feedback signals extracted from HTTP responses.
207#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct ImplicitSignal {
209    pub model_id: String,
210    pub signal_type: ImplicitSignalType,
211    pub timestamp: u64,
212}
213
214#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
215#[serde(rename_all = "snake_case")]
216pub enum ImplicitSignalType {
217    /// HTTP 200 — model responded successfully.
218    Success,
219    /// HTTP 429 — rate limited.
220    RateLimited,
221    /// HTTP 5xx — server error.
222    ServerError,
223    /// HTTP 4xx (non-429) — client error (bad request, auth failure, etc.).
224    ClientError,
225    /// Request timed out.
226    Timeout,
227    /// Same prompt was sent again (retry detection = previous response was weak).
228    Retried,
229}
230
231impl ImplicitSignalType {
232    /// Convert to a quality delta for Thompson Sampling.
233    /// Positive = success, negative = failure.
234    pub fn quality_delta(&self) -> f64 {
235        match self {
236            ImplicitSignalType::Success => 1.0,
237            ImplicitSignalType::RateLimited => -0.3, // not quality, but availability
238            ImplicitSignalType::ServerError => -0.8,
239            ImplicitSignalType::ClientError => -0.2, // usually not model's fault
240            ImplicitSignalType::Timeout => -0.5,
241            ImplicitSignalType::Retried => -0.7, // strong signal of weak output
242        }
243    }
244
245    /// Whether this signal should count as a "failure" for the circuit breaker.
246    pub fn is_circuit_failure(&self) -> bool {
247        matches!(
248            self,
249            ImplicitSignalType::RateLimited
250                | ImplicitSignalType::ServerError
251                | ImplicitSignalType::Timeout
252        )
253    }
254}
255
256/// Extract an implicit signal from an HTTP status code.
257pub fn signal_from_status(status: u16) -> ImplicitSignalType {
258    match status {
259        200..=299 => ImplicitSignalType::Success,
260        429 => ImplicitSignalType::RateLimited,
261        400..=428 | 430..=499 => ImplicitSignalType::ClientError,
262        500..=599 => ImplicitSignalType::ServerError,
263        _ => ImplicitSignalType::ClientError,
264    }
265}
266
267// ─── Spend Control ───────────────────────────────────────────────────────────
268
269/// Budget limits for inference spending.
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct SpendLimits {
272    /// Maximum cost per individual request in USD.
273    #[serde(default)]
274    pub per_request_usd: Option<f64>,
275    /// Maximum cost per rolling hour in USD.
276    #[serde(default)]
277    pub hourly_usd: Option<f64>,
278    /// Maximum cost per rolling 24 hours in USD.
279    #[serde(default)]
280    pub daily_usd: Option<f64>,
281}
282
283impl Default for SpendLimits {
284    fn default() -> Self {
285        Self {
286            per_request_usd: None,
287            hourly_usd: None,
288            daily_usd: None,
289        }
290    }
291}
292
293/// Tracked spending record.
294#[derive(Debug, Clone, Serialize, Deserialize)]
295struct SpendRecord {
296    cost_usd: f64,
297    timestamp: u64,
298}
299
300/// Spend controller that enforces budget limits.
301#[derive(Debug)]
302pub struct SpendControl {
303    limits: SpendLimits,
304    /// Rolling window of spending records.
305    records: Vec<SpendRecord>,
306}
307
308impl SpendControl {
309    pub fn new(limits: SpendLimits) -> Self {
310        Self {
311            limits,
312            records: Vec::new(),
313        }
314    }
315
316    /// Check if a request with estimated cost would exceed any limit.
317    /// Returns Ok(()) if allowed, Err(reason) if blocked.
318    pub fn check(&self, estimated_cost_usd: f64) -> Result<(), SpendLimitExceeded> {
319        // Per-request limit
320        if let Some(max) = self.limits.per_request_usd {
321            if estimated_cost_usd > max {
322                return Err(SpendLimitExceeded {
323                    limit_type: "per_request".into(),
324                    limit_usd: max,
325                    current_usd: estimated_cost_usd,
326                    window_secs: 0,
327                });
328            }
329        }
330
331        let now = now_unix();
332
333        // Hourly limit
334        if let Some(max) = self.limits.hourly_usd {
335            let hourly_spend = self.spend_in_window(now, 3600);
336            if hourly_spend + estimated_cost_usd > max {
337                return Err(SpendLimitExceeded {
338                    limit_type: "hourly".into(),
339                    limit_usd: max,
340                    current_usd: hourly_spend,
341                    window_secs: 3600,
342                });
343            }
344        }
345
346        // Daily limit
347        if let Some(max) = self.limits.daily_usd {
348            let daily_spend = self.spend_in_window(now, 86400);
349            if daily_spend + estimated_cost_usd > max {
350                return Err(SpendLimitExceeded {
351                    limit_type: "daily".into(),
352                    limit_usd: max,
353                    current_usd: daily_spend,
354                    window_secs: 86400,
355                });
356            }
357        }
358
359        Ok(())
360    }
361
362    /// Record a completed request's actual cost.
363    pub fn record(&mut self, cost_usd: f64) {
364        self.records.push(SpendRecord {
365            cost_usd,
366            timestamp: now_unix(),
367        });
368        // Prune records older than 24 hours
369        let cutoff = now_unix().saturating_sub(86400);
370        self.records.retain(|r| r.timestamp >= cutoff);
371    }
372
373    /// Get total spend in a time window (seconds from now).
374    pub fn spend_in_window(&self, now: u64, window_secs: u64) -> f64 {
375        let cutoff = now.saturating_sub(window_secs);
376        self.records
377            .iter()
378            .filter(|r| r.timestamp >= cutoff)
379            .map(|r| r.cost_usd)
380            .sum()
381    }
382
383    /// Get total spend in the last hour.
384    pub fn hourly_spend(&self) -> f64 {
385        self.spend_in_window(now_unix(), 3600)
386    }
387
388    /// Get total spend in the last 24 hours.
389    pub fn daily_spend(&self) -> f64 {
390        self.spend_in_window(now_unix(), 86400)
391    }
392
393    /// Get spend status summary.
394    pub fn status(&self) -> SpendStatus {
395        SpendStatus {
396            hourly_spend: self.hourly_spend(),
397            daily_spend: self.daily_spend(),
398            hourly_limit: self.limits.hourly_usd,
399            daily_limit: self.limits.daily_usd,
400            per_request_limit: self.limits.per_request_usd,
401        }
402    }
403}
404
405/// Error returned when a spend limit is exceeded.
406#[derive(Debug, Clone, Serialize, Deserialize)]
407pub struct SpendLimitExceeded {
408    pub limit_type: String,
409    pub limit_usd: f64,
410    pub current_usd: f64,
411    pub window_secs: u64,
412}
413
414impl std::fmt::Display for SpendLimitExceeded {
415    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416        write!(
417            f,
418            "{} spend limit exceeded: ${:.4} / ${:.4}",
419            self.limit_type, self.current_usd, self.limit_usd
420        )
421    }
422}
423
424/// Current spend status.
425#[derive(Debug, Clone, Serialize, Deserialize)]
426pub struct SpendStatus {
427    pub hourly_spend: f64,
428    pub daily_spend: f64,
429    pub hourly_limit: Option<f64>,
430    pub daily_limit: Option<f64>,
431    pub per_request_limit: Option<f64>,
432}
433
434// ─── Benchmark Priors ────────────────────────────────────────────────────────
435
436#[derive(Debug, Clone, Serialize, Deserialize)]
437pub struct BenchmarkPrior {
438    pub overall_score: f64,
439    #[serde(default)]
440    pub overall_latency_ms: Option<f64>,
441    #[serde(default)]
442    pub task_scores: HashMap<String, f64>,
443    #[serde(default)]
444    pub task_latency_ms: HashMap<String, f64>,
445}
446
447/// Import benchmark results as quality priors for the OutcomeTracker.
448///
449/// For each model with benchmark data, set the initial EMA quality and any
450/// task-specific EMA priors instead of the neutral 0.5 prior. This gives the
451/// Thompson Sampling informed priors from the start.
452pub fn apply_benchmark_priors(
453    tracker: &mut crate::outcome::OutcomeTracker,
454    benchmark_priors: &HashMap<String, BenchmarkPrior>,
455) {
456    for (model_id, prior) in benchmark_priors {
457        let profile = tracker.profile(model_id);
458        if profile.is_none() || profile.map(|p| p.total_calls == 0).unwrap_or(true) {
459            // No observed data — set the prior from benchmark
460            let mut new_profile = crate::outcome::ModelProfile::new(model_id.clone());
461            new_profile.ema_quality = prior.overall_score.clamp(0.0, 1.0);
462            for (task, score) in &prior.task_scores {
463                new_profile.task_stats.insert(
464                    task.clone(),
465                    TaskStats {
466                        ema_quality: score.clamp(0.0, 1.0),
467                        avg_latency_ms: prior
468                            .task_latency_ms
469                            .get(task)
470                            .copied()
471                            .unwrap_or_default(),
472                        ..Default::default()
473                    },
474                );
475            }
476            tracker.import_profiles(vec![new_profile]);
477            debug!(
478                model = %model_id,
479                quality = prior.overall_score,
480                task_priors = prior.task_scores.len(),
481                latency_priors = prior.task_latency_ms.len(),
482                "set benchmark quality prior"
483            );
484        }
485    }
486}
487
488/// Load benchmark priors from a benchmark result JSON file.
489pub fn load_benchmark_priors(
490    path: &std::path::Path,
491) -> Result<HashMap<String, BenchmarkPrior>, String> {
492    if !path.exists() {
493        return Ok(HashMap::new());
494    }
495    let json = std::fs::read_to_string(path).map_err(|e| e.to_string())?;
496    let value: serde_json::Value = serde_json::from_str(&json).map_err(|e| e.to_string())?;
497
498    let mut priors = HashMap::new();
499
500    // Try car-bench BenchmarkResult format
501    if let Some(model_id) = value.get("model_id").and_then(|v| v.as_str()) {
502        if let Some(overall) = value.get("overall_score").and_then(|v| v.as_f64()) {
503            priors.insert(
504                model_id.to_string(),
505                BenchmarkPrior {
506                    overall_score: overall,
507                    overall_latency_ms: value.get("avg_latency_ms").and_then(|v| v.as_f64()),
508                    task_scores: extract_task_scores(&value),
509                    task_latency_ms: extract_task_latencies(&value),
510                },
511            );
512        }
513    }
514
515    // Try array of results
516    if let Some(arr) = value.as_array() {
517        for item in arr {
518            if let (Some(id), Some(score)) = (
519                item.get("model_id").and_then(|v| v.as_str()),
520                item.get("overall_score").and_then(|v| v.as_f64()),
521            ) {
522                priors.insert(
523                    id.to_string(),
524                    BenchmarkPrior {
525                        overall_score: score,
526                        overall_latency_ms: item.get("avg_latency_ms").and_then(|v| v.as_f64()),
527                        task_scores: extract_task_scores(item),
528                        task_latency_ms: extract_task_latencies(item),
529                    },
530                );
531            }
532        }
533    }
534
535    Ok(priors)
536}
537
538// ─── Helpers ─────────────────────────────────────────────────────────────────
539
540fn now_unix() -> u64 {
541    SystemTime::now()
542        .duration_since(UNIX_EPOCH)
543        .unwrap_or_default()
544        .as_secs()
545}
546
547fn extract_task_scores(value: &serde_json::Value) -> HashMap<String, f64> {
548    let mut task_scores: HashMap<String, Vec<f64>> = HashMap::new();
549    let Some(cases) = value.get("cases").and_then(|v| v.as_array()) else {
550        return HashMap::new();
551    };
552
553    for case in cases {
554        let Some(category) = case.get("category").and_then(|v| v.as_str()) else {
555            continue;
556        };
557        let Some(score) = case.get("score").and_then(|v| v.as_f64()) else {
558            continue;
559        };
560        if let Some(task) = benchmark_category_to_task(category) {
561            task_scores
562                .entry(task.to_string())
563                .or_default()
564                .push(score.clamp(0.0, 1.0));
565        }
566    }
567
568    task_scores
569        .into_iter()
570        .map(|(task, scores)| {
571            let avg = scores.iter().sum::<f64>() / scores.len() as f64;
572            (task, avg)
573        })
574        .collect()
575}
576
577fn extract_task_latencies(value: &serde_json::Value) -> HashMap<String, f64> {
578    let mut task_latencies: HashMap<String, Vec<f64>> = HashMap::new();
579    let Some(cases) = value.get("cases").and_then(|v| v.as_array()) else {
580        return HashMap::new();
581    };
582
583    for case in cases {
584        let Some(category) = case.get("category").and_then(|v| v.as_str()) else {
585            continue;
586        };
587        let Some(latency_ms) = case.get("latency_ms").and_then(|v| v.as_f64()) else {
588            continue;
589        };
590        if let Some(task) = benchmark_category_to_task(category) {
591            task_latencies
592                .entry(task.to_string())
593                .or_default()
594                .push(latency_ms.max(1.0));
595        }
596    }
597
598    task_latencies
599        .into_iter()
600        .map(|(task, latencies)| {
601            let avg = latencies.iter().sum::<f64>() / latencies.len() as f64;
602            (task, avg)
603        })
604        .collect()
605}
606
607fn benchmark_category_to_task(category: &str) -> Option<InferenceTask> {
608    match category {
609        "basic" | "generate" | "tool_use" | "vision" => Some(InferenceTask::Generate),
610        "code" | "coding" => Some(InferenceTask::Code),
611        "reasoning" | "analysis" => Some(InferenceTask::Reasoning),
612        "classify" | "classification" => Some(InferenceTask::Classify),
613        "embed" | "embedding" => Some(InferenceTask::Embed),
614        _ => None,
615    }
616}
617
618// ─── Tests ───────────────────────────────────────────────────────────────────
619
620#[cfg(test)]
621mod tests {
622    use super::*;
623
624    // --- Routing Modes ---
625
626    #[test]
627    fn routing_mode_weights() {
628        let (q, l, c) = RoutingMode::Auto.weights();
629        assert!((q + l + c - 1.0).abs() < 0.01);
630
631        let (q, _, c) = RoutingMode::Fast.weights();
632        assert!(c > q, "Fast mode should weight cost > quality");
633
634        let (q, _, c) = RoutingMode::Best.weights();
635        assert!(q > c, "Best mode should weight quality > cost");
636    }
637
638    // --- Circuit Breaker ---
639
640    #[test]
641    fn circuit_breaker_lifecycle() {
642        let mut cb = CircuitBreaker::new(3, 60);
643        assert_eq!(cb.state, CircuitState::Closed);
644        assert!(cb.allow_request());
645
646        // 2 failures — still closed
647        cb.record_failure();
648        cb.record_failure();
649        assert_eq!(cb.state, CircuitState::Closed);
650        assert!(cb.allow_request());
651
652        // 3rd failure — trips to open
653        cb.record_failure();
654        assert_eq!(cb.state, CircuitState::Open);
655        assert!(!cb.allow_request()); // blocked
656
657        // Success resets on closed
658        let mut cb2 = CircuitBreaker::new(3, 60);
659        cb2.record_failure();
660        cb2.record_failure();
661        cb2.record_success();
662        assert_eq!(cb2.failure_count, 0);
663    }
664
665    #[test]
666    fn circuit_breaker_half_open_recovery() {
667        let mut cb = CircuitBreaker::new(2, 0); // 0 cooldown for testing
668        cb.record_failure();
669        cb.record_failure();
670        assert_eq!(cb.state, CircuitState::Open);
671
672        // With 0 cooldown, should transition to HalfOpen immediately
673        assert!(cb.allow_request());
674        assert_eq!(cb.state, CircuitState::HalfOpen);
675
676        // Probe succeeds → Closed
677        cb.record_success();
678        assert_eq!(cb.state, CircuitState::Closed);
679    }
680
681    #[test]
682    fn circuit_breaker_half_open_failure() {
683        let mut cb = CircuitBreaker::new(2, 0);
684        cb.record_failure();
685        cb.record_failure();
686        assert!(cb.allow_request()); // transitions to HalfOpen
687        assert_eq!(cb.state, CircuitState::HalfOpen);
688
689        // Probe fails → back to Open
690        cb.record_failure();
691        assert_eq!(cb.state, CircuitState::Open);
692        assert_eq!(cb.trip_count, 2);
693    }
694
695    #[test]
696    fn circuit_breaker_registry() {
697        let mut reg = CircuitBreakerRegistry::new(2, 0);
698        assert!(reg.allow_request("model-a"));
699
700        reg.record_failure("model-a");
701        reg.record_failure("model-a");
702        assert!(
703            !reg.allow_request("model-a") || reg.state("model-a") == Some(CircuitState::HalfOpen)
704        );
705
706        // Different model unaffected
707        assert!(reg.allow_request("model-b"));
708    }
709
710    // --- Implicit Feedback ---
711
712    #[test]
713    fn signal_from_http_status() {
714        assert_eq!(signal_from_status(200), ImplicitSignalType::Success);
715        assert_eq!(signal_from_status(429), ImplicitSignalType::RateLimited);
716        assert_eq!(signal_from_status(500), ImplicitSignalType::ServerError);
717        assert_eq!(signal_from_status(400), ImplicitSignalType::ClientError);
718    }
719
720    #[test]
721    fn quality_deltas() {
722        assert!(ImplicitSignalType::Success.quality_delta() > 0.0);
723        assert!(ImplicitSignalType::ServerError.quality_delta() < 0.0);
724        assert!(ImplicitSignalType::Retried.quality_delta() < 0.0);
725    }
726
727    // --- Spend Control ---
728
729    #[test]
730    fn spend_per_request_limit() {
731        let sc = SpendControl::new(SpendLimits {
732            per_request_usd: Some(0.10),
733            ..Default::default()
734        });
735        assert!(sc.check(0.05).is_ok());
736        assert!(sc.check(0.15).is_err());
737    }
738
739    #[test]
740    fn spend_hourly_limit() {
741        let mut sc = SpendControl::new(SpendLimits {
742            hourly_usd: Some(1.00),
743            ..Default::default()
744        });
745        sc.record(0.40);
746        sc.record(0.40);
747        assert!(sc.check(0.10).is_ok());
748        assert!(sc.check(0.25).is_err());
749    }
750
751    #[test]
752    fn spend_status() {
753        let mut sc = SpendControl::new(SpendLimits {
754            hourly_usd: Some(5.0),
755            daily_usd: Some(20.0),
756            ..Default::default()
757        });
758        sc.record(1.50);
759        let status = sc.status();
760        assert!((status.hourly_spend - 1.50).abs() < 0.01);
761        assert_eq!(status.hourly_limit, Some(5.0));
762    }
763
764    // --- Benchmark Priors ---
765
766    #[test]
767    fn apply_priors() {
768        let mut tracker = crate::outcome::OutcomeTracker::new();
769        let mut priors = HashMap::new();
770        priors.insert(
771            "model-a".to_string(),
772            BenchmarkPrior {
773                overall_score: 0.85,
774                overall_latency_ms: Some(1100.0),
775                task_scores: HashMap::from([
776                    ("generate".to_string(), 0.82),
777                    ("code".to_string(), 0.91),
778                ]),
779                task_latency_ms: HashMap::from([
780                    ("generate".to_string(), 900.0),
781                    ("code".to_string(), 2100.0),
782                ]),
783            },
784        );
785        priors.insert(
786            "model-b".to_string(),
787            BenchmarkPrior {
788                overall_score: 0.60,
789                overall_latency_ms: Some(2000.0),
790                task_scores: HashMap::new(),
791                task_latency_ms: HashMap::new(),
792            },
793        );
794
795        apply_benchmark_priors(&mut tracker, &priors);
796
797        let profile_a = tracker.profile("model-a").unwrap();
798        assert!((profile_a.ema_quality - 0.85).abs() < 0.01);
799        assert!(
800            (profile_a
801                .task_stats(crate::outcome::InferenceTask::Generate)
802                .unwrap()
803                .ema_quality
804                - 0.82)
805                .abs()
806                < 0.01
807        );
808        assert!(
809            (profile_a
810                .task_stats(crate::outcome::InferenceTask::Code)
811                .unwrap()
812                .ema_quality
813                - 0.91)
814                .abs()
815                < 0.01
816        );
817        assert!(
818            (profile_a
819                .task_stats(crate::outcome::InferenceTask::Code)
820                .unwrap()
821                .avg_latency_ms
822                - 2100.0)
823                .abs()
824                < 0.01
825        );
826
827        let profile_b = tracker.profile("model-b").unwrap();
828        assert!((profile_b.ema_quality - 0.60).abs() < 0.01);
829    }
830
831    #[test]
832    fn priors_dont_overwrite_observed() {
833        let mut tracker = crate::outcome::OutcomeTracker::new();
834
835        // Record some observations first
836        let trace =
837            tracker.record_start("model-a", crate::outcome::InferenceTask::Generate, "test");
838        tracker.record_complete(&trace, 100, 10, 5);
839
840        // Now try to apply a prior — should NOT overwrite since we have observations
841        let mut priors = HashMap::new();
842        priors.insert(
843            "model-a".to_string(),
844            BenchmarkPrior {
845                overall_score: 0.99,
846                overall_latency_ms: Some(1500.0),
847                task_scores: HashMap::from([("generate".to_string(), 0.99)]),
848                task_latency_ms: HashMap::from([("generate".to_string(), 1500.0)]),
849            },
850        );
851        apply_benchmark_priors(&mut tracker, &priors);
852
853        let profile = tracker.profile("model-a").unwrap();
854        // Should still have original quality (0.5 neutral), not 0.99
855        assert!(profile.ema_quality < 0.9);
856    }
857
858    #[test]
859    fn load_benchmark_priors_extracts_task_scores_from_cases() {
860        let tmp = tempfile::NamedTempFile::new().unwrap();
861        std::fs::write(
862            tmp.path(),
863            serde_json::json!({
864                "model_id": "model-a",
865                "overall_score": 0.78,
866                "cases": [
867                    {"id": "basic_exact", "category": "basic", "score": 0.9, "latency_ms": 800},
868                    {"id": "code_fibonacci", "category": "code", "score": 0.8, "latency_ms": 2200},
869                    {"id": "reasoning_lp", "category": "reasoning", "score": 0.7, "latency_ms": 3300},
870                    {"id": "reasoning_lp_2", "category": "reasoning", "score": 0.5, "latency_ms": 2700}
871                ]
872            })
873            .to_string(),
874        )
875        .unwrap();
876
877        let priors = load_benchmark_priors(tmp.path()).unwrap();
878        let prior = priors.get("model-a").unwrap();
879
880        assert!((prior.overall_score - 0.78).abs() < 0.01);
881        assert!((prior.task_scores["generate"] - 0.9).abs() < 0.01);
882        assert!((prior.task_scores["code"] - 0.8).abs() < 0.01);
883        assert!((prior.task_scores["reasoning"] - 0.6).abs() < 0.01);
884        assert!((prior.task_latency_ms["generate"] - 800.0).abs() < 0.01);
885        assert!((prior.task_latency_ms["code"] - 2200.0).abs() < 0.01);
886        assert!((prior.task_latency_ms["reasoning"] - 3000.0).abs() < 0.01);
887    }
888
889    #[test]
890    fn load_benchmark_priors_maps_tool_and_vision_cases_to_generate() {
891        let tmp = tempfile::NamedTempFile::new().unwrap();
892        std::fs::write(
893            tmp.path(),
894            serde_json::json!({
895                "model_id": "model-a",
896                "overall_score": 1.0,
897                "cases": [
898                    {"id": "tool_weather", "category": "tool_use", "score": 1.0, "latency_ms": 1300},
899                    {"id": "vision_cat", "category": "vision", "score": 1.0, "latency_ms": 1700}
900                ]
901            })
902            .to_string(),
903        )
904        .unwrap();
905
906        let priors = load_benchmark_priors(tmp.path()).unwrap();
907        let prior = priors.get("model-a").unwrap();
908
909        assert!((prior.task_scores["generate"] - 1.0).abs() < 0.01);
910        assert!((prior.task_latency_ms["generate"] - 1500.0).abs() < 0.01);
911    }
912}