use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EvalDataset {
pub name: String,
pub version: String,
pub description: String,
pub test_cases: Vec<EvalTestCase>,
#[serde(default)]
pub tags: Vec<String>,
#[serde(default)]
pub created_at: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EvalTestCase {
pub id: String,
pub description: String,
#[serde(default = "default_priority")]
pub priority: u32,
pub automation_spec: AutomationSpecTest,
pub expected_output: EvalExpectedOutput,
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default)]
pub tags: Vec<String>,
#[serde(default)]
pub metric_tolerance: MetricTolerance,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AutomationSpecTest {
pub name: String,
pub nodes: Vec<TestNode>,
#[serde(default)]
pub validators: Vec<String>,
#[serde(default)]
pub config: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TestNode {
pub id: String,
pub node_type: String,
pub objective: String,
#[serde(default)]
pub output_contract: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EvalExpectedOutput {
pub artifact_status: ArtifactStatus,
#[serde(default)]
pub required_validators: Vec<String>,
#[serde(default)]
pub optional_validators: Vec<String>,
#[serde(default)]
pub unmet_requirements_acceptable: bool,
pub max_repair_iterations: Option<u32>,
#[serde(default)]
pub output_format: String,
#[serde(default)]
pub quality_indicators: Vec<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ArtifactStatus {
Completed,
CompletedWithWarnings,
Blocked,
Failed,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct MetricTolerance {
#[serde(default)]
pub repair_iterations_delta: Option<u32>,
#[serde(default)]
pub token_usage_tolerance_percent: Option<f64>,
#[serde(default)]
pub cost_tolerance_usd: Option<f64>,
#[serde(default)]
pub acceptable_failure_rate: Option<f32>,
}
fn default_priority() -> u32 {
2
}
fn default_true() -> bool {
true
}
impl EvalDataset {
pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
Self {
name: name.into(),
version: version.into(),
description: String::new(),
test_cases: Vec::new(),
tags: Vec::new(),
created_at: chrono::Utc::now().to_rfc3339(),
}
}
pub fn add_test_case(mut self, test_case: EvalTestCase) -> Self {
self.test_cases.push(test_case);
self
}
pub fn sorted_by_priority(&self) -> Vec<&EvalTestCase> {
let mut cases: Vec<_> = self.test_cases.iter().collect();
cases.sort_by_key(|t| std::cmp::Reverse(t.priority));
cases
}
pub fn filter_by_tag(&self, tag: &str) -> Vec<&EvalTestCase> {
self.test_cases
.iter()
.filter(|tc| tc.tags.contains(&tag.to_string()) && tc.enabled)
.collect()
}
pub fn enabled_test_cases(&self) -> Vec<&EvalTestCase> {
self.test_cases.iter().filter(|tc| tc.enabled).collect()
}
}
impl EvalTestCase {
pub fn new(id: impl Into<String>, description: impl Into<String>) -> Self {
Self {
id: id.into(),
description: description.into(),
priority: 2,
automation_spec: AutomationSpecTest {
name: "test_automation".to_string(),
nodes: Vec::new(),
validators: Vec::new(),
config: HashMap::new(),
},
expected_output: EvalExpectedOutput {
artifact_status: ArtifactStatus::Completed,
required_validators: Vec::new(),
optional_validators: Vec::new(),
unmet_requirements_acceptable: false,
max_repair_iterations: Some(3),
output_format: "json".to_string(),
quality_indicators: Vec::new(),
},
enabled: true,
tags: Vec::new(),
metric_tolerance: MetricTolerance::default(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_dataset() {
let dataset = EvalDataset::new("test", "1.0");
assert_eq!(dataset.name, "test");
assert_eq!(dataset.version, "1.0");
assert!(dataset.test_cases.is_empty());
}
#[test]
fn test_add_test_case() {
let dataset =
EvalDataset::new("test", "1.0").add_test_case(EvalTestCase::new("test1", "Test 1"));
assert_eq!(dataset.test_cases.len(), 1);
assert_eq!(dataset.test_cases[0].id, "test1");
}
#[test]
fn test_sorted_by_priority() {
let mut tc1 = EvalTestCase::new("test1", "Test 1");
tc1.priority = 1;
let mut tc2 = EvalTestCase::new("test2", "Test 2");
tc2.priority = 3;
let mut tc3 = EvalTestCase::new("test3", "Test 3");
tc3.priority = 2;
let dataset = EvalDataset::new("test", "1.0")
.add_test_case(tc1)
.add_test_case(tc2)
.add_test_case(tc3);
let sorted = dataset.sorted_by_priority();
assert_eq!(sorted[0].id, "test1"); assert_eq!(sorted[1].id, "test3"); assert_eq!(sorted[2].id, "test2"); }
#[test]
fn test_filter_by_tag() {
let mut tc1 = EvalTestCase::new("test1", "Test 1");
tc1.tags = vec!["critical".to_string()];
let mut tc2 = EvalTestCase::new("test2", "Test 2");
tc2.tags = vec!["regression".to_string()];
let dataset = EvalDataset::new("test", "1.0")
.add_test_case(tc1)
.add_test_case(tc2);
let critical = dataset.filter_by_tag("critical");
assert_eq!(critical.len(), 1);
assert_eq!(critical[0].id, "test1");
}
#[test]
fn test_enabled_test_cases() {
let mut tc1 = EvalTestCase::new("test1", "Test 1");
tc1.enabled = true;
let mut tc2 = EvalTestCase::new("test2", "Test 2");
tc2.enabled = false;
let dataset = EvalDataset::new("test", "1.0")
.add_test_case(tc1)
.add_test_case(tc2);
let enabled = dataset.enabled_test_cases();
assert_eq!(enabled.len(), 1);
assert_eq!(enabled[0].id, "test1");
}
#[test]
fn test_serde_roundtrip() {
let dataset = EvalDataset::new("test_dataset", "1.0");
let json = serde_json::to_string(&dataset).unwrap();
let deserialized: EvalDataset = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.name, "test_dataset");
assert_eq!(deserialized.version, "1.0");
}
#[test]
fn test_test_case_defaults() {
let tc = EvalTestCase::new("test1", "Test 1");
assert_eq!(tc.priority, 2);
assert!(tc.enabled);
assert_eq!(
tc.expected_output.artifact_status,
ArtifactStatus::Completed
);
}
#[test]
fn test_artifact_status_serialization() {
let json = serde_json::to_string(&ArtifactStatus::Completed).unwrap();
assert_eq!(json, "\"completed\"");
let status: ArtifactStatus = serde_json::from_str("\"completed\"").unwrap();
assert_eq!(status, ArtifactStatus::Completed);
}
}