use std::collections::HashMap;
use std::path::Path;
use serde::{Deserialize, Serialize};
use super::evaluator::EvalError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TestConfig {
pub criteria: HashMap<String, CriterionConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum CriterionConfig {
Threshold(f64),
LlmJudge {
threshold: f64,
#[serde(default)]
judge_model: Option<String>,
#[serde(default)]
num_samples: Option<u32>,
},
}
impl CriterionConfig {
pub fn threshold(&self) -> f64 {
match self {
Self::Threshold(t) => *t,
Self::LlmJudge { threshold, .. } => *threshold,
}
}
pub fn passes(&self, score: f64) -> bool {
score >= self.threshold()
}
}
impl TestConfig {
pub fn check_all(&self, scores: &HashMap<String, f64>) -> HashMap<String, (bool, f64, f64)> {
self.criteria
.iter()
.map(|(name, config)| {
let score = scores.get(name).copied().unwrap_or(0.0);
let threshold = config.threshold();
let passed = config.passes(score);
(name.clone(), (passed, score, threshold))
})
.collect()
}
pub fn all_pass(&self, scores: &HashMap<String, f64>) -> bool {
self.check_all(scores)
.values()
.all(|(passed, _, _)| *passed)
}
}
pub fn parse_test_config(path: &Path) -> Result<TestConfig, EvalError> {
let contents = std::fs::read_to_string(path).map_err(|e| {
EvalError::Io(format!(
"Failed to read test config {}: {e}",
path.display()
))
})?;
parse_test_config_str(&contents)
}
pub fn parse_test_config_str(json: &str) -> Result<TestConfig, EvalError> {
serde_json::from_str(json)
.map_err(|e| EvalError::Parse(format!("Invalid test config JSON: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_simple_thresholds() {
let json = r#"{
"criteria": {
"response_quality": 0.8,
"tool_accuracy": 0.9
}
}"#;
let config = parse_test_config_str(json).unwrap();
assert_eq!(config.criteria.len(), 2);
let rq = &config.criteria["response_quality"];
assert!((rq.threshold() - 0.8).abs() < f64::EPSILON);
}
#[test]
fn parse_llm_judge_config() {
let json = r#"{
"criteria": {
"coherence": {
"threshold": 0.7,
"judge_model": "gemini-2.0-flash",
"num_samples": 3
}
}
}"#;
let config = parse_test_config_str(json).unwrap();
match &config.criteria["coherence"] {
CriterionConfig::LlmJudge {
threshold,
judge_model,
num_samples,
} => {
assert!((threshold - 0.7).abs() < f64::EPSILON);
assert_eq!(judge_model.as_deref(), Some("gemini-2.0-flash"));
assert_eq!(*num_samples, Some(3));
}
_ => panic!("Expected LlmJudge variant"),
}
}
#[test]
fn check_all_passing() {
let json = r#"{"criteria": {"a": 0.5, "b": 0.8}}"#;
let config = parse_test_config_str(json).unwrap();
let scores: HashMap<String, f64> =
[("a".into(), 0.6), ("b".into(), 0.9)].into_iter().collect();
assert!(config.all_pass(&scores));
}
#[test]
fn check_all_failing() {
let json = r#"{"criteria": {"a": 0.5, "b": 0.8}}"#;
let config = parse_test_config_str(json).unwrap();
let scores: HashMap<String, f64> =
[("a".into(), 0.6), ("b".into(), 0.7)].into_iter().collect();
assert!(!config.all_pass(&scores));
}
#[test]
fn missing_score_defaults_to_zero() {
let json = r#"{"criteria": {"a": 0.5}}"#;
let config = parse_test_config_str(json).unwrap();
let scores: HashMap<String, f64> = HashMap::new();
assert!(!config.all_pass(&scores));
}
#[test]
fn parse_invalid_json() {
let result = parse_test_config_str("bad");
assert!(result.is_err());
}
}