aprender/bench/
mod.rs

1//! Model Evaluation and Benchmarking Framework (`aprender::bench`)
2//!
3//! Provides multi-model comparison for evaluating `.apr` models on custom tasks.
4//! Unlike QA (single-model validation), this module compares multiple models
5//! to find the **smallest model that meets a performance threshold**.
6//!
7//! # Toyota Way Alignment
8//! - **Pull Systems (P3)**: Pareto frontier pulls smallest viable model
9//! - **Muda Elimination**: Avoid overprovisioning with right-sized models
10//!
11//! # References
12//! - Deb et al. (2002) "NSGA-II" for Pareto optimization
13//!
14//! # Example
15//! ```
16//! use aprender::bench::{EvalResult, ModelComparison};
17//!
18//! let comparison = ModelComparison::new("python-to-rust");
19//! assert!(comparison.results.is_empty());
20//! ```
21
22pub mod pareto;
23pub mod py2rs;
24
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use std::path::PathBuf;
28use std::time::Duration;
29
30/// Custom evaluation task trait
31pub trait EvalTask: Send + Sync {
32    /// Unique task identifier
33    fn id(&self) -> &str;
34
35    /// Human-readable description
36    fn description(&self) -> &str;
37
38    /// Input examples to evaluate
39    fn examples(&self) -> &[Example];
40
41    /// Maximum turns before declaring failure
42    fn max_turns(&self) -> u32 {
43        5
44    }
45
46    /// Timeout per turn
47    fn turn_timeout(&self) -> Duration {
48        Duration::from_secs(60)
49    }
50}
51
52/// Example input for evaluation
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct Example {
55    /// Unique example ID
56    pub id: String,
57    /// Input prompt/code
58    pub input: String,
59    /// Expected output or behavior
60    pub expected: String,
61    /// Difficulty tier
62    pub difficulty: Difficulty,
63    /// Tags for filtering
64    pub tags: Vec<String>,
65}
66
67impl Example {
68    /// Create a new example
69    #[must_use]
70    pub fn new(
71        id: impl Into<String>,
72        input: impl Into<String>,
73        expected: impl Into<String>,
74    ) -> Self {
75        Self {
76            id: id.into(),
77            input: input.into(),
78            expected: expected.into(),
79            difficulty: Difficulty::Medium,
80            tags: Vec::new(),
81        }
82    }
83
84    /// Set difficulty
85    #[must_use]
86    pub fn with_difficulty(mut self, difficulty: Difficulty) -> Self {
87        self.difficulty = difficulty;
88        self
89    }
90
91    /// Add tags
92    #[must_use]
93    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
94        self.tags = tags;
95        self
96    }
97}
98
99/// Difficulty tier for stratified analysis
100#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
101pub enum Difficulty {
102    /// 1-liner, obvious translation
103    Trivial,
104    /// Simple logic, standard patterns
105    Easy,
106    /// Multiple functions, error handling
107    Medium,
108    /// Complex algorithms, unsafe/FFI
109    Hard,
110    /// Requires deep language knowledge
111    Expert,
112}
113
114impl Difficulty {
115    /// Get numeric level (1-5)
116    #[must_use]
117    pub const fn level(&self) -> u8 {
118        match self {
119            Self::Trivial => 1,
120            Self::Easy => 2,
121            Self::Medium => 3,
122            Self::Hard => 4,
123            Self::Expert => 5,
124        }
125    }
126
127    /// Get display name
128    #[must_use]
129    pub const fn name(&self) -> &'static str {
130        match self {
131            Self::Trivial => "Trivial",
132            Self::Easy => "Easy",
133            Self::Medium => "Medium",
134            Self::Hard => "Hard",
135            Self::Expert => "Expert",
136        }
137    }
138
139    /// All difficulties in order
140    #[must_use]
141    pub const fn all() -> [Self; 5] {
142        [
143            Self::Trivial,
144            Self::Easy,
145            Self::Medium,
146            Self::Hard,
147            Self::Expert,
148        ]
149    }
150}
151
152/// Result of evaluating a single model on a single task
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct EvalResult {
155    /// Model identifier
156    pub model_id: String,
157    /// Model size in bytes
158    pub model_size_bytes: u64,
159    /// Model parameter count (if known)
160    pub model_params: Option<u64>,
161    /// Task evaluated
162    pub task_id: String,
163    /// Per-example results
164    pub example_results: Vec<ExampleResult>,
165    /// Success rate by turn (cumulative)
166    /// e.g., [0.60, 0.85, 0.95] = 60% turn 1, 85% by turn 2, 95% by turn 3
167    pub success_by_turn: Vec<f64>,
168    /// Average turns to success (for successful examples)
169    pub avg_turns_to_success: f64,
170    /// Overall success rate (any turn)
171    pub overall_success_rate: f64,
172    /// Total tokens consumed
173    pub total_tokens: u64,
174    /// Total latency
175    pub total_latency: Duration,
176}
177
178impl EvalResult {
179    /// Create a new evaluation result
180    #[must_use]
181    pub fn new(
182        model_id: impl Into<String>,
183        task_id: impl Into<String>,
184        model_size_bytes: u64,
185    ) -> Self {
186        Self {
187            model_id: model_id.into(),
188            model_size_bytes,
189            model_params: None,
190            task_id: task_id.into(),
191            example_results: Vec::new(),
192            success_by_turn: Vec::new(),
193            avg_turns_to_success: 0.0,
194            overall_success_rate: 0.0,
195            total_tokens: 0,
196            total_latency: Duration::ZERO,
197        }
198    }
199
200    /// Add an example result
201    pub fn add_example(&mut self, result: ExampleResult) {
202        self.total_tokens += result.tokens_per_turn.iter().sum::<u64>();
203        self.total_latency += result.latency_per_turn.iter().sum::<Duration>();
204        self.example_results.push(result);
205    }
206
207    /// Finalize the evaluation (compute aggregate metrics)
208    pub fn finalize(&mut self, max_turns: u32) {
209        let total = self.example_results.len();
210        if total == 0 {
211            return;
212        }
213
214        // Compute success by turn (cumulative)
215        self.success_by_turn = Vec::with_capacity(max_turns as usize);
216        for turn in 1..=max_turns {
217            let solved = self
218                .example_results
219                .iter()
220                .filter(|r| matches!(r.status, ExampleStatus::Solved { turn: t } if t <= turn))
221                .count();
222            self.success_by_turn.push(solved as f64 / total as f64);
223        }
224
225        // Compute average turns to success
226        let solved_examples: Vec<_> = self
227            .example_results
228            .iter()
229            .filter_map(|r| match r.status {
230                ExampleStatus::Solved { turn } => Some(turn),
231                _ => None,
232            })
233            .collect();
234
235        if !solved_examples.is_empty() {
236            self.avg_turns_to_success =
237                f64::from(solved_examples.iter().sum::<u32>()) / solved_examples.len() as f64;
238        }
239
240        // Overall success rate
241        self.overall_success_rate = solved_examples.len() as f64 / total as f64;
242    }
243
244    /// Get success rate at a specific turn
245    #[must_use]
246    pub fn success_at_turn(&self, turn: u32) -> f64 {
247        self.success_by_turn
248            .get((turn - 1) as usize)
249            .copied()
250            .unwrap_or(0.0)
251    }
252}
253
254/// Result for a single example
255#[derive(Debug, Clone, Serialize, Deserialize)]
256pub struct ExampleResult {
257    /// Example ID
258    pub example_id: String,
259    /// Difficulty
260    pub difficulty: Difficulty,
261    /// Which turn solved it (None = failed all turns)
262    pub solved_at_turn: Option<u32>,
263    /// Tokens per turn
264    pub tokens_per_turn: Vec<u64>,
265    /// Latency per turn
266    pub latency_per_turn: Vec<Duration>,
267    /// Final status
268    pub status: ExampleStatus,
269}
270
271impl ExampleResult {
272    /// Create a solved example result
273    #[must_use]
274    pub fn solved(
275        example_id: impl Into<String>,
276        difficulty: Difficulty,
277        turn: u32,
278        tokens: Vec<u64>,
279        latencies: Vec<Duration>,
280    ) -> Self {
281        Self {
282            example_id: example_id.into(),
283            difficulty,
284            solved_at_turn: Some(turn),
285            tokens_per_turn: tokens,
286            latency_per_turn: latencies,
287            status: ExampleStatus::Solved { turn },
288        }
289    }
290
291    /// Create a failed example result
292    #[must_use]
293    pub fn failed(
294        example_id: impl Into<String>,
295        difficulty: Difficulty,
296        attempts: u32,
297        last_error: impl Into<String>,
298        tokens: Vec<u64>,
299        latencies: Vec<Duration>,
300    ) -> Self {
301        Self {
302            example_id: example_id.into(),
303            difficulty,
304            solved_at_turn: None,
305            tokens_per_turn: tokens,
306            latency_per_turn: latencies,
307            status: ExampleStatus::Failed {
308                attempts,
309                last_error: last_error.into(),
310            },
311        }
312    }
313}
314
315/// Status of an example evaluation
316#[derive(Debug, Clone, Serialize, Deserialize)]
317pub enum ExampleStatus {
318    /// Solved within `max_turns`
319    Solved {
320        /// Turn on which it was solved
321        turn: u32,
322    },
323    /// Failed all turns
324    Failed {
325        /// Number of attempts
326        attempts: u32,
327        /// Last error message
328        last_error: String,
329    },
330    /// Timed out
331    Timeout {
332        /// Turn on which it timed out
333        turn: u32,
334    },
335    /// Skipped
336    Skipped {
337        /// Reason for skipping
338        reason: String,
339    },
340}
341
342/// Compare multiple models on the same task
343#[derive(Debug, Clone, Serialize, Deserialize)]
344pub struct ModelComparison {
345    /// Task being evaluated
346    pub task_id: String,
347    /// Results per model
348    pub results: Vec<EvalResult>,
349    /// Pareto-optimal models
350    pub pareto_frontier: Vec<ParetoPoint>,
351    /// Recommendations
352    pub recommendations: Vec<Recommendation>,
353}
354
355impl ModelComparison {
356    /// Create a new model comparison
357    #[must_use]
358    pub fn new(task_id: impl Into<String>) -> Self {
359        Self {
360            task_id: task_id.into(),
361            results: Vec::new(),
362            pareto_frontier: Vec::new(),
363            recommendations: Vec::new(),
364        }
365    }
366
367    /// Add an evaluation result
368    pub fn add_result(&mut self, result: EvalResult) {
369        self.results.push(result);
370    }
371
372    /// Find smallest model meeting a success threshold
373    #[must_use]
374    pub fn smallest_meeting_threshold(&self, min_success: f64) -> Option<&EvalResult> {
375        self.results
376            .iter()
377            .filter(|r| r.overall_success_rate >= min_success)
378            .min_by_key(|r| r.model_size_bytes)
379    }
380
381    /// Find fastest model meeting a success threshold
382    #[must_use]
383    pub fn fastest_meeting_threshold(&self, min_success: f64) -> Option<&EvalResult> {
384        self.results
385            .iter()
386            .filter(|r| r.overall_success_rate >= min_success)
387            .min_by(|a, b| {
388                a.avg_turns_to_success
389                    .partial_cmp(&b.avg_turns_to_success)
390                    .unwrap_or(std::cmp::Ordering::Equal)
391            })
392    }
393
394    /// Compute Pareto frontier (size vs accuracy)
395    pub fn compute_pareto_frontier(&mut self) {
396        self.pareto_frontier = pareto::compute_pareto_frontier(&self.results);
397    }
398
399    /// Generate recommendations based on results
400    pub fn generate_recommendations(&mut self) {
401        self.recommendations.clear();
402
403        // Smallest overall
404        if let Some(smallest) = self.results.iter().min_by_key(|r| r.model_size_bytes) {
405            self.recommendations.push(Recommendation {
406                scenario: "Minimum footprint".to_string(),
407                model_id: smallest.model_id.clone(),
408                rationale: format!(
409                    "Smallest model at {} bytes, {:.1}% success",
410                    smallest.model_size_bytes,
411                    smallest.overall_success_rate * 100.0
412                ),
413            });
414        }
415
416        // Best accuracy
417        if let Some(best) = self.results.iter().max_by(|a, b| {
418            a.overall_success_rate
419                .partial_cmp(&b.overall_success_rate)
420                .unwrap_or(std::cmp::Ordering::Equal)
421        }) {
422            self.recommendations.push(Recommendation {
423                scenario: "Maximum accuracy".to_string(),
424                model_id: best.model_id.clone(),
425                rationale: format!(
426                    "Highest success rate at {:.1}%",
427                    best.overall_success_rate * 100.0
428                ),
429            });
430        }
431
432        // Best balance (on Pareto frontier, closest to ideal)
433        if let Some(best) = self.pareto_frontier.iter().max_by(|a, b| {
434            // Score: success_rate - normalized_size
435            let score_a = a.success_rate - (a.size_bytes as f64 / 1e9);
436            let score_b = b.success_rate - (b.size_bytes as f64 / 1e9);
437            score_a
438                .partial_cmp(&score_b)
439                .unwrap_or(std::cmp::Ordering::Equal)
440        }) {
441            self.recommendations.push(Recommendation {
442                scenario: "Best balance".to_string(),
443                model_id: best.model_id.clone(),
444                rationale: format!(
445                    "Pareto optimal: {:.1}% success at {} bytes",
446                    best.success_rate * 100.0,
447                    best.size_bytes
448                ),
449            });
450        }
451    }
452
453    /// Get stratified success rates by difficulty
454    #[must_use]
455    pub fn stratified_by_difficulty(&self) -> HashMap<String, HashMap<Difficulty, f64>> {
456        let mut result = HashMap::new();
457
458        for eval in &self.results {
459            let mut by_difficulty: HashMap<Difficulty, (usize, usize)> = HashMap::new();
460
461            for example in &eval.example_results {
462                let entry = by_difficulty.entry(example.difficulty).or_insert((0, 0));
463                entry.1 += 1; // total
464                if example.solved_at_turn.is_some() {
465                    entry.0 += 1; // solved
466                }
467            }
468
469            let rates: HashMap<Difficulty, f64> = by_difficulty
470                .into_iter()
471                .map(|(d, (solved, total))| {
472                    let rate = if total > 0 {
473                        solved as f64 / total as f64
474                    } else {
475                        0.0
476                    };
477                    (d, rate)
478                })
479                .collect();
480
481            result.insert(eval.model_id.clone(), rates);
482        }
483
484        result
485    }
486}
487
488/// Point on the Pareto frontier
489#[derive(Debug, Clone, Serialize, Deserialize)]
490pub struct ParetoPoint {
491    /// Model ID
492    pub model_id: String,
493    /// Size in bytes
494    pub size_bytes: u64,
495    /// Success rate
496    pub success_rate: f64,
497    /// Average turns
498    pub avg_turns: f64,
499    /// Is Pareto optimal
500    pub is_pareto_optimal: bool,
501}
502
503/// Recommendation for a specific scenario
504#[derive(Debug, Clone, Serialize, Deserialize)]
505pub struct Recommendation {
506    /// Scenario description
507    pub scenario: String,
508    /// Recommended model ID
509    pub model_id: String,
510    /// Rationale
511    pub rationale: String,
512}
513
514/// Evaluation suite configuration
515#[derive(Debug, Clone, Serialize, Deserialize)]
516pub struct EvalSuiteConfig {
517    /// Suite ID
518    pub id: String,
519    /// Description
520    pub description: String,
521    /// Maximum turns
522    pub max_turns: u32,
523    /// Turn timeout
524    pub turn_timeout_secs: u64,
525    /// Examples path
526    pub examples_path: PathBuf,
527    /// Success thresholds for recommendations
528    pub success_thresholds: Vec<f64>,
529}
530
531impl Default for EvalSuiteConfig {
532    fn default() -> Self {
533        Self {
534            id: "default".to_string(),
535            description: "Default evaluation suite".to_string(),
536            max_turns: 5,
537            turn_timeout_secs: 60,
538            examples_path: PathBuf::from("./examples"),
539            success_thresholds: vec![0.80, 0.90, 0.95],
540        }
541    }
542}
543
544#[cfg(test)]
545mod tests {
546    use super::*;
547
548    #[test]
549    fn test_example_creation() {
550        let example = Example::new("ex1", "print('hello')", "println!(\"hello\");")
551            .with_difficulty(Difficulty::Trivial)
552            .with_tags(vec!["hello".to_string()]);
553
554        assert_eq!(example.id, "ex1");
555        assert_eq!(example.difficulty, Difficulty::Trivial);
556        assert_eq!(example.tags.len(), 1);
557    }
558
559    #[test]
560    fn test_difficulty_levels() {
561        assert_eq!(Difficulty::Trivial.level(), 1);
562        assert_eq!(Difficulty::Expert.level(), 5);
563        assert_eq!(Difficulty::all().len(), 5);
564    }
565
566    #[test]
567    fn test_eval_result_creation() {
568        let mut result = EvalResult::new("model1", "py2rs", 1_000_000);
569        assert_eq!(result.model_id, "model1");
570        assert_eq!(result.model_size_bytes, 1_000_000);
571        assert!(result.example_results.is_empty());
572
573        // Add some example results
574        result.add_example(ExampleResult::solved(
575            "ex1",
576            Difficulty::Easy,
577            1,
578            vec![100],
579            vec![Duration::from_millis(50)],
580        ));
581        result.add_example(ExampleResult::solved(
582            "ex2",
583            Difficulty::Medium,
584            2,
585            vec![100, 150],
586            vec![Duration::from_millis(50), Duration::from_millis(75)],
587        ));
588        result.add_example(ExampleResult::failed(
589            "ex3",
590            Difficulty::Hard,
591            3,
592            "Compile error",
593            vec![100, 150, 200],
594            vec![Duration::from_millis(50); 3],
595        ));
596
597        result.finalize(3);
598
599        assert_eq!(result.example_results.len(), 3);
600        assert!((result.overall_success_rate - 2.0 / 3.0).abs() < 0.01);
601        assert!(result.avg_turns_to_success > 0.0);
602    }
603
604    #[test]
605    fn test_eval_result_success_by_turn() {
606        let mut result = EvalResult::new("model1", "task1", 1000);
607
608        // 3 examples: turn 1, turn 2, failed
609        result.add_example(ExampleResult::solved(
610            "ex1",
611            Difficulty::Easy,
612            1,
613            vec![100],
614            vec![Duration::from_millis(10)],
615        ));
616        result.add_example(ExampleResult::solved(
617            "ex2",
618            Difficulty::Medium,
619            2,
620            vec![100, 100],
621            vec![Duration::from_millis(10); 2],
622        ));
623        result.add_example(ExampleResult::failed(
624            "ex3",
625            Difficulty::Hard,
626            3,
627            "Failed",
628            vec![100; 3],
629            vec![Duration::from_millis(10); 3],
630        ));
631
632        result.finalize(3);
633
634        // Turn 1: 1/3 solved
635        assert!((result.success_at_turn(1) - 1.0 / 3.0).abs() < 0.01);
636        // Turn 2: 2/3 solved
637        assert!((result.success_at_turn(2) - 2.0 / 3.0).abs() < 0.01);
638        // Turn 3: still 2/3 (failed didn't solve)
639        assert!((result.success_at_turn(3) - 2.0 / 3.0).abs() < 0.01);
640    }
641
642    #[test]
643    fn test_model_comparison() {
644        let mut comparison = ModelComparison::new("py2rs");
645
646        // Add two models
647        let mut model1 = EvalResult::new("small", "py2rs", 1_000_000);
648        model1.add_example(ExampleResult::solved(
649            "ex1",
650            Difficulty::Easy,
651            1,
652            vec![100],
653            vec![Duration::from_millis(10)],
654        ));
655        model1.finalize(3);
656
657        let mut model2 = EvalResult::new("large", "py2rs", 10_000_000);
658        model2.add_example(ExampleResult::solved(
659            "ex1",
660            Difficulty::Easy,
661            1,
662            vec![100],
663            vec![Duration::from_millis(10)],
664        ));
665        model2.finalize(3);
666
667        comparison.add_result(model1);
668        comparison.add_result(model2);
669
670        assert_eq!(comparison.results.len(), 2);
671
672        // Both have 100% success
673        let smallest = comparison.smallest_meeting_threshold(0.9);
674        assert!(smallest.is_some());
675        assert_eq!(smallest.unwrap().model_id, "small");
676    }
677
678    #[test]
679    fn test_example_status() {
680        let solved = ExampleStatus::Solved { turn: 2 };
681        let failed = ExampleStatus::Failed {
682            attempts: 3,
683            last_error: "Error".to_string(),
684        };
685        let timeout = ExampleStatus::Timeout { turn: 1 };
686        let skipped = ExampleStatus::Skipped {
687            reason: "No deps".to_string(),
688        };
689
690        // Just verify they serialize
691        assert!(serde_json::to_string(&solved).is_ok());
692        assert!(serde_json::to_string(&failed).is_ok());
693        assert!(serde_json::to_string(&timeout).is_ok());
694        assert!(serde_json::to_string(&skipped).is_ok());
695    }
696
697    #[test]
698    fn test_eval_suite_config_default() {
699        let config = EvalSuiteConfig::default();
700        assert_eq!(config.max_turns, 5);
701        assert_eq!(config.turn_timeout_secs, 60);
702        assert!(config.success_thresholds.contains(&0.9));
703    }
704
705    #[test]
706    fn test_stratified_by_difficulty() {
707        let mut comparison = ModelComparison::new("test");
708
709        let mut result = EvalResult::new("model1", "test", 1000);
710        result.add_example(ExampleResult::solved(
711            "e1",
712            Difficulty::Easy,
713            1,
714            vec![10],
715            vec![Duration::ZERO],
716        ));
717        result.add_example(ExampleResult::solved(
718            "e2",
719            Difficulty::Easy,
720            1,
721            vec![10],
722            vec![Duration::ZERO],
723        ));
724        result.add_example(ExampleResult::failed(
725            "e3",
726            Difficulty::Hard,
727            3,
728            "err",
729            vec![10; 3],
730            vec![Duration::ZERO; 3],
731        ));
732        result.finalize(3);
733        comparison.add_result(result);
734
735        let stratified = comparison.stratified_by_difficulty();
736        let model1 = stratified.get("model1").unwrap();
737
738        // Easy: 2/2 = 100%
739        assert!((model1.get(&Difficulty::Easy).unwrap() - 1.0).abs() < 0.01);
740        // Hard: 0/1 = 0%
741        assert!((model1.get(&Difficulty::Hard).unwrap() - 0.0).abs() < 0.01);
742    }
743
744    #[test]
745    fn test_generate_recommendations() {
746        let mut comparison = ModelComparison::new("test");
747
748        let mut small = EvalResult::new("small", "test", 1000);
749        small.overall_success_rate = 0.8;
750        comparison.add_result(small);
751
752        let mut large = EvalResult::new("large", "test", 10000);
753        large.overall_success_rate = 0.95;
754        comparison.add_result(large);
755
756        comparison.generate_recommendations();
757
758        assert!(comparison.recommendations.len() >= 2);
759        assert!(comparison
760            .recommendations
761            .iter()
762            .any(|r| r.scenario.contains("footprint")));
763        assert!(comparison
764            .recommendations
765            .iter()
766            .any(|r| r.scenario.contains("accuracy")));
767    }
768
769    // =========================================================================
770    // Additional coverage tests
771    // =========================================================================
772
773    #[test]
774    fn test_difficulty_name() {
775        assert_eq!(Difficulty::Trivial.name(), "Trivial");
776        assert_eq!(Difficulty::Easy.name(), "Easy");
777        assert_eq!(Difficulty::Medium.name(), "Medium");
778        assert_eq!(Difficulty::Hard.name(), "Hard");
779        assert_eq!(Difficulty::Expert.name(), "Expert");
780    }
781
782    #[test]
783    fn test_eval_result_finalize_empty() {
784        let mut result = EvalResult::new("model1", "task1", 1000);
785        result.finalize(3); // Should early return
786        assert!(result.success_by_turn.is_empty());
787        assert!((result.avg_turns_to_success - 0.0).abs() < 0.001);
788    }
789
790    #[test]
791    fn test_success_at_turn_out_of_bounds() {
792        let result = EvalResult::new("model1", "task1", 1000);
793        // No finalize, so success_by_turn is empty
794        assert!((result.success_at_turn(1) - 0.0).abs() < 0.001);
795        assert!((result.success_at_turn(10) - 0.0).abs() < 0.001);
796    }
797
798    #[test]
799    fn test_fastest_meeting_threshold() {
800        let mut comparison = ModelComparison::new("test");
801
802        let mut fast = EvalResult::new("fast", "test", 5000);
803        fast.overall_success_rate = 0.9;
804        fast.avg_turns_to_success = 1.2;
805        comparison.add_result(fast);
806
807        let mut slow = EvalResult::new("slow", "test", 2000);
808        slow.overall_success_rate = 0.9;
809        slow.avg_turns_to_success = 2.8;
810        comparison.add_result(slow);
811
812        let fastest = comparison.fastest_meeting_threshold(0.85);
813        assert!(fastest.is_some());
814        assert_eq!(fastest.unwrap().model_id, "fast");
815    }
816
817    #[test]
818    fn test_fastest_meeting_threshold_none() {
819        let mut comparison = ModelComparison::new("test");
820
821        let mut result = EvalResult::new("model", "test", 1000);
822        result.overall_success_rate = 0.5;
823        comparison.add_result(result);
824
825        let fastest = comparison.fastest_meeting_threshold(0.9);
826        assert!(fastest.is_none());
827    }
828
829    #[test]
830    fn test_smallest_meeting_threshold_none() {
831        let mut comparison = ModelComparison::new("test");
832
833        let mut result = EvalResult::new("model", "test", 1000);
834        result.overall_success_rate = 0.5;
835        comparison.add_result(result);
836
837        let smallest = comparison.smallest_meeting_threshold(0.9);
838        assert!(smallest.is_none());
839    }
840
841    #[test]
842    fn test_compute_pareto_frontier() {
843        let mut comparison = ModelComparison::new("test");
844
845        let mut model1 = EvalResult::new("m1", "test", 1000);
846        model1.overall_success_rate = 0.9;
847        model1.avg_turns_to_success = 1.5;
848        comparison.add_result(model1);
849
850        let mut model2 = EvalResult::new("m2", "test", 5000);
851        model2.overall_success_rate = 0.95;
852        model2.avg_turns_to_success = 1.2;
853        comparison.add_result(model2);
854
855        comparison.compute_pareto_frontier();
856        // Both should be on frontier (different trade-offs)
857        assert!(!comparison.pareto_frontier.is_empty());
858    }
859
860    #[test]
861    fn test_generate_recommendations_with_pareto() {
862        let mut comparison = ModelComparison::new("test");
863
864        let mut model1 = EvalResult::new("m1", "test", 1000);
865        model1.overall_success_rate = 0.9;
866        model1.avg_turns_to_success = 1.5;
867        comparison.add_result(model1);
868
869        let mut model2 = EvalResult::new("m2", "test", 5000);
870        model2.overall_success_rate = 0.95;
871        model2.avg_turns_to_success = 1.2;
872        comparison.add_result(model2);
873
874        // Compute pareto first so the "best balance" recommendation uses it
875        comparison.compute_pareto_frontier();
876        comparison.generate_recommendations();
877
878        assert!(comparison.recommendations.len() >= 2);
879        assert!(comparison
880            .recommendations
881            .iter()
882            .any(|r| r.scenario.contains("balance")));
883    }
884
885    #[test]
886    fn test_pareto_point_struct() {
887        let point = ParetoPoint {
888            model_id: "test".to_string(),
889            size_bytes: 1000,
890            success_rate: 0.9,
891            avg_turns: 1.5,
892            is_pareto_optimal: true,
893        };
894        assert_eq!(point.model_id, "test");
895        assert!(point.is_pareto_optimal);
896    }
897
898    #[test]
899    fn test_recommendation_struct() {
900        let rec = Recommendation {
901            scenario: "test scenario".to_string(),
902            model_id: "model1".to_string(),
903            rationale: "because".to_string(),
904        };
905        assert_eq!(rec.scenario, "test scenario");
906        assert_eq!(rec.model_id, "model1");
907    }
908
909    #[test]
910    fn test_eval_suite_config_custom() {
911        let config = EvalSuiteConfig {
912            id: "custom".to_string(),
913            description: "Custom suite".to_string(),
914            max_turns: 10,
915            turn_timeout_secs: 120,
916            examples_path: PathBuf::from("/custom/path"),
917            success_thresholds: vec![0.7, 0.8],
918        };
919        assert_eq!(config.id, "custom");
920        assert_eq!(config.max_turns, 10);
921    }
922
923    #[test]
924    fn test_example_default_values() {
925        let example = Example::new("id", "input", "expected");
926        assert_eq!(example.difficulty, Difficulty::Medium);
927        assert!(example.tags.is_empty());
928    }
929
930    #[test]
931    fn test_model_comparison_empty() {
932        let comparison = ModelComparison::new("test");
933        assert!(comparison.results.is_empty());
934        assert!(comparison.pareto_frontier.is_empty());
935        assert!(comparison.recommendations.is_empty());
936    }
937
938    #[test]
939    fn test_eval_result_with_params() {
940        let mut result = EvalResult::new("model", "task", 1000);
941        result.model_params = Some(1_000_000);
942        assert_eq!(result.model_params, Some(1_000_000));
943    }
944
945    #[test]
946    fn test_finalize_with_only_failed() {
947        let mut result = EvalResult::new("model", "task", 1000);
948        result.add_example(ExampleResult::failed(
949            "ex1",
950            Difficulty::Hard,
951            3,
952            "error",
953            vec![100; 3],
954            vec![Duration::ZERO; 3],
955        ));
956        result.finalize(3);
957
958        assert!((result.overall_success_rate - 0.0).abs() < 0.001);
959        assert!((result.avg_turns_to_success - 0.0).abs() < 0.001);
960    }
961
962    /// Test implementation of EvalTask trait
963    struct TestTask {
964        id: String,
965        description: String,
966        examples: Vec<Example>,
967    }
968
969    impl EvalTask for TestTask {
970        fn id(&self) -> &str {
971            &self.id
972        }
973
974        fn description(&self) -> &str {
975            &self.description
976        }
977
978        fn examples(&self) -> &[Example] {
979            &self.examples
980        }
981    }
982
983    #[test]
984    fn test_eval_task_defaults() {
985        let task = TestTask {
986            id: "test".to_string(),
987            description: "Test task".to_string(),
988            examples: vec![],
989        };
990
991        assert_eq!(task.max_turns(), 5);
992        assert_eq!(task.turn_timeout(), Duration::from_secs(60));
993        assert_eq!(task.id(), "test");
994        assert_eq!(task.description(), "Test task");
995    }
996
997    #[test]
998    fn test_difficulty_clone_copy() {
999        let d1 = Difficulty::Expert;
1000        let d2 = d1; // Copy
1001        let d3 = d1.clone();
1002        assert_eq!(d1, d2);
1003        assert_eq!(d1, d3);
1004    }
1005
1006    #[test]
1007    fn test_difficulty_hash() {
1008        use std::collections::HashSet;
1009        let mut set = HashSet::new();
1010        set.insert(Difficulty::Easy);
1011        set.insert(Difficulty::Hard);
1012        assert!(set.contains(&Difficulty::Easy));
1013        assert!(!set.contains(&Difficulty::Expert));
1014    }
1015}