1use std::collections::HashMap;
5use std::time::{SystemTime, UNIX_EPOCH};
6
7use serde::{Deserialize, Serialize};
8use tracing::{debug, info, warn};
9
10use crate::outcome::{InferenceTask, TaskStats};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
16#[serde(rename_all = "snake_case")]
17pub enum RoutingMode {
18 #[default]
20 Auto,
21 Fast,
23 Best,
25}
26
27impl RoutingMode {
28 pub fn weights(&self) -> (f64, f64, f64) {
31 match self {
32 RoutingMode::Auto => (0.45, 0.40, 0.15),
33 RoutingMode::Fast => (0.15, 0.35, 0.50),
34 RoutingMode::Best => (0.70, 0.20, 0.10),
35 }
36 }
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
43#[serde(rename_all = "snake_case")]
44pub enum CircuitState {
45 Closed,
47 Open,
49 HalfOpen,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct CircuitBreaker {
56 pub state: CircuitState,
57 pub failure_count: u32,
59 pub failure_threshold: u32,
61 pub cooldown_secs: u64,
63 pub opened_at: u64,
65 pub trip_count: u32,
67}
68
69impl CircuitBreaker {
70 pub fn new(failure_threshold: u32, cooldown_secs: u64) -> Self {
71 Self {
72 state: CircuitState::Closed,
73 failure_count: 0,
74 failure_threshold,
75 cooldown_secs,
76 opened_at: 0,
77 trip_count: 0,
78 }
79 }
80
81 pub fn allow_request(&mut self) -> bool {
83 match self.state {
84 CircuitState::Closed => true,
85 CircuitState::Open => {
86 let now = now_unix();
88 if now.saturating_sub(self.opened_at) >= self.cooldown_secs {
89 self.state = CircuitState::HalfOpen;
90 debug!("circuit breaker: Open → HalfOpen (cooldown expired)");
91 true } else {
93 false
94 }
95 }
96 CircuitState::HalfOpen => {
97 false
100 }
101 }
102 }
103
104 pub fn record_success(&mut self) {
106 match self.state {
107 CircuitState::HalfOpen => {
108 self.state = CircuitState::Closed;
110 self.failure_count = 0;
111 info!("circuit breaker: HalfOpen → Closed (probe succeeded)");
112 }
113 CircuitState::Closed => {
114 self.failure_count = 0;
115 }
116 CircuitState::Open => {} }
118 }
119
120 pub fn record_failure(&mut self) {
122 self.failure_count += 1;
123
124 match self.state {
125 CircuitState::Closed => {
126 if self.failure_count >= self.failure_threshold {
127 self.state = CircuitState::Open;
128 self.opened_at = now_unix();
129 self.trip_count += 1;
130 warn!(
131 failures = self.failure_count,
132 trips = self.trip_count,
133 "circuit breaker: Closed → Open"
134 );
135 }
136 }
137 CircuitState::HalfOpen => {
138 self.state = CircuitState::Open;
140 self.opened_at = now_unix();
141 self.trip_count += 1;
142 warn!("circuit breaker: HalfOpen → Open (probe failed)");
143 }
144 CircuitState::Open => {} }
146 }
147
148 pub fn is_blocking(&self) -> bool {
150 matches!(self.state, CircuitState::Open)
151 }
152}
153
154impl Default for CircuitBreaker {
155 fn default() -> Self {
156 Self::new(3, 60)
157 }
158}
159
160#[derive(Debug, Default)]
162pub struct CircuitBreakerRegistry {
163 breakers: HashMap<String, CircuitBreaker>,
164 default_threshold: u32,
165 default_cooldown: u64,
166}
167
168impl CircuitBreakerRegistry {
169 pub fn new(default_threshold: u32, default_cooldown_secs: u64) -> Self {
170 Self {
171 breakers: HashMap::new(),
172 default_threshold,
173 default_cooldown: default_cooldown_secs,
174 }
175 }
176
177 pub fn allow_request(&mut self, model_id: &str) -> bool {
179 self.get_or_create(model_id).allow_request()
180 }
181
182 pub fn record_success(&mut self, model_id: &str) {
184 self.get_or_create(model_id).record_success();
185 }
186
187 pub fn record_failure(&mut self, model_id: &str) {
189 self.get_or_create(model_id).record_failure();
190 }
191
192 pub fn state(&self, model_id: &str) -> Option<CircuitState> {
194 self.breakers.get(model_id).map(|b| b.state)
195 }
196
197 fn get_or_create(&mut self, model_id: &str) -> &mut CircuitBreaker {
198 self.breakers
199 .entry(model_id.to_string())
200 .or_insert_with(|| CircuitBreaker::new(self.default_threshold, self.default_cooldown))
201 }
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct ImplicitSignal {
209 pub model_id: String,
210 pub signal_type: ImplicitSignalType,
211 pub timestamp: u64,
212}
213
214#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
215#[serde(rename_all = "snake_case")]
216pub enum ImplicitSignalType {
217 Success,
219 RateLimited,
221 ServerError,
223 ClientError,
225 Timeout,
227 Retried,
229}
230
231impl ImplicitSignalType {
232 pub fn quality_delta(&self) -> f64 {
235 match self {
236 ImplicitSignalType::Success => 1.0,
237 ImplicitSignalType::RateLimited => -0.3, ImplicitSignalType::ServerError => -0.8,
239 ImplicitSignalType::ClientError => -0.2, ImplicitSignalType::Timeout => -0.5,
241 ImplicitSignalType::Retried => -0.7, }
243 }
244
245 pub fn is_circuit_failure(&self) -> bool {
247 matches!(
248 self,
249 ImplicitSignalType::RateLimited
250 | ImplicitSignalType::ServerError
251 | ImplicitSignalType::Timeout
252 )
253 }
254}
255
256pub fn signal_from_status(status: u16) -> ImplicitSignalType {
258 match status {
259 200..=299 => ImplicitSignalType::Success,
260 429 => ImplicitSignalType::RateLimited,
261 400..=428 | 430..=499 => ImplicitSignalType::ClientError,
262 500..=599 => ImplicitSignalType::ServerError,
263 _ => ImplicitSignalType::ClientError,
264 }
265}
266
267#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct SpendLimits {
272 #[serde(default)]
274 pub per_request_usd: Option<f64>,
275 #[serde(default)]
277 pub hourly_usd: Option<f64>,
278 #[serde(default)]
280 pub daily_usd: Option<f64>,
281}
282
283impl Default for SpendLimits {
284 fn default() -> Self {
285 Self {
286 per_request_usd: None,
287 hourly_usd: None,
288 daily_usd: None,
289 }
290 }
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
295struct SpendRecord {
296 cost_usd: f64,
297 timestamp: u64,
298}
299
300#[derive(Debug)]
302pub struct SpendControl {
303 limits: SpendLimits,
304 records: Vec<SpendRecord>,
306}
307
308impl SpendControl {
309 pub fn new(limits: SpendLimits) -> Self {
310 Self {
311 limits,
312 records: Vec::new(),
313 }
314 }
315
316 pub fn check(&self, estimated_cost_usd: f64) -> Result<(), SpendLimitExceeded> {
319 if let Some(max) = self.limits.per_request_usd {
321 if estimated_cost_usd > max {
322 return Err(SpendLimitExceeded {
323 limit_type: "per_request".into(),
324 limit_usd: max,
325 current_usd: estimated_cost_usd,
326 window_secs: 0,
327 });
328 }
329 }
330
331 let now = now_unix();
332
333 if let Some(max) = self.limits.hourly_usd {
335 let hourly_spend = self.spend_in_window(now, 3600);
336 if hourly_spend + estimated_cost_usd > max {
337 return Err(SpendLimitExceeded {
338 limit_type: "hourly".into(),
339 limit_usd: max,
340 current_usd: hourly_spend,
341 window_secs: 3600,
342 });
343 }
344 }
345
346 if let Some(max) = self.limits.daily_usd {
348 let daily_spend = self.spend_in_window(now, 86400);
349 if daily_spend + estimated_cost_usd > max {
350 return Err(SpendLimitExceeded {
351 limit_type: "daily".into(),
352 limit_usd: max,
353 current_usd: daily_spend,
354 window_secs: 86400,
355 });
356 }
357 }
358
359 Ok(())
360 }
361
362 pub fn record(&mut self, cost_usd: f64) {
364 self.records.push(SpendRecord {
365 cost_usd,
366 timestamp: now_unix(),
367 });
368 let cutoff = now_unix().saturating_sub(86400);
370 self.records.retain(|r| r.timestamp >= cutoff);
371 }
372
373 pub fn spend_in_window(&self, now: u64, window_secs: u64) -> f64 {
375 let cutoff = now.saturating_sub(window_secs);
376 self.records
377 .iter()
378 .filter(|r| r.timestamp >= cutoff)
379 .map(|r| r.cost_usd)
380 .sum()
381 }
382
383 pub fn hourly_spend(&self) -> f64 {
385 self.spend_in_window(now_unix(), 3600)
386 }
387
388 pub fn daily_spend(&self) -> f64 {
390 self.spend_in_window(now_unix(), 86400)
391 }
392
393 pub fn status(&self) -> SpendStatus {
395 SpendStatus {
396 hourly_spend: self.hourly_spend(),
397 daily_spend: self.daily_spend(),
398 hourly_limit: self.limits.hourly_usd,
399 daily_limit: self.limits.daily_usd,
400 per_request_limit: self.limits.per_request_usd,
401 }
402 }
403}
404
405#[derive(Debug, Clone, Serialize, Deserialize)]
407pub struct SpendLimitExceeded {
408 pub limit_type: String,
409 pub limit_usd: f64,
410 pub current_usd: f64,
411 pub window_secs: u64,
412}
413
414impl std::fmt::Display for SpendLimitExceeded {
415 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416 write!(
417 f,
418 "{} spend limit exceeded: ${:.4} / ${:.4}",
419 self.limit_type, self.current_usd, self.limit_usd
420 )
421 }
422}
423
424#[derive(Debug, Clone, Serialize, Deserialize)]
426pub struct SpendStatus {
427 pub hourly_spend: f64,
428 pub daily_spend: f64,
429 pub hourly_limit: Option<f64>,
430 pub daily_limit: Option<f64>,
431 pub per_request_limit: Option<f64>,
432}
433
434#[derive(Debug, Clone, Serialize, Deserialize)]
437pub struct BenchmarkPrior {
438 pub overall_score: f64,
439 #[serde(default)]
440 pub overall_latency_ms: Option<f64>,
441 #[serde(default)]
442 pub task_scores: HashMap<String, f64>,
443 #[serde(default)]
444 pub task_latency_ms: HashMap<String, f64>,
445}
446
447pub fn apply_benchmark_priors(
453 tracker: &mut crate::outcome::OutcomeTracker,
454 benchmark_priors: &HashMap<String, BenchmarkPrior>,
455) {
456 for (model_id, prior) in benchmark_priors {
457 let profile = tracker.profile(model_id);
458 if profile.is_none() || profile.map(|p| p.total_calls == 0).unwrap_or(true) {
459 let mut new_profile = crate::outcome::ModelProfile::new(model_id.clone());
461 new_profile.ema_quality = prior.overall_score.clamp(0.0, 1.0);
462 for (task, score) in &prior.task_scores {
463 new_profile.task_stats.insert(
464 task.clone(),
465 TaskStats {
466 ema_quality: score.clamp(0.0, 1.0),
467 avg_latency_ms: prior
468 .task_latency_ms
469 .get(task)
470 .copied()
471 .unwrap_or_default(),
472 ..Default::default()
473 },
474 );
475 }
476 tracker.import_profiles(vec![new_profile]);
477 debug!(
478 model = %model_id,
479 quality = prior.overall_score,
480 task_priors = prior.task_scores.len(),
481 latency_priors = prior.task_latency_ms.len(),
482 "set benchmark quality prior"
483 );
484 }
485 }
486}
487
488pub fn load_benchmark_priors(
490 path: &std::path::Path,
491) -> Result<HashMap<String, BenchmarkPrior>, String> {
492 if !path.exists() {
493 return Ok(HashMap::new());
494 }
495 let json = std::fs::read_to_string(path).map_err(|e| e.to_string())?;
496 let value: serde_json::Value = serde_json::from_str(&json).map_err(|e| e.to_string())?;
497
498 let mut priors = HashMap::new();
499
500 if let Some(model_id) = value.get("model_id").and_then(|v| v.as_str()) {
502 if let Some(overall) = value.get("overall_score").and_then(|v| v.as_f64()) {
503 priors.insert(
504 model_id.to_string(),
505 BenchmarkPrior {
506 overall_score: overall,
507 overall_latency_ms: value.get("avg_latency_ms").and_then(|v| v.as_f64()),
508 task_scores: extract_task_scores(&value),
509 task_latency_ms: extract_task_latencies(&value),
510 },
511 );
512 }
513 }
514
515 if let Some(arr) = value.as_array() {
517 for item in arr {
518 if let (Some(id), Some(score)) = (
519 item.get("model_id").and_then(|v| v.as_str()),
520 item.get("overall_score").and_then(|v| v.as_f64()),
521 ) {
522 priors.insert(
523 id.to_string(),
524 BenchmarkPrior {
525 overall_score: score,
526 overall_latency_ms: item.get("avg_latency_ms").and_then(|v| v.as_f64()),
527 task_scores: extract_task_scores(item),
528 task_latency_ms: extract_task_latencies(item),
529 },
530 );
531 }
532 }
533 }
534
535 Ok(priors)
536}
537
538fn now_unix() -> u64 {
541 SystemTime::now()
542 .duration_since(UNIX_EPOCH)
543 .unwrap_or_default()
544 .as_secs()
545}
546
547fn extract_task_scores(value: &serde_json::Value) -> HashMap<String, f64> {
548 let mut task_scores: HashMap<String, Vec<f64>> = HashMap::new();
549 let Some(cases) = value.get("cases").and_then(|v| v.as_array()) else {
550 return HashMap::new();
551 };
552
553 for case in cases {
554 let Some(category) = case.get("category").and_then(|v| v.as_str()) else {
555 continue;
556 };
557 let Some(score) = case.get("score").and_then(|v| v.as_f64()) else {
558 continue;
559 };
560 if let Some(task) = benchmark_category_to_task(category) {
561 task_scores
562 .entry(task.to_string())
563 .or_default()
564 .push(score.clamp(0.0, 1.0));
565 }
566 }
567
568 task_scores
569 .into_iter()
570 .map(|(task, scores)| {
571 let avg = scores.iter().sum::<f64>() / scores.len() as f64;
572 (task, avg)
573 })
574 .collect()
575}
576
577fn extract_task_latencies(value: &serde_json::Value) -> HashMap<String, f64> {
578 let mut task_latencies: HashMap<String, Vec<f64>> = HashMap::new();
579 let Some(cases) = value.get("cases").and_then(|v| v.as_array()) else {
580 return HashMap::new();
581 };
582
583 for case in cases {
584 let Some(category) = case.get("category").and_then(|v| v.as_str()) else {
585 continue;
586 };
587 let Some(latency_ms) = case.get("latency_ms").and_then(|v| v.as_f64()) else {
588 continue;
589 };
590 if let Some(task) = benchmark_category_to_task(category) {
591 task_latencies
592 .entry(task.to_string())
593 .or_default()
594 .push(latency_ms.max(1.0));
595 }
596 }
597
598 task_latencies
599 .into_iter()
600 .map(|(task, latencies)| {
601 let avg = latencies.iter().sum::<f64>() / latencies.len() as f64;
602 (task, avg)
603 })
604 .collect()
605}
606
607fn benchmark_category_to_task(category: &str) -> Option<InferenceTask> {
608 match category {
609 "basic" | "generate" | "tool_use" | "vision" => Some(InferenceTask::Generate),
610 "code" | "coding" => Some(InferenceTask::Code),
611 "reasoning" | "analysis" => Some(InferenceTask::Reasoning),
612 "classify" | "classification" => Some(InferenceTask::Classify),
613 "embed" | "embedding" => Some(InferenceTask::Embed),
614 _ => None,
615 }
616}
617
618#[cfg(test)]
621mod tests {
622 use super::*;
623
624 #[test]
627 fn routing_mode_weights() {
628 let (q, l, c) = RoutingMode::Auto.weights();
629 assert!((q + l + c - 1.0).abs() < 0.01);
630
631 let (q, _, c) = RoutingMode::Fast.weights();
632 assert!(c > q, "Fast mode should weight cost > quality");
633
634 let (q, _, c) = RoutingMode::Best.weights();
635 assert!(q > c, "Best mode should weight quality > cost");
636 }
637
638 #[test]
641 fn circuit_breaker_lifecycle() {
642 let mut cb = CircuitBreaker::new(3, 60);
643 assert_eq!(cb.state, CircuitState::Closed);
644 assert!(cb.allow_request());
645
646 cb.record_failure();
648 cb.record_failure();
649 assert_eq!(cb.state, CircuitState::Closed);
650 assert!(cb.allow_request());
651
652 cb.record_failure();
654 assert_eq!(cb.state, CircuitState::Open);
655 assert!(!cb.allow_request()); let mut cb2 = CircuitBreaker::new(3, 60);
659 cb2.record_failure();
660 cb2.record_failure();
661 cb2.record_success();
662 assert_eq!(cb2.failure_count, 0);
663 }
664
665 #[test]
666 fn circuit_breaker_half_open_recovery() {
667 let mut cb = CircuitBreaker::new(2, 0); cb.record_failure();
669 cb.record_failure();
670 assert_eq!(cb.state, CircuitState::Open);
671
672 assert!(cb.allow_request());
674 assert_eq!(cb.state, CircuitState::HalfOpen);
675
676 cb.record_success();
678 assert_eq!(cb.state, CircuitState::Closed);
679 }
680
681 #[test]
682 fn circuit_breaker_half_open_failure() {
683 let mut cb = CircuitBreaker::new(2, 0);
684 cb.record_failure();
685 cb.record_failure();
686 assert!(cb.allow_request()); assert_eq!(cb.state, CircuitState::HalfOpen);
688
689 cb.record_failure();
691 assert_eq!(cb.state, CircuitState::Open);
692 assert_eq!(cb.trip_count, 2);
693 }
694
695 #[test]
696 fn circuit_breaker_registry() {
697 let mut reg = CircuitBreakerRegistry::new(2, 0);
698 assert!(reg.allow_request("model-a"));
699
700 reg.record_failure("model-a");
701 reg.record_failure("model-a");
702 assert!(
703 !reg.allow_request("model-a") || reg.state("model-a") == Some(CircuitState::HalfOpen)
704 );
705
706 assert!(reg.allow_request("model-b"));
708 }
709
710 #[test]
713 fn signal_from_http_status() {
714 assert_eq!(signal_from_status(200), ImplicitSignalType::Success);
715 assert_eq!(signal_from_status(429), ImplicitSignalType::RateLimited);
716 assert_eq!(signal_from_status(500), ImplicitSignalType::ServerError);
717 assert_eq!(signal_from_status(400), ImplicitSignalType::ClientError);
718 }
719
720 #[test]
721 fn quality_deltas() {
722 assert!(ImplicitSignalType::Success.quality_delta() > 0.0);
723 assert!(ImplicitSignalType::ServerError.quality_delta() < 0.0);
724 assert!(ImplicitSignalType::Retried.quality_delta() < 0.0);
725 }
726
727 #[test]
730 fn spend_per_request_limit() {
731 let sc = SpendControl::new(SpendLimits {
732 per_request_usd: Some(0.10),
733 ..Default::default()
734 });
735 assert!(sc.check(0.05).is_ok());
736 assert!(sc.check(0.15).is_err());
737 }
738
739 #[test]
740 fn spend_hourly_limit() {
741 let mut sc = SpendControl::new(SpendLimits {
742 hourly_usd: Some(1.00),
743 ..Default::default()
744 });
745 sc.record(0.40);
746 sc.record(0.40);
747 assert!(sc.check(0.10).is_ok());
748 assert!(sc.check(0.25).is_err());
749 }
750
751 #[test]
752 fn spend_status() {
753 let mut sc = SpendControl::new(SpendLimits {
754 hourly_usd: Some(5.0),
755 daily_usd: Some(20.0),
756 ..Default::default()
757 });
758 sc.record(1.50);
759 let status = sc.status();
760 assert!((status.hourly_spend - 1.50).abs() < 0.01);
761 assert_eq!(status.hourly_limit, Some(5.0));
762 }
763
764 #[test]
767 fn apply_priors() {
768 let mut tracker = crate::outcome::OutcomeTracker::new();
769 let mut priors = HashMap::new();
770 priors.insert(
771 "model-a".to_string(),
772 BenchmarkPrior {
773 overall_score: 0.85,
774 overall_latency_ms: Some(1100.0),
775 task_scores: HashMap::from([
776 ("generate".to_string(), 0.82),
777 ("code".to_string(), 0.91),
778 ]),
779 task_latency_ms: HashMap::from([
780 ("generate".to_string(), 900.0),
781 ("code".to_string(), 2100.0),
782 ]),
783 },
784 );
785 priors.insert(
786 "model-b".to_string(),
787 BenchmarkPrior {
788 overall_score: 0.60,
789 overall_latency_ms: Some(2000.0),
790 task_scores: HashMap::new(),
791 task_latency_ms: HashMap::new(),
792 },
793 );
794
795 apply_benchmark_priors(&mut tracker, &priors);
796
797 let profile_a = tracker.profile("model-a").unwrap();
798 assert!((profile_a.ema_quality - 0.85).abs() < 0.01);
799 assert!(
800 (profile_a
801 .task_stats(crate::outcome::InferenceTask::Generate)
802 .unwrap()
803 .ema_quality
804 - 0.82)
805 .abs()
806 < 0.01
807 );
808 assert!(
809 (profile_a
810 .task_stats(crate::outcome::InferenceTask::Code)
811 .unwrap()
812 .ema_quality
813 - 0.91)
814 .abs()
815 < 0.01
816 );
817 assert!(
818 (profile_a
819 .task_stats(crate::outcome::InferenceTask::Code)
820 .unwrap()
821 .avg_latency_ms
822 - 2100.0)
823 .abs()
824 < 0.01
825 );
826
827 let profile_b = tracker.profile("model-b").unwrap();
828 assert!((profile_b.ema_quality - 0.60).abs() < 0.01);
829 }
830
831 #[test]
832 fn priors_dont_overwrite_observed() {
833 let mut tracker = crate::outcome::OutcomeTracker::new();
834
835 let trace =
837 tracker.record_start("model-a", crate::outcome::InferenceTask::Generate, "test");
838 tracker.record_complete(&trace, 100, 10, 5);
839
840 let mut priors = HashMap::new();
842 priors.insert(
843 "model-a".to_string(),
844 BenchmarkPrior {
845 overall_score: 0.99,
846 overall_latency_ms: Some(1500.0),
847 task_scores: HashMap::from([("generate".to_string(), 0.99)]),
848 task_latency_ms: HashMap::from([("generate".to_string(), 1500.0)]),
849 },
850 );
851 apply_benchmark_priors(&mut tracker, &priors);
852
853 let profile = tracker.profile("model-a").unwrap();
854 assert!(profile.ema_quality < 0.9);
856 }
857
858 #[test]
859 fn load_benchmark_priors_extracts_task_scores_from_cases() {
860 let tmp = tempfile::NamedTempFile::new().unwrap();
861 std::fs::write(
862 tmp.path(),
863 serde_json::json!({
864 "model_id": "model-a",
865 "overall_score": 0.78,
866 "cases": [
867 {"id": "basic_exact", "category": "basic", "score": 0.9, "latency_ms": 800},
868 {"id": "code_fibonacci", "category": "code", "score": 0.8, "latency_ms": 2200},
869 {"id": "reasoning_lp", "category": "reasoning", "score": 0.7, "latency_ms": 3300},
870 {"id": "reasoning_lp_2", "category": "reasoning", "score": 0.5, "latency_ms": 2700}
871 ]
872 })
873 .to_string(),
874 )
875 .unwrap();
876
877 let priors = load_benchmark_priors(tmp.path()).unwrap();
878 let prior = priors.get("model-a").unwrap();
879
880 assert!((prior.overall_score - 0.78).abs() < 0.01);
881 assert!((prior.task_scores["generate"] - 0.9).abs() < 0.01);
882 assert!((prior.task_scores["code"] - 0.8).abs() < 0.01);
883 assert!((prior.task_scores["reasoning"] - 0.6).abs() < 0.01);
884 assert!((prior.task_latency_ms["generate"] - 800.0).abs() < 0.01);
885 assert!((prior.task_latency_ms["code"] - 2200.0).abs() < 0.01);
886 assert!((prior.task_latency_ms["reasoning"] - 3000.0).abs() < 0.01);
887 }
888
889 #[test]
890 fn load_benchmark_priors_maps_tool_and_vision_cases_to_generate() {
891 let tmp = tempfile::NamedTempFile::new().unwrap();
892 std::fs::write(
893 tmp.path(),
894 serde_json::json!({
895 "model_id": "model-a",
896 "overall_score": 1.0,
897 "cases": [
898 {"id": "tool_weather", "category": "tool_use", "score": 1.0, "latency_ms": 1300},
899 {"id": "vision_cat", "category": "vision", "score": 1.0, "latency_ms": 1700}
900 ]
901 })
902 .to_string(),
903 )
904 .unwrap();
905
906 let priors = load_benchmark_priors(tmp.path()).unwrap();
907 let prior = priors.get("model-a").unwrap();
908
909 assert!((prior.task_scores["generate"] - 1.0).abs() < 0.01);
910 assert!((prior.task_latency_ms["generate"] - 1500.0).abs() < 0.01);
911 }
912}