1use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::collections::HashMap;
8use std::path::Path;
9
10use crate::error::{EvalError, Result};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct TestFile {
15 pub eval_set_id: String,
17 pub name: String,
19 #[serde(default)]
21 pub description: String,
22 pub eval_cases: Vec<EvalCase>,
24}
25
26impl TestFile {
27 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
29 let content = std::fs::read_to_string(path.as_ref())?;
30 let test_file: TestFile = serde_json::from_str(&content)?;
31 Ok(test_file)
32 }
33
34 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
36 let content = serde_json::to_string_pretty(self)?;
37 std::fs::write(path, content)?;
38 Ok(())
39 }
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct EvalSet {
45 pub eval_set_id: String,
47 pub name: String,
49 #[serde(default)]
51 pub description: String,
52 #[serde(default)]
54 pub test_files: Vec<String>,
55 #[serde(default)]
57 pub eval_cases: Vec<EvalCase>,
58}
59
60impl EvalSet {
61 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
63 let content = std::fs::read_to_string(path.as_ref())?;
64 let eval_set: EvalSet = serde_json::from_str(&content)?;
65 Ok(eval_set)
66 }
67
68 pub fn get_all_cases(&self, base_path: impl AsRef<Path>) -> Result<Vec<EvalCase>> {
70 let mut all_cases = self.eval_cases.clone();
71
72 for test_file_path in &self.test_files {
73 let full_path = base_path.as_ref().join(test_file_path);
74 let test_file = TestFile::load(&full_path).map_err(|e| {
75 EvalError::LoadError(format!("Failed to load {}: {}", test_file_path, e))
76 })?;
77 all_cases.extend(test_file.eval_cases);
78 }
79
80 Ok(all_cases)
81 }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct EvalCase {
87 pub eval_id: String,
89 #[serde(default)]
91 pub description: String,
92 pub conversation: Vec<Turn>,
94 #[serde(default)]
96 pub session_input: SessionInput,
97 #[serde(default)]
99 pub tags: Vec<String>,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct Turn {
105 pub invocation_id: String,
107 pub user_content: ContentData,
109 #[serde(default)]
111 pub final_response: Option<ContentData>,
112 #[serde(default)]
114 pub intermediate_data: Option<IntermediateData>,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct ContentData {
120 pub parts: Vec<Part>,
122 #[serde(default = "default_role")]
124 pub role: String,
125}
126
127fn default_role() -> String {
128 "user".to_string()
129}
130
131impl ContentData {
132 pub fn text(text: &str) -> Self {
134 Self { parts: vec![Part::Text { text: text.to_string() }], role: "user".to_string() }
135 }
136
137 pub fn model_response(text: &str) -> Self {
139 Self { parts: vec![Part::Text { text: text.to_string() }], role: "model".to_string() }
140 }
141
142 pub fn get_text(&self) -> String {
144 self.parts
145 .iter()
146 .filter_map(|p| match p {
147 Part::Text { text } => Some(text.as_str()),
148 _ => None,
149 })
150 .collect::<Vec<_>>()
151 .join("")
152 }
153
154 pub fn to_adk_content(&self) -> adk_core::Content {
156 let mut content = adk_core::Content::new(&self.role);
157 for part in &self.parts {
158 match part {
159 Part::Text { text } => {
160 content = content.with_text(text);
161 }
162 Part::FunctionCall { .. } | Part::FunctionResponse { .. } => {
163 }
166 }
167 }
168 content
169 }
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
174#[serde(untagged)]
175pub enum Part {
176 Text { text: String },
178 FunctionCall { name: String, args: Value },
180 FunctionResponse { name: String, response: Value },
182}
183
184#[derive(Debug, Clone, Default, Serialize, Deserialize)]
186pub struct IntermediateData {
187 #[serde(default)]
189 pub tool_uses: Vec<ToolUse>,
190 #[serde(default)]
192 pub intermediate_responses: Vec<ContentData>,
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct ToolUse {
198 pub name: String,
200 #[serde(default)]
202 pub args: Value,
203 #[serde(default)]
205 pub expected_response: Option<Value>,
206}
207
208impl ToolUse {
209 pub fn new(name: &str) -> Self {
211 Self {
212 name: name.to_string(),
213 args: Value::Object(Default::default()),
214 expected_response: None,
215 }
216 }
217
218 pub fn with_args(mut self, args: Value) -> Self {
220 self.args = args;
221 self
222 }
223
224 pub fn matches(&self, other: &ToolUse, strict_args: bool) -> bool {
226 if self.name != other.name {
227 return false;
228 }
229
230 if strict_args {
231 self.args == other.args
232 } else {
233 match (&self.args, &other.args) {
235 (Value::Object(expected), Value::Object(actual)) => {
236 expected.iter().all(|(k, v)| actual.get(k) == Some(v))
237 }
238 _ => self.args == other.args,
239 }
240 }
241 }
242}
243
244#[derive(Debug, Clone, Default, Serialize, Deserialize)]
246pub struct SessionInput {
247 #[serde(default)]
249 pub app_name: String,
250 #[serde(default)]
252 pub user_id: String,
253 #[serde(default)]
255 pub state: HashMap<String, Value>,
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use serde_json::json;
262
263 #[test]
264 fn test_parse_test_file() {
265 let json = r#"{
266 "eval_set_id": "test_set",
267 "name": "Test Set",
268 "description": "A test set",
269 "eval_cases": [
270 {
271 "eval_id": "test_1",
272 "conversation": [
273 {
274 "invocation_id": "inv_1",
275 "user_content": {
276 "parts": [{"text": "Hello"}],
277 "role": "user"
278 },
279 "final_response": {
280 "parts": [{"text": "Hi there!"}],
281 "role": "model"
282 }
283 }
284 ]
285 }
286 ]
287 }"#;
288
289 let test_file: TestFile = serde_json::from_str(json).unwrap();
290 assert_eq!(test_file.eval_set_id, "test_set");
291 assert_eq!(test_file.eval_cases.len(), 1);
292 assert_eq!(test_file.eval_cases[0].eval_id, "test_1");
293 }
294
295 #[test]
296 fn test_tool_use_matching() {
297 let expected = ToolUse::new("get_weather").with_args(json!({"location": "NYC"}));
298
299 let actual_exact = ToolUse::new("get_weather").with_args(json!({"location": "NYC"}));
300 assert!(expected.matches(&actual_exact, true));
301
302 let actual_extra =
303 ToolUse::new("get_weather").with_args(json!({"location": "NYC", "unit": "celsius"}));
304 assert!(!expected.matches(&actual_extra, true)); assert!(expected.matches(&actual_extra, false)); let actual_wrong = ToolUse::new("get_weather").with_args(json!({"location": "LA"}));
308 assert!(!expected.matches(&actual_wrong, true));
309 assert!(!expected.matches(&actual_wrong, false));
310 }
311
312 #[test]
313 fn test_content_data() {
314 let content = ContentData::text("Hello world");
315 assert_eq!(content.get_text(), "Hello world");
316 assert_eq!(content.role, "user");
317
318 let model = ContentData::model_response("Hi there!");
319 assert_eq!(model.role, "model");
320 }
321}