use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::path::Path;
use crate::error::{EvalError, Result};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TestFile {
pub eval_set_id: String,
pub name: String,
#[serde(default)]
pub description: String,
pub eval_cases: Vec<EvalCase>,
}
impl TestFile {
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
let content = std::fs::read_to_string(path.as_ref())?;
let test_file: TestFile = serde_json::from_str(&content)?;
Ok(test_file)
}
pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
let content = serde_json::to_string_pretty(self)?;
std::fs::write(path, content)?;
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EvalSet {
pub eval_set_id: String,
pub name: String,
#[serde(default)]
pub description: String,
#[serde(default)]
pub test_files: Vec<String>,
#[serde(default)]
pub eval_cases: Vec<EvalCase>,
}
impl EvalSet {
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
let content = std::fs::read_to_string(path.as_ref())?;
let eval_set: EvalSet = serde_json::from_str(&content)?;
Ok(eval_set)
}
pub fn get_all_cases(&self, base_path: impl AsRef<Path>) -> Result<Vec<EvalCase>> {
let mut all_cases = self.eval_cases.clone();
for test_file_path in &self.test_files {
let full_path = base_path.as_ref().join(test_file_path);
let test_file = TestFile::load(&full_path).map_err(|e| {
EvalError::LoadError(format!("Failed to load {}: {}", test_file_path, e))
})?;
all_cases.extend(test_file.eval_cases);
}
Ok(all_cases)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EvalCase {
pub eval_id: String,
#[serde(default)]
pub description: String,
pub conversation: Vec<Turn>,
#[serde(default)]
pub session_input: SessionInput,
#[serde(default)]
pub tags: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Turn {
pub invocation_id: String,
pub user_content: ContentData,
#[serde(default)]
pub final_response: Option<ContentData>,
#[serde(default)]
pub intermediate_data: Option<IntermediateData>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContentData {
pub parts: Vec<Part>,
#[serde(default = "default_role")]
pub role: String,
}
fn default_role() -> String {
"user".to_string()
}
impl ContentData {
pub fn text(text: &str) -> Self {
Self { parts: vec![Part::Text { text: text.to_string() }], role: "user".to_string() }
}
pub fn model_response(text: &str) -> Self {
Self { parts: vec![Part::Text { text: text.to_string() }], role: "model".to_string() }
}
pub fn get_text(&self) -> String {
self.parts
.iter()
.filter_map(|p| match p {
Part::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("")
}
pub fn to_adk_content(&self) -> adk_core::Content {
let mut content = adk_core::Content::new(&self.role);
for part in &self.parts {
match part {
Part::Text { text } => {
content = content.with_text(text);
}
Part::FunctionCall { .. } | Part::FunctionResponse { .. } => {
}
}
}
content
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Part {
Text { text: String },
FunctionCall { name: String, args: Value },
FunctionResponse { name: String, response: Value },
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct IntermediateData {
#[serde(default)]
pub tool_uses: Vec<ToolUse>,
#[serde(default)]
pub intermediate_responses: Vec<ContentData>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolUse {
pub name: String,
#[serde(default)]
pub args: Value,
#[serde(default)]
pub expected_response: Option<Value>,
}
impl ToolUse {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
args: Value::Object(Default::default()),
expected_response: None,
}
}
pub fn with_args(mut self, args: Value) -> Self {
self.args = args;
self
}
pub fn matches(&self, other: &ToolUse, strict_args: bool) -> bool {
if self.name != other.name {
return false;
}
if strict_args {
self.args == other.args
} else {
match (&self.args, &other.args) {
(Value::Object(expected), Value::Object(actual)) => {
expected.iter().all(|(k, v)| actual.get(k) == Some(v))
}
_ => self.args == other.args,
}
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SessionInput {
#[serde(default)]
pub app_name: String,
#[serde(default)]
pub user_id: String,
#[serde(default)]
pub state: HashMap<String, Value>,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_parse_test_file() {
let json = r#"{
"eval_set_id": "test_set",
"name": "Test Set",
"description": "A test set",
"eval_cases": [
{
"eval_id": "test_1",
"conversation": [
{
"invocation_id": "inv_1",
"user_content": {
"parts": [{"text": "Hello"}],
"role": "user"
},
"final_response": {
"parts": [{"text": "Hi there!"}],
"role": "model"
}
}
]
}
]
}"#;
let test_file: TestFile = serde_json::from_str(json).unwrap();
assert_eq!(test_file.eval_set_id, "test_set");
assert_eq!(test_file.eval_cases.len(), 1);
assert_eq!(test_file.eval_cases[0].eval_id, "test_1");
}
#[test]
fn test_tool_use_matching() {
let expected = ToolUse::new("get_weather").with_args(json!({"location": "NYC"}));
let actual_exact = ToolUse::new("get_weather").with_args(json!({"location": "NYC"}));
assert!(expected.matches(&actual_exact, true));
let actual_extra =
ToolUse::new("get_weather").with_args(json!({"location": "NYC", "unit": "celsius"}));
assert!(!expected.matches(&actual_extra, true)); assert!(expected.matches(&actual_extra, false));
let actual_wrong = ToolUse::new("get_weather").with_args(json!({"location": "LA"}));
assert!(!expected.matches(&actual_wrong, true));
assert!(!expected.matches(&actual_wrong, false));
}
#[test]
fn test_content_data() {
let content = ContentData::text("Hello world");
assert_eq!(content.get_text(), "Hello world");
assert_eq!(content.role, "user");
let model = ContentData::model_response("Hi there!");
assert_eq!(model.role, "model");
}
}