use std::path::Path;
use serde::Deserialize;
use super::error::EvalError;
const MAX_BENCHMARK_SIZE: u64 = 10 * 1024 * 1024;
#[derive(Debug, Clone, Deserialize)]
pub struct BenchmarkSet {
pub cases: Vec<BenchmarkCase>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct BenchmarkCase {
pub prompt: String,
#[serde(default)]
pub context: Option<String>,
#[serde(default)]
pub reference: Option<String>,
#[serde(default)]
pub tags: Option<Vec<String>>,
}
impl BenchmarkSet {
pub fn from_file(path: &Path) -> Result<Self, EvalError> {
let canonical = std::fs::canonicalize(path)
.map_err(|e| EvalError::BenchmarkLoad(path.display().to_string(), e))?;
if let Some(parent) = path.parent()
&& let Ok(canonical_parent) = std::fs::canonicalize(parent)
&& !canonical.starts_with(&canonical_parent)
{
return Err(EvalError::PathTraversal(canonical.display().to_string()));
}
let metadata = std::fs::metadata(&canonical)
.map_err(|e| EvalError::BenchmarkLoad(canonical.display().to_string(), e))?;
if metadata.len() > MAX_BENCHMARK_SIZE {
return Err(EvalError::BenchmarkTooLarge {
path: canonical.display().to_string(),
size: metadata.len(),
limit: MAX_BENCHMARK_SIZE,
});
}
let content = std::fs::read_to_string(&canonical)
.map_err(|e| EvalError::BenchmarkLoad(canonical.display().to_string(), e))?;
toml::from_str(&content)
.map_err(|e| EvalError::BenchmarkParse(canonical.display().to_string(), e.to_string()))
}
pub fn validate(&self) -> Result<(), EvalError> {
if self.cases.is_empty() {
return Err(EvalError::EmptyBenchmarkSet);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn parse(toml: &str) -> BenchmarkSet {
toml::from_str(toml).expect("valid TOML")
}
#[test]
fn benchmark_from_toml_happy_path() {
let toml = r#"
[[cases]]
prompt = "What is 2+2?"
"#;
let set = parse(toml);
assert_eq!(set.cases.len(), 1);
assert_eq!(set.cases[0].prompt, "What is 2+2?");
assert!(set.cases[0].context.is_none());
assert!(set.cases[0].reference.is_none());
assert!(set.cases[0].tags.is_none());
}
#[test]
fn benchmark_from_toml_with_all_fields() {
let toml = r#"
[[cases]]
prompt = "Explain Rust ownership."
context = "You are a Rust expert."
reference = "Ownership is Rust's memory management model."
tags = ["rust", "concepts"]
"#;
let set = parse(toml);
assert_eq!(set.cases.len(), 1);
let case = &set.cases[0];
assert_eq!(case.context.as_deref(), Some("You are a Rust expert."));
assert!(case.reference.is_some());
assert_eq!(case.tags.as_ref().map(|t| t.len()), Some(2));
}
#[test]
fn benchmark_empty_cases_rejected() {
let set = BenchmarkSet { cases: vec![] };
assert!(matches!(set.validate(), Err(EvalError::EmptyBenchmarkSet)));
}
#[test]
fn benchmark_from_file_missing_file() {
let result = BenchmarkSet::from_file(Path::new("/nonexistent/path/benchmark.toml"));
assert!(matches!(result, Err(EvalError::BenchmarkLoad(_, _))));
}
#[test]
fn benchmark_from_toml_invalid_syntax() {
let bad = "[[cases\nprompt = 'unclosed'";
let result: Result<BenchmarkSet, _> = toml::from_str(bad);
assert!(result.is_err());
}
#[test]
fn benchmark_from_file_invalid_toml() {
use std::io::Write;
let mut f = tempfile::NamedTempFile::new().unwrap();
writeln!(f, "not valid toml ][[]").unwrap();
let result = BenchmarkSet::from_file(f.path());
assert!(matches!(result, Err(EvalError::BenchmarkParse(_, _))));
}
#[test]
fn benchmark_from_file_too_large() {
let err = EvalError::BenchmarkTooLarge {
path: "/tmp/bench.toml".into(),
size: MAX_BENCHMARK_SIZE + 1,
limit: MAX_BENCHMARK_SIZE,
};
assert!(err.to_string().contains("exceeds size limit"));
}
#[test]
fn benchmark_from_file_size_guard_allows_normal_file() {
use std::io::Write;
let mut f = tempfile::NamedTempFile::new().unwrap();
writeln!(f, "[[cases]]\nprompt = \"hello\"").unwrap();
let result = BenchmarkSet::from_file(f.path());
assert!(result.is_ok());
}
#[test]
fn benchmark_validate_passes_for_nonempty() {
let set = BenchmarkSet {
cases: vec![BenchmarkCase {
prompt: "hello".into(),
context: None,
reference: None,
tags: None,
}],
};
assert!(set.validate().is_ok());
}
}