Skip to main content

ravenclaws/
eval.rs

1//! RavenClaws
2//!
3//! Provides a framework for defining, running, and scoring evaluation tasks
4//! against LLM agents. Captures full run traces for inspection and debugging.
5//!
6//! # Architecture
7//!
8//! ```text
9//! EvalConfig (TOML file)
10//!   └── Vec<EvalTask>
11//!         ├── prompt + golden answer
12//!         ├── assertions (contains, not_contains, regex, exact)
13//!         └── scoring weights
14//!
15//! EvalRunner
16//!   ├── run_task() → EvalResult (with RunTrace)
17//!   └── run_suite() → EvalReport (summary of all results)
18//!
19//! RunTrace
20//!   ├── steps: Vec<TraceStep>
21//!   ├── llm_calls: Vec<LlmCallTrace>
22//!   └── tool_calls: Vec<ToolCallTrace>
23//! ```
24
25use crate::error::{RavenClawsError, Result};
26use crate::llm::{ChatMessage, LLMProviderTrait};
27use serde::{Deserialize, Serialize};
28use std::sync::Arc;
29use tracing::{info, instrument, warn};
30
31// ── Configuration ───────────────────────────────────────────────────────────
32
33/// Configuration for an eval suite — loaded from a TOML file
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct EvalConfig {
36    /// Name of this eval suite
37    #[serde(default = "default_suite_name")]
38    pub name: String,
39    /// Description of what this suite tests
40    #[serde(default)]
41    pub description: String,
42    /// System prompt to use for all tasks in this suite
43    #[serde(default = "default_system_prompt")]
44    pub system_prompt: String,
45    /// Maximum iterations per task
46    #[serde(default = "default_max_iterations")]
47    pub max_iterations: usize,
48    /// List of eval tasks to run
49    #[serde(default)]
50    pub tasks: Vec<EvalTask>,
51}
52
53fn default_suite_name() -> String {
54    "unnamed".to_string()
55}
56
57fn default_system_prompt() -> String {
58    "You are a helpful assistant. Be concise and accurate.".to_string()
59}
60
61fn default_max_iterations() -> usize {
62    5
63}
64
65/// A single eval task with prompt, golden answer, and assertions
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct EvalTask {
68    /// Name of this task (used in reports)
69    pub name: String,
70    /// Description of what this task tests
71    #[serde(default)]
72    pub description: String,
73    /// The prompt to send to the agent
74    pub prompt: String,
75    /// Expected golden answer (used for exact match scoring)
76    #[serde(default)]
77    pub golden: String,
78    /// List of assertions to check against the response
79    #[serde(default)]
80    pub assertions: Vec<Assertion>,
81    /// Weight of this task in the overall score (0.0 - 1.0)
82    #[serde(default = "default_weight")]
83    pub weight: f64,
84    /// Whether this task is required to pass (fails the suite if not)
85    #[serde(default)]
86    pub required: bool,
87}
88
89fn default_weight() -> f64 {
90    1.0
91}
92
93/// Types of assertions that can be checked against a response
94#[derive(Debug, Clone, Serialize, Deserialize)]
95#[serde(tag = "type", content = "value")]
96pub enum Assertion {
97    /// Response must contain this substring
98    #[serde(rename = "contains")]
99    Contains(String),
100    /// Response must NOT contain this substring
101    #[serde(rename = "not_contains")]
102    NotContains(String),
103    /// Response must exactly match this string
104    #[serde(rename = "exact")]
105    Exact(String),
106    /// Response must match this regex pattern
107    #[serde(rename = "regex")]
108    Regex(String),
109    /// Response must be non-empty
110    #[serde(rename = "non_empty")]
111    NonEmpty,
112    /// Response length must be at least N characters
113    #[serde(rename = "min_length")]
114    MinLength(usize),
115    /// Response length must be at most N characters
116    #[serde(rename = "max_length")]
117    MaxLength(usize),
118    /// A tool with this name must have been called during execution (v0.9.6)
119    #[serde(rename = "tool_called")]
120    ToolCalled(String),
121    /// A tool with this name must NOT have been called during execution (v0.9.6)
122    #[serde(rename = "tool_not_called")]
123    ToolNotCalled(String),
124}
125
126// ── Run Trace ───────────────────────────────────────────────────────────────
127
128/// Full trace of a single agent run — captures every step for inspection
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct RunTrace {
131    /// Task name
132    pub task_name: String,
133    /// When the run started (ISO 8601)
134    pub started_at: String,
135    /// When the run ended (ISO 8601)
136    pub ended_at: String,
137    /// Duration in milliseconds
138    pub duration_ms: u64,
139    /// Number of iterations used
140    pub iterations: usize,
141    /// All steps in chronological order
142    pub steps: Vec<TraceStep>,
143    /// LLM calls made during the run
144    pub llm_calls: Vec<LlmCallTrace>,
145    /// Tool calls made during the run
146    pub tool_calls: Vec<ToolCallTrace>,
147    /// Final response from the agent
148    pub final_response: String,
149}
150
151/// A single step in the agent loop
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct TraceStep {
154    /// Step number (0-based)
155    pub number: usize,
156    /// Type of step
157    pub step_type: StepType,
158    /// Content of the step (LLM response, tool result, etc.)
159    pub content: String,
160    /// Duration of this step in milliseconds
161    pub duration_ms: u64,
162}
163
164/// Type of a trace step
165#[derive(Debug, Clone, Serialize, Deserialize)]
166pub enum StepType {
167    /// LLM thought/response
168    Thought,
169    /// Tool call
170    ToolCall,
171    /// Tool result/observation
172    Observation,
173    /// Final answer
174    Final,
175    /// Error
176    Error,
177}
178
179/// Trace of a single LLM call
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct LlmCallTrace {
182    /// Iteration number
183    pub iteration: usize,
184    /// Provider name
185    pub provider: String,
186    /// Model name
187    pub model: String,
188    /// Prompt tokens (if available)
189    pub prompt_tokens: Option<u32>,
190    /// Completion tokens (if available)
191    pub completion_tokens: Option<u32>,
192    /// Duration in milliseconds
193    pub duration_ms: u64,
194    /// Response content (truncated to 1000 chars for storage)
195    pub response_preview: String,
196}
197
198/// Trace of a single tool call
199#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct ToolCallTrace {
201    /// Iteration number
202    pub iteration: usize,
203    /// Tool name
204    pub tool_name: String,
205    /// Arguments (JSON)
206    pub arguments: serde_json::Value,
207    /// Whether the tool succeeded
208    pub success: bool,
209    /// Output preview (truncated to 500 chars)
210    pub output_preview: String,
211    /// Duration in milliseconds
212    pub duration_ms: u64,
213}
214
215// ── Results ─────────────────────────────────────────────────────────────────
216
217/// Result of a single eval task
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct EvalResult {
220    /// Task name
221    pub task_name: String,
222    /// Whether the task passed all assertions
223    pub passed: bool,
224    /// Score (0.0 - 1.0)
225    pub score: f64,
226    /// Number of assertions that passed
227    pub assertions_passed: usize,
228    /// Number of assertions that failed
229    pub assertions_failed: usize,
230    /// Details of each assertion check
231    pub assertion_results: Vec<AssertionResult>,
232    /// Full run trace for inspection
233    pub trace: RunTrace,
234    /// Error message if the task failed to run
235    pub error: Option<String>,
236}
237
238/// Result of a single assertion check
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct AssertionResult {
241    /// The assertion that was checked
242    pub assertion: String,
243    /// Whether it passed
244    pub passed: bool,
245    /// Details about the check
246    pub details: String,
247}
248
249/// Summary report of an entire eval suite run
250#[derive(Debug, Clone, Serialize, Deserialize)]
251pub struct EvalReport {
252    /// Suite name
253    pub suite_name: String,
254    /// When the suite was run (ISO 8601)
255    pub ran_at: String,
256    /// Duration in milliseconds
257    pub duration_ms: u64,
258    /// Overall score (0.0 - 1.0)
259    pub overall_score: f64,
260    /// Number of tasks
261    pub total_tasks: usize,
262    /// Number of tasks that passed
263    pub passed_tasks: usize,
264    /// Number of tasks that failed
265    pub failed_tasks: usize,
266    /// Individual task results
267    pub results: Vec<EvalResult>,
268}
269
270// ── Eval Runner ─────────────────────────────────────────────────────────────
271
272/// Runs eval tasks against an LLM provider and captures traces
273pub struct EvalRunner {
274    /// The LLM provider to test
275    llm: Arc<dyn LLMProviderTrait>,
276    /// Eval configuration
277    config: EvalConfig,
278}
279
280impl EvalRunner {
281    /// Create a new eval runner
282    pub fn new(llm: Arc<dyn LLMProviderTrait>, config: EvalConfig) -> Self {
283        Self { llm, config }
284    }
285
286    /// Run the full eval suite and return a report
287    #[instrument(skip(self), fields(suite = %self.config.name, task_count = self.config.tasks.len()))]
288    pub async fn run_suite(&self) -> EvalReport {
289        let started_at = chrono::Utc::now().to_rfc3339();
290        let suite_start = std::time::Instant::now();
291        let mut results = Vec::with_capacity(self.config.tasks.len());
292
293        info!(
294            suite = %self.config.name,
295            task_count = self.config.tasks.len(),
296            "Starting eval suite"
297        );
298
299        for task in &self.config.tasks {
300            let result = self.run_task(task).await;
301            let passed = result.passed;
302            let name = &result.task_name;
303
304            if passed {
305                info!(task = %name, score = result.score, "Eval task passed");
306            } else {
307                warn!(
308                    task = %name,
309                    score = result.score,
310                    passed = result.assertions_passed,
311                    failed = result.assertions_failed,
312                    "Eval task failed"
313                );
314            }
315
316            results.push(result);
317        }
318
319        let duration_ms = suite_start.elapsed().as_millis() as u64;
320        let total_tasks = results.len();
321        let passed_tasks = results.iter().filter(|r| r.passed).count();
322        let failed_tasks = total_tasks - passed_tasks;
323        let overall_score = if total_tasks > 0 {
324            results
325                .iter()
326                .map(|r| r.score * r.trace.iterations as f64)
327                .sum::<f64>()
328                / results
329                    .iter()
330                    .map(|r| r.trace.iterations as f64)
331                    .sum::<f64>()
332        } else {
333            0.0
334        };
335
336        info!(
337            suite = %self.config.name,
338            passed = passed_tasks,
339            failed = failed_tasks,
340            overall_score = overall_score,
341            duration_ms = duration_ms,
342            "Eval suite completed"
343        );
344
345        EvalReport {
346            suite_name: self.config.name.clone(),
347            ran_at: started_at,
348            duration_ms,
349            overall_score,
350            total_tasks,
351            passed_tasks,
352            failed_tasks,
353            results,
354        }
355    }
356
357    /// Run a single eval task and return the result with trace
358    #[instrument(skip(self), fields(task = %task.name))]
359    async fn run_task(&self, task: &EvalTask) -> EvalResult {
360        let task_start = std::time::Instant::now();
361        let started_at = chrono::Utc::now().to_rfc3339();
362        let mut trace = RunTrace {
363            task_name: task.name.clone(),
364            started_at: started_at.clone(),
365            ended_at: String::new(),
366            duration_ms: 0,
367            iterations: 0,
368            steps: Vec::new(),
369            llm_calls: Vec::new(),
370            tool_calls: Vec::new(),
371            final_response: String::new(),
372        };
373
374        // Build messages
375        let messages = vec![
376            ChatMessage {
377                role: "system".to_string(),
378                content: self.config.system_prompt.clone(),
379            },
380            ChatMessage {
381                role: "user".to_string(),
382                content: task.prompt.clone(),
383            },
384        ];
385
386        // Make the LLM call
387        let call_start = std::time::Instant::now();
388        let response = match self.llm.chat(messages).await {
389            Ok(r) => r,
390            Err(e) => {
391                let duration_ms = task_start.elapsed().as_millis() as u64;
392                trace.ended_at = chrono::Utc::now().to_rfc3339();
393                trace.duration_ms = duration_ms;
394                trace.steps.push(TraceStep {
395                    number: 0,
396                    step_type: StepType::Error,
397                    content: format!("LLM call failed: {}", e),
398                    duration_ms,
399                });
400
401                return EvalResult {
402                    task_name: task.name.clone(),
403                    passed: false,
404                    score: 0.0,
405                    assertions_passed: 0,
406                    assertions_failed: 1,
407                    assertion_results: vec![AssertionResult {
408                        assertion: "llm_call".to_string(),
409                        passed: false,
410                        details: format!("LLM call failed: {}", e),
411                    }],
412                    trace,
413                    error: Some(e.to_string()),
414                };
415            }
416        };
417        let call_duration_ms = call_start.elapsed().as_millis() as u64;
418
419        let first_choice = response.choices.first();
420        let content = first_choice
421            .map(|c| c.message.content.clone())
422            .unwrap_or_default();
423
424        // Record LLM call trace
425        trace.llm_calls.push(LlmCallTrace {
426            iteration: 0,
427            provider: self.llm.provider_name().to_string(),
428            model: self.llm.model().to_string(),
429            prompt_tokens: response.usage.as_ref().map(|u| u.prompt_tokens),
430            completion_tokens: response.usage.as_ref().map(|u| u.completion_tokens),
431            duration_ms: call_duration_ms,
432            response_preview: content.chars().take(1000).collect(),
433        });
434
435        // Record step
436        trace.steps.push(TraceStep {
437            number: 0,
438            step_type: if content.contains("FINAL:") {
439                StepType::Final
440            } else {
441                StepType::Thought
442            },
443            content: content.clone(),
444            duration_ms: call_duration_ms,
445        });
446
447        trace.iterations = 1;
448        trace.final_response = content.clone();
449        trace.ended_at = chrono::Utc::now().to_rfc3339();
450        trace.duration_ms = task_start.elapsed().as_millis() as u64;
451
452        // Run assertions
453        let (assertion_results, assertions_passed, assertions_failed) =
454            check_assertions(&content, &task.assertions, Some(&trace));
455
456        // Calculate score
457        let score = if task.assertions.is_empty() {
458            // No assertions: score based on non-empty response
459            if content.is_empty() || content.len() < 10 {
460                0.0
461            } else {
462                1.0
463            }
464        } else if task.assertions.len() == assertions_passed + assertions_failed {
465            assertions_passed as f64 / task.assertions.len() as f64
466        } else {
467            0.0
468        };
469
470        let passed = assertions_failed == 0 && !content.is_empty();
471
472        EvalResult {
473            task_name: task.name.clone(),
474            passed,
475            score,
476            assertions_passed,
477            assertions_failed,
478            assertion_results,
479            trace,
480            error: None,
481        }
482    }
483}
484
485// ── Assertion Checking ──────────────────────────────────────────────────────
486
487/// Check all assertions against a response string
488fn check_assertions(
489    response: &str,
490    assertions: &[Assertion],
491    run_trace: Option<&RunTrace>,
492) -> (Vec<AssertionResult>, usize, usize) {
493    let mut results = Vec::with_capacity(assertions.len());
494    let mut passed = 0;
495    let mut failed = 0;
496
497    for assertion in assertions {
498        let result = check_single_assertion(response, assertion, run_trace);
499        if result.passed {
500            passed += 1;
501        } else {
502            failed += 1;
503        }
504        results.push(result);
505    }
506
507    (results, passed, failed)
508}
509
510/// Check a single assertion against a response
511fn check_single_assertion(
512    response: &str,
513    assertion: &Assertion,
514    run_trace: Option<&RunTrace>,
515) -> AssertionResult {
516    match assertion {
517        Assertion::Contains(pattern) => {
518            let passed = response.contains(pattern);
519            AssertionResult {
520                assertion: format!("contains: {}", pattern),
521                passed,
522                details: if passed {
523                    format!("Response contains '{}'", pattern)
524                } else {
525                    format!("Response does not contain '{}'", pattern)
526                },
527            }
528        }
529        Assertion::NotContains(pattern) => {
530            let passed = !response.contains(pattern);
531            AssertionResult {
532                assertion: format!("not_contains: {}", pattern),
533                passed,
534                details: if passed {
535                    format!("Response does not contain '{}'", pattern)
536                } else {
537                    format!("Response contains '{}'", pattern)
538                },
539            }
540        }
541        Assertion::Exact(expected) => {
542            let trimmed_response = response.trim();
543            let passed = trimmed_response == expected.as_str();
544            AssertionResult {
545                assertion: format!("exact: {}", expected),
546                passed,
547                details: if passed {
548                    "Response matches exactly".to_string()
549                } else {
550                    format!(
551                        "Expected '{}', got '{}'",
552                        expected,
553                        trimmed_response.chars().take(100).collect::<String>()
554                    )
555                },
556            }
557        }
558        Assertion::Regex(pattern) => {
559            let re = regex_lite::Regex::new(pattern);
560            match re {
561                Ok(re) => {
562                    let passed = re.is_match(response);
563                    AssertionResult {
564                        assertion: format!("regex: {}", pattern),
565                        passed,
566                        details: if passed {
567                            format!("Response matches pattern '{}'", pattern)
568                        } else {
569                            format!("Response does not match pattern '{}'", pattern)
570                        },
571                    }
572                }
573                Err(e) => AssertionResult {
574                    assertion: format!("regex: {}", pattern),
575                    passed: false,
576                    details: format!("Invalid regex pattern: {}", e),
577                },
578            }
579        }
580        Assertion::NonEmpty => {
581            let passed = !response.is_empty();
582            AssertionResult {
583                assertion: "non_empty".to_string(),
584                passed,
585                details: if passed {
586                    format!("Response is non-empty ({} chars)", response.len())
587                } else {
588                    "Response is empty".to_string()
589                },
590            }
591        }
592        Assertion::MinLength(min) => {
593            let passed = response.len() >= *min;
594            AssertionResult {
595                assertion: format!("min_length: {}", min),
596                passed,
597                details: if passed {
598                    format!("Response length {} >= {}", response.len(), min)
599                } else {
600                    format!("Response length {} < {}", response.len(), min)
601                },
602            }
603        }
604        Assertion::MaxLength(max) => {
605            let passed = response.len() <= *max;
606            AssertionResult {
607                assertion: format!("max_length: {}", max),
608                passed,
609                details: if passed {
610                    format!("Response length {} <= {}", response.len(), max)
611                } else {
612                    format!("Response length {} > {}", response.len(), max)
613                },
614            }
615        }
616        Assertion::ToolCalled(tool_name) => {
617            let tool_calls = run_trace
618                .map(|t| &t.tool_calls)
619                .filter(|calls| calls.iter().any(|tc| tc.tool_name == *tool_name));
620            let passed = tool_calls.is_some();
621            AssertionResult {
622                assertion: format!("tool_called: {}", tool_name),
623                passed,
624                details: if passed {
625                    format!("Tool '{}' was called", tool_name)
626                } else {
627                    let all_tools: Vec<&str> = run_trace
628                        .map(|t| {
629                            t.tool_calls
630                                .iter()
631                                .map(|tc| tc.tool_name.as_str())
632                                .collect()
633                        })
634                        .unwrap_or_default();
635                    if all_tools.is_empty() {
636                        format!("Tool '{}' was not called (no tools were called)", tool_name)
637                    } else {
638                        format!(
639                            "Tool '{}' was not called (called: {})",
640                            tool_name,
641                            all_tools.join(", ")
642                        )
643                    }
644                },
645            }
646        }
647        Assertion::ToolNotCalled(tool_name) => {
648            let tool_calls = run_trace
649                .map(|t| &t.tool_calls)
650                .filter(|calls| calls.iter().any(|tc| tc.tool_name == *tool_name));
651            let passed = tool_calls.is_none();
652            AssertionResult {
653                assertion: format!("tool_not_called: {}", tool_name),
654                passed,
655                details: if passed {
656                    format!("Tool '{}' was not called", tool_name)
657                } else {
658                    format!("Tool '{}' was called but should not have been", tool_name)
659                },
660            }
661        }
662    }
663}
664
665// ── Report Formatting ───────────────────────────────────────────────────────
666
667impl EvalReport {
668    /// Format the report as a human-readable string
669    pub fn format_text(&self) -> String {
670        let mut output = String::new();
671
672        output.push_str(&format!("\n🐦‍⬛ Eval Report: {}\n", self.suite_name));
673        output.push_str(&format!("{:-^60}\n", ""));
674        output.push_str(&format!(
675            "Ran at:       {}\n",
676            &self.ran_at[..19].replace('T', " ")
677        ));
678        output.push_str(&format!("Duration:     {} ms\n", self.duration_ms));
679        output.push_str(&format!(
680            "Overall score: {:.1}%\n",
681            self.overall_score * 100.0
682        ));
683        output.push_str(&format!(
684            "Tasks:        {}/{} passed\n",
685            self.passed_tasks, self.total_tasks
686        ));
687        output.push_str(&format!("{:-^60}\n", ""));
688
689        for result in &self.results {
690            output.push_str(&format!(
691                "\n  {} {} — {:.1}%\n",
692                if result.passed { "✅" } else { "❌" },
693                result.task_name,
694                result.score * 100.0
695            ));
696
697            if let Some(ref error) = result.error {
698                output.push_str(&format!("    Error: {}\n", error));
699            }
700
701            if !result.assertion_results.is_empty() {
702                for ar in &result.assertion_results {
703                    output.push_str(&format!(
704                        "    {} {}\n",
705                        if ar.passed { "  ✅" } else { "  ❌" },
706                        ar.details
707                    ));
708                }
709            }
710
711            // Show trace summary
712            let trace = &result.trace;
713            output.push_str(&format!(
714                "    Iterations: {} · LLM calls: {} · Tool calls: {} · Duration: {} ms\n",
715                trace.iterations,
716                trace.llm_calls.len(),
717                trace.tool_calls.len(),
718                trace.duration_ms
719            ));
720
721            // Show response preview
722            let preview: String = trace.final_response.chars().take(200).collect();
723            if !preview.is_empty() {
724                output.push_str(&format!("    Response: {}\n", preview));
725            }
726        }
727
728        output
729    }
730
731    /// Format the report as JSON
732    pub fn format_json(&self) -> serde_json::Value {
733        serde_json::to_value(self).unwrap_or(serde_json::json!({"error": "serialization failed"}))
734    }
735}
736
737// ── Config Loading ──────────────────────────────────────────────────────────
738
739impl EvalConfig {
740    /// Load eval config from a TOML file
741    pub fn from_file(path: &str) -> Result<Self> {
742        let content = std::fs::read_to_string(path).map_err(|e| {
743            RavenClawsError::CommandExecution(format!("Failed to read eval config: {}", e))
744        })?;
745
746        let config: EvalConfig = toml::from_str(&content).map_err(|e| {
747            RavenClawsError::CommandExecution(format!("Failed to parse eval config: {}", e))
748        })?;
749
750        Ok(config)
751    }
752}
753
754// ── Tests ───────────────────────────────────────────────────────────────────
755
756#[cfg(test)]
757mod tests {
758    use super::*;
759
760    #[test]
761    fn test_assertion_contains_pass() {
762        let result = check_single_assertion(
763            "hello world",
764            &Assertion::Contains("world".to_string()),
765            None,
766        );
767        assert!(result.passed);
768        assert!(result.details.contains("contains"));
769    }
770
771    #[test]
772    fn test_assertion_contains_fail() {
773        let result =
774            check_single_assertion("hello world", &Assertion::Contains("foo".to_string()), None);
775        assert!(!result.passed);
776    }
777
778    #[test]
779    fn test_assertion_not_contains_pass() {
780        let result = check_single_assertion(
781            "hello world",
782            &Assertion::NotContains("foo".to_string()),
783            None,
784        );
785        assert!(result.passed);
786    }
787
788    #[test]
789    fn test_assertion_not_contains_fail() {
790        let result = check_single_assertion(
791            "hello world",
792            &Assertion::NotContains("world".to_string()),
793            None,
794        );
795        assert!(!result.passed);
796    }
797
798    #[test]
799    fn test_assertion_exact_pass() {
800        let result = check_single_assertion("hello", &Assertion::Exact("hello".to_string()), None);
801        assert!(result.passed);
802    }
803
804    #[test]
805    fn test_assertion_exact_fail() {
806        let result = check_single_assertion("world", &Assertion::Exact("hello".to_string()), None);
807        assert!(!result.passed);
808    }
809
810    #[test]
811    fn test_assertion_regex_pass() {
812        let result =
813            check_single_assertion("hello 123", &Assertion::Regex(r"\d+".to_string()), None);
814        assert!(result.passed);
815    }
816
817    #[test]
818    fn test_assertion_regex_fail() {
819        let result = check_single_assertion("hello", &Assertion::Regex(r"\d+".to_string()), None);
820        assert!(!result.passed);
821    }
822
823    #[test]
824    fn test_assertion_non_empty_pass() {
825        let result = check_single_assertion("hello", &Assertion::NonEmpty, None);
826        assert!(result.passed);
827    }
828
829    #[test]
830    fn test_assertion_non_empty_fail() {
831        let result = check_single_assertion("", &Assertion::NonEmpty, None);
832        assert!(!result.passed);
833    }
834
835    #[test]
836    fn test_assertion_min_length_pass() {
837        let result = check_single_assertion("hello", &Assertion::MinLength(3), None);
838        assert!(result.passed);
839    }
840
841    #[test]
842    fn test_assertion_min_length_fail() {
843        let result = check_single_assertion("hi", &Assertion::MinLength(5), None);
844        assert!(!result.passed);
845    }
846
847    #[test]
848    fn test_assertion_max_length_pass() {
849        let result = check_single_assertion("hi", &Assertion::MaxLength(5), None);
850        assert!(result.passed);
851    }
852
853    #[test]
854    fn test_assertion_max_length_fail() {
855        let result = check_single_assertion("hello world", &Assertion::MaxLength(5), None);
856        assert!(!result.passed);
857    }
858
859    #[test]
860    fn test_check_assertions_empty() {
861        let (results, passed, failed) = check_assertions("hello", &[], None);
862        assert!(results.is_empty());
863        assert_eq!(passed, 0);
864        assert_eq!(failed, 0);
865    }
866
867    #[test]
868    fn test_check_assertions_multiple() {
869        let assertions = vec![
870            Assertion::Contains("hello".to_string()),
871            Assertion::Contains("world".to_string()),
872            Assertion::NonEmpty,
873        ];
874        let (results, passed, failed) = check_assertions("hello world", &assertions, None);
875        assert_eq!(passed, 3);
876        assert_eq!(failed, 0);
877        assert_eq!(results.len(), 3);
878    }
879
880    #[test]
881    fn test_check_assertions_tool_called() {
882        let trace = RunTrace {
883            task_name: "test".to_string(),
884            started_at: "2026-01-01T00:00:00Z".to_string(),
885            ended_at: "2026-01-01T00:00:01Z".to_string(),
886            duration_ms: 1000,
887            iterations: 1,
888            steps: vec![],
889            llm_calls: vec![],
890            tool_calls: vec![
891                ToolCallTrace {
892                    iteration: 0,
893                    tool_name: "web_search".to_string(),
894                    arguments: serde_json::json!({"query": "test"}),
895                    success: true,
896                    output_preview: "results".to_string(),
897                    duration_ms: 100,
898                },
899                ToolCallTrace {
900                    iteration: 0,
901                    tool_name: "read_file".to_string(),
902                    arguments: serde_json::json!({"path": "/tmp/test"}),
903                    success: true,
904                    output_preview: "content".to_string(),
905                    duration_ms: 50,
906                },
907            ],
908            final_response: "response".to_string(),
909        };
910
911        // ToolCalled — should pass
912        let (results, passed, failed) = check_assertions(
913            "response",
914            &[Assertion::ToolCalled("web_search".to_string())],
915            Some(&trace),
916        );
917        assert_eq!(passed, 1);
918        assert_eq!(failed, 0);
919        assert!(results[0].passed);
920
921        // ToolCalled — should fail (tool not called)
922        let (results, passed, failed) = check_assertions(
923            "response",
924            &[Assertion::ToolCalled("nonexistent".to_string())],
925            Some(&trace),
926        );
927        assert_eq!(passed, 0);
928        assert_eq!(failed, 1);
929        assert!(!results[0].passed);
930
931        // ToolNotCalled — should pass (tool not in list)
932        let (results, passed, failed) = check_assertions(
933            "response",
934            &[Assertion::ToolNotCalled("nonexistent".to_string())],
935            Some(&trace),
936        );
937        assert_eq!(passed, 1);
938        assert_eq!(failed, 0);
939        assert!(results[0].passed);
940
941        // ToolNotCalled — should fail (tool was called)
942        let (results, passed, failed) = check_assertions(
943            "response",
944            &[Assertion::ToolNotCalled("web_search".to_string())],
945            Some(&trace),
946        );
947        assert_eq!(passed, 0);
948        assert_eq!(failed, 1);
949        assert!(!results[0].passed);
950
951        // ToolCalled with no trace — should fail
952        let (results, passed, failed) = check_assertions(
953            "response",
954            &[Assertion::ToolCalled("web_search".to_string())],
955            None,
956        );
957        assert_eq!(passed, 0);
958        assert_eq!(failed, 1);
959        assert!(!results[0].passed);
960    }
961
962    #[test]
963    fn test_eval_config_from_toml() {
964        let toml_str = r#"
965name = "test-suite"
966description = "A test suite"
967system_prompt = "Be concise"
968max_iterations = 3
969
970[[tasks]]
971name = "test-1"
972prompt = "What is 2+2?"
973golden = "4"
974assertions = [{ type = "contains", value = "4" }]
975weight = 1.0
976required = true
977"#;
978
979        let config: EvalConfig = toml::from_str(toml_str).unwrap();
980        assert_eq!(config.name, "test-suite");
981        assert_eq!(config.tasks.len(), 1);
982        assert_eq!(config.tasks[0].name, "test-1");
983        assert_eq!(config.tasks[0].prompt, "What is 2+2?");
984        assert_eq!(config.tasks[0].golden, "4");
985        assert_eq!(config.tasks[0].assertions.len(), 1);
986    }
987
988    #[test]
989    fn test_eval_config_defaults() {
990        let toml_str = r#"
991[[tasks]]
992name = "simple"
993prompt = "Say hello"
994"#;
995
996        let config: EvalConfig = toml::from_str(toml_str).unwrap();
997        assert_eq!(config.name, "unnamed");
998        assert_eq!(config.system_prompt, default_system_prompt());
999        assert_eq!(config.max_iterations, 5);
1000        assert_eq!(config.tasks[0].weight, 1.0);
1001        assert!(!config.tasks[0].required);
1002    }
1003
1004    #[test]
1005    fn test_report_format_text() {
1006        let report = EvalReport {
1007            suite_name: "test".to_string(),
1008            ran_at: "2026-06-22T12:00:00+00:00".to_string(),
1009            duration_ms: 100,
1010            overall_score: 0.75,
1011            total_tasks: 2,
1012            passed_tasks: 1,
1013            failed_tasks: 1,
1014            results: vec![
1015                EvalResult {
1016                    task_name: "pass-task".to_string(),
1017                    passed: true,
1018                    score: 1.0,
1019                    assertions_passed: 2,
1020                    assertions_failed: 0,
1021                    assertion_results: vec![AssertionResult {
1022                        assertion: "contains: hello".to_string(),
1023                        passed: true,
1024                        details: "Response contains 'hello'".to_string(),
1025                    }],
1026                    trace: RunTrace {
1027                        task_name: "pass-task".to_string(),
1028                        started_at: "2026-06-22T12:00:00+00:00".to_string(),
1029                        ended_at: "2026-06-22T12:00:01+00:00".to_string(),
1030                        duration_ms: 50,
1031                        iterations: 1,
1032                        steps: vec![],
1033                        llm_calls: vec![],
1034                        tool_calls: vec![],
1035                        final_response: "hello world".to_string(),
1036                    },
1037                    error: None,
1038                },
1039                EvalResult {
1040                    task_name: "fail-task".to_string(),
1041                    passed: false,
1042                    score: 0.0,
1043                    assertions_passed: 0,
1044                    assertions_failed: 1,
1045                    assertion_results: vec![AssertionResult {
1046                        assertion: "contains: foo".to_string(),
1047                        passed: false,
1048                        details: "Response does not contain 'foo'".to_string(),
1049                    }],
1050                    trace: RunTrace {
1051                        task_name: "fail-task".to_string(),
1052                        started_at: "2026-06-22T12:00:01+00:00".to_string(),
1053                        ended_at: "2026-06-22T12:00:02+00:00".to_string(),
1054                        duration_ms: 50,
1055                        iterations: 1,
1056                        steps: vec![],
1057                        llm_calls: vec![],
1058                        tool_calls: vec![],
1059                        final_response: "bar".to_string(),
1060                    },
1061                    error: None,
1062                },
1063            ],
1064        };
1065
1066        let text = report.format_text();
1067        assert!(text.contains("Eval Report: test"));
1068        assert!(text.contains("75.0%"));
1069        assert!(text.contains("1/2 passed"));
1070        assert!(text.contains("✅ pass-task"));
1071        assert!(text.contains("❌ fail-task"));
1072    }
1073
1074    #[test]
1075    fn test_report_format_json() {
1076        let report = EvalReport {
1077            suite_name: "test".to_string(),
1078            ran_at: "2026-06-22T12:00:00+00:00".to_string(),
1079            duration_ms: 100,
1080            overall_score: 1.0,
1081            total_tasks: 1,
1082            passed_tasks: 1,
1083            failed_tasks: 0,
1084            results: vec![],
1085        };
1086
1087        let json = report.format_json();
1088        assert_eq!(json["suite_name"], "test");
1089        assert_eq!(json["overall_score"], 1.0);
1090    }
1091
1092    #[test]
1093    fn test_eval_config_from_file_not_found() {
1094        let result = EvalConfig::from_file("/tmp/nonexistent-eval-config.toml");
1095        assert!(result.is_err());
1096    }
1097
1098    #[test]
1099    fn test_assertion_regex_invalid_pattern() {
1100        let result =
1101            check_single_assertion("hello", &Assertion::Regex(r"[invalid".to_string()), None);
1102        assert!(!result.passed);
1103        assert!(result.details.contains("Invalid regex"));
1104    }
1105
1106    #[test]
1107    fn test_trace_step_serialization() {
1108        let step = TraceStep {
1109            number: 0,
1110            step_type: StepType::Thought,
1111            content: "test".to_string(),
1112            duration_ms: 100,
1113        };
1114        let json = serde_json::to_string(&step).unwrap();
1115        assert!(json.contains("Thought"));
1116    }
1117
1118    #[test]
1119    fn test_tool_call_trace_serialization() {
1120        let trace = ToolCallTrace {
1121            iteration: 0,
1122            tool_name: "shell_exec".to_string(),
1123            arguments: serde_json::json!({"command": "echo hello"}),
1124            success: true,
1125            output_preview: "hello".to_string(),
1126            duration_ms: 50,
1127        };
1128        let json = serde_json::to_string(&trace).unwrap();
1129        assert!(json.contains("shell_exec"));
1130        assert!(json.contains("echo hello"));
1131    }
1132}