1use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::collections::HashMap;
8use std::path::Path;
9
10use crate::error::{EvalError, Result};
11use crate::test_generator::EvalCaseMetadata;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct TestFile {
16 pub eval_set_id: String,
18 pub name: String,
20 #[serde(default)]
22 pub description: String,
23 pub eval_cases: Vec<EvalCase>,
25}
26
27impl TestFile {
28 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
30 let content = std::fs::read_to_string(path.as_ref())?;
31 let test_file: TestFile = serde_json::from_str(&content)?;
32 Ok(test_file)
33 }
34
35 pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
37 let content = serde_json::to_string_pretty(self)?;
38 std::fs::write(path, content)?;
39 Ok(())
40 }
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct EvalSet {
46 pub eval_set_id: String,
48 pub name: String,
50 #[serde(default)]
52 pub description: String,
53 #[serde(default)]
55 pub test_files: Vec<String>,
56 #[serde(default)]
58 pub eval_cases: Vec<EvalCase>,
59}
60
61impl EvalSet {
62 pub fn load(path: impl AsRef<Path>) -> Result<Self> {
64 let content = std::fs::read_to_string(path.as_ref())?;
65 let eval_set: EvalSet = serde_json::from_str(&content)?;
66 Ok(eval_set)
67 }
68
69 pub fn get_all_cases(&self, base_path: impl AsRef<Path>) -> Result<Vec<EvalCase>> {
71 let mut all_cases = self.eval_cases.clone();
72
73 for test_file_path in &self.test_files {
74 let full_path = base_path.as_ref().join(test_file_path);
75 let test_file = TestFile::load(&full_path).map_err(|e| {
76 EvalError::LoadError(format!("Failed to load {}: {}", test_file_path, e))
77 })?;
78 all_cases.extend(test_file.eval_cases);
79 }
80
81 Ok(all_cases)
82 }
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct EvalCase {
88 pub eval_id: String,
90 #[serde(default)]
92 pub description: String,
93 pub conversation: Vec<Turn>,
95 #[serde(default)]
97 pub session_input: SessionInput,
98 #[serde(default)]
100 pub tags: Vec<String>,
101 #[serde(default, skip_serializing_if = "Option::is_none")]
103 pub metadata: Option<EvalCaseMetadata>,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct Turn {
109 pub invocation_id: String,
111 pub user_content: ContentData,
113 #[serde(default)]
115 pub final_response: Option<ContentData>,
116 #[serde(default)]
118 pub intermediate_data: Option<IntermediateData>,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct ContentData {
124 pub parts: Vec<Part>,
126 #[serde(default = "default_role")]
128 pub role: String,
129}
130
131fn default_role() -> String {
132 "user".to_string()
133}
134
135impl ContentData {
136 pub fn text(text: &str) -> Self {
138 Self { parts: vec![Part::Text { text: text.to_string() }], role: "user".to_string() }
139 }
140
141 pub fn model_response(text: &str) -> Self {
143 Self { parts: vec![Part::Text { text: text.to_string() }], role: "model".to_string() }
144 }
145
146 pub fn get_text(&self) -> String {
148 self.parts
149 .iter()
150 .filter_map(|p| match p {
151 Part::Text { text } => Some(text.as_str()),
152 _ => None,
153 })
154 .collect::<Vec<_>>()
155 .join("")
156 }
157
158 pub fn to_adk_content(&self) -> adk_core::Content {
160 let mut content = adk_core::Content::new(&self.role);
161 for part in &self.parts {
162 match part {
163 Part::Text { text } => {
164 content = content.with_text(text);
165 }
166 Part::FunctionCall { .. } | Part::FunctionResponse { .. } => {
167 }
170 }
171 }
172 content
173 }
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
178#[serde(untagged)]
179pub enum Part {
180 Text { text: String },
182 FunctionCall { name: String, args: Value },
184 FunctionResponse { name: String, response: Value },
186}
187
188#[derive(Debug, Clone, Default, Serialize, Deserialize)]
190pub struct IntermediateData {
191 #[serde(default)]
193 pub tool_uses: Vec<ToolUse>,
194 #[serde(default)]
196 pub intermediate_responses: Vec<ContentData>,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct ToolUse {
202 pub name: String,
204 #[serde(default)]
206 pub args: Value,
207 #[serde(default)]
209 pub expected_response: Option<Value>,
210}
211
212impl ToolUse {
213 pub fn new(name: &str) -> Self {
215 Self {
216 name: name.to_string(),
217 args: Value::Object(Default::default()),
218 expected_response: None,
219 }
220 }
221
222 pub fn with_args(mut self, args: Value) -> Self {
224 self.args = args;
225 self
226 }
227
228 pub fn matches(&self, other: &ToolUse, strict_args: bool) -> bool {
230 if self.name != other.name {
231 return false;
232 }
233
234 if strict_args {
235 self.args == other.args
236 } else {
237 match (&self.args, &other.args) {
239 (Value::Object(expected), Value::Object(actual)) => {
240 expected.iter().all(|(k, v)| actual.get(k) == Some(v))
241 }
242 _ => self.args == other.args,
243 }
244 }
245 }
246}
247
248#[derive(Debug, Clone, Default, Serialize, Deserialize)]
250pub struct SessionInput {
251 #[serde(default)]
253 pub app_name: String,
254 #[serde(default)]
256 pub user_id: String,
257 #[serde(default)]
259 pub state: HashMap<String, Value>,
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use serde_json::json;
266
267 #[test]
268 fn test_parse_test_file() {
269 let json = r#"{
270 "eval_set_id": "test_set",
271 "name": "Test Set",
272 "description": "A test set",
273 "eval_cases": [
274 {
275 "eval_id": "test_1",
276 "conversation": [
277 {
278 "invocation_id": "inv_1",
279 "user_content": {
280 "parts": [{"text": "Hello"}],
281 "role": "user"
282 },
283 "final_response": {
284 "parts": [{"text": "Hi there!"}],
285 "role": "model"
286 }
287 }
288 ]
289 }
290 ]
291 }"#;
292
293 let test_file: TestFile = serde_json::from_str(json).unwrap();
294 assert_eq!(test_file.eval_set_id, "test_set");
295 assert_eq!(test_file.eval_cases.len(), 1);
296 assert_eq!(test_file.eval_cases[0].eval_id, "test_1");
297 }
298
299 #[test]
300 fn test_tool_use_matching() {
301 let expected = ToolUse::new("get_weather").with_args(json!({"location": "NYC"}));
302
303 let actual_exact = ToolUse::new("get_weather").with_args(json!({"location": "NYC"}));
304 assert!(expected.matches(&actual_exact, true));
305
306 let actual_extra =
307 ToolUse::new("get_weather").with_args(json!({"location": "NYC", "unit": "celsius"}));
308 assert!(!expected.matches(&actual_extra, true)); assert!(expected.matches(&actual_extra, false)); let actual_wrong = ToolUse::new("get_weather").with_args(json!({"location": "LA"}));
312 assert!(!expected.matches(&actual_wrong, true));
313 assert!(!expected.matches(&actual_wrong, false));
314 }
315
316 #[test]
317 fn test_content_data() {
318 let content = ContentData::text("Hello world");
319 assert_eq!(content.get_text(), "Hello world");
320 assert_eq!(content.role, "user");
321
322 let model = ContentData::model_response("Hi there!");
323 assert_eq!(model.role, "model");
324 }
325}