adk_eval/
schema.rs

1//! Test file schema definitions
2//!
3//! Defines the structure for test files (`.test.json`) and eval sets (`.evalset.json`).
4
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::collections::HashMap;
8use std::path::Path;
9
10use crate::error::{EvalError, Result};
11
12/// A complete test file containing multiple evaluation cases
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct TestFile {
15    /// Unique identifier for this eval set
16    pub eval_set_id: String,
17    /// Human-readable name
18    pub name: String,
19    /// Description of what these tests cover
20    #[serde(default)]
21    pub description: String,
22    /// List of evaluation cases
23    pub eval_cases: Vec<EvalCase>,
24}
25
26impl TestFile {
27    /// Load a test file from disk
28    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
29        let content = std::fs::read_to_string(path.as_ref())?;
30        let test_file: TestFile = serde_json::from_str(&content)?;
31        Ok(test_file)
32    }
33
34    /// Save test file to disk
35    pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
36        let content = serde_json::to_string_pretty(self)?;
37        std::fs::write(path, content)?;
38        Ok(())
39    }
40}
41
42/// An eval set references multiple test files
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct EvalSet {
45    /// Unique identifier
46    pub eval_set_id: String,
47    /// Human-readable name
48    pub name: String,
49    /// Description
50    #[serde(default)]
51    pub description: String,
52    /// List of test file paths or inline eval cases
53    #[serde(default)]
54    pub test_files: Vec<String>,
55    /// Inline eval cases (alternative to test_files)
56    #[serde(default)]
57    pub eval_cases: Vec<EvalCase>,
58}
59
60impl EvalSet {
61    /// Load an eval set from disk
62    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
63        let content = std::fs::read_to_string(path.as_ref())?;
64        let eval_set: EvalSet = serde_json::from_str(&content)?;
65        Ok(eval_set)
66    }
67
68    /// Get all eval cases, loading from test files if needed
69    pub fn get_all_cases(&self, base_path: impl AsRef<Path>) -> Result<Vec<EvalCase>> {
70        let mut all_cases = self.eval_cases.clone();
71
72        for test_file_path in &self.test_files {
73            let full_path = base_path.as_ref().join(test_file_path);
74            let test_file = TestFile::load(&full_path).map_err(|e| {
75                EvalError::LoadError(format!("Failed to load {}: {}", test_file_path, e))
76            })?;
77            all_cases.extend(test_file.eval_cases);
78        }
79
80        Ok(all_cases)
81    }
82}
83
84/// A single evaluation case (test case)
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct EvalCase {
87    /// Unique identifier for this test case
88    pub eval_id: String,
89    /// Optional description
90    #[serde(default)]
91    pub description: String,
92    /// The conversation turns to evaluate
93    pub conversation: Vec<Turn>,
94    /// Session configuration
95    #[serde(default)]
96    pub session_input: SessionInput,
97    /// Optional tags for filtering
98    #[serde(default)]
99    pub tags: Vec<String>,
100}
101
102/// A single turn in a conversation
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct Turn {
105    /// Unique identifier for this turn
106    pub invocation_id: String,
107    /// User input content
108    pub user_content: ContentData,
109    /// Expected final response from the agent
110    #[serde(default)]
111    pub final_response: Option<ContentData>,
112    /// Expected intermediate data (tool calls, etc.)
113    #[serde(default)]
114    pub intermediate_data: Option<IntermediateData>,
115}
116
117/// Content data structure (matches ADK Content)
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct ContentData {
120    /// Content parts
121    pub parts: Vec<Part>,
122    /// Role (user, model, tool)
123    #[serde(default = "default_role")]
124    pub role: String,
125}
126
127fn default_role() -> String {
128    "user".to_string()
129}
130
131impl ContentData {
132    /// Create content from text
133    pub fn text(text: &str) -> Self {
134        Self { parts: vec![Part::Text { text: text.to_string() }], role: "user".to_string() }
135    }
136
137    /// Create model response content
138    pub fn model_response(text: &str) -> Self {
139        Self { parts: vec![Part::Text { text: text.to_string() }], role: "model".to_string() }
140    }
141
142    /// Get all text parts concatenated
143    pub fn get_text(&self) -> String {
144        self.parts
145            .iter()
146            .filter_map(|p| match p {
147                Part::Text { text } => Some(text.as_str()),
148                _ => None,
149            })
150            .collect::<Vec<_>>()
151            .join("")
152    }
153
154    /// Convert to ADK Content
155    pub fn to_adk_content(&self) -> adk_core::Content {
156        let mut content = adk_core::Content::new(&self.role);
157        for part in &self.parts {
158            match part {
159                Part::Text { text } => {
160                    content = content.with_text(text);
161                }
162                Part::FunctionCall { .. } | Part::FunctionResponse { .. } => {
163                    // Function calls/responses are handled separately in the evaluation
164                    // The Content type doesn't have direct methods for these
165                }
166            }
167        }
168        content
169    }
170}
171
172/// Content part variants
173#[derive(Debug, Clone, Serialize, Deserialize)]
174#[serde(untagged)]
175pub enum Part {
176    /// Text content
177    Text { text: String },
178    /// Function/tool call
179    FunctionCall { name: String, args: Value },
180    /// Function/tool response
181    FunctionResponse { name: String, response: Value },
182}
183
184/// Intermediate data during a turn (tool calls, etc.)
185#[derive(Debug, Clone, Default, Serialize, Deserialize)]
186pub struct IntermediateData {
187    /// Expected tool calls in order
188    #[serde(default)]
189    pub tool_uses: Vec<ToolUse>,
190    /// Intermediate responses before final
191    #[serde(default)]
192    pub intermediate_responses: Vec<ContentData>,
193}
194
195/// A tool use (function call)
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct ToolUse {
198    /// Tool/function name
199    pub name: String,
200    /// Arguments passed to the tool
201    #[serde(default)]
202    pub args: Value,
203    /// Expected response (optional, for mocking)
204    #[serde(default)]
205    pub expected_response: Option<Value>,
206}
207
208impl ToolUse {
209    /// Create a new tool use
210    pub fn new(name: &str) -> Self {
211        Self {
212            name: name.to_string(),
213            args: Value::Object(Default::default()),
214            expected_response: None,
215        }
216    }
217
218    /// Add arguments
219    pub fn with_args(mut self, args: Value) -> Self {
220        self.args = args;
221        self
222    }
223
224    /// Check if this tool use matches another (name and args)
225    pub fn matches(&self, other: &ToolUse, strict_args: bool) -> bool {
226        if self.name != other.name {
227            return false;
228        }
229
230        if strict_args {
231            self.args == other.args
232        } else {
233            // Partial match: check that expected args are present in actual
234            match (&self.args, &other.args) {
235                (Value::Object(expected), Value::Object(actual)) => {
236                    expected.iter().all(|(k, v)| actual.get(k) == Some(v))
237                }
238                _ => self.args == other.args,
239            }
240        }
241    }
242}
243
244/// Session input configuration
245#[derive(Debug, Clone, Default, Serialize, Deserialize)]
246pub struct SessionInput {
247    /// Application name
248    #[serde(default)]
249    pub app_name: String,
250    /// User identifier
251    #[serde(default)]
252    pub user_id: String,
253    /// Initial state
254    #[serde(default)]
255    pub state: HashMap<String, Value>,
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use serde_json::json;
262
263    #[test]
264    fn test_parse_test_file() {
265        let json = r#"{
266            "eval_set_id": "test_set",
267            "name": "Test Set",
268            "description": "A test set",
269            "eval_cases": [
270                {
271                    "eval_id": "test_1",
272                    "conversation": [
273                        {
274                            "invocation_id": "inv_1",
275                            "user_content": {
276                                "parts": [{"text": "Hello"}],
277                                "role": "user"
278                            },
279                            "final_response": {
280                                "parts": [{"text": "Hi there!"}],
281                                "role": "model"
282                            }
283                        }
284                    ]
285                }
286            ]
287        }"#;
288
289        let test_file: TestFile = serde_json::from_str(json).unwrap();
290        assert_eq!(test_file.eval_set_id, "test_set");
291        assert_eq!(test_file.eval_cases.len(), 1);
292        assert_eq!(test_file.eval_cases[0].eval_id, "test_1");
293    }
294
295    #[test]
296    fn test_tool_use_matching() {
297        let expected = ToolUse::new("get_weather").with_args(json!({"location": "NYC"}));
298
299        let actual_exact = ToolUse::new("get_weather").with_args(json!({"location": "NYC"}));
300        assert!(expected.matches(&actual_exact, true));
301
302        let actual_extra =
303            ToolUse::new("get_weather").with_args(json!({"location": "NYC", "unit": "celsius"}));
304        assert!(!expected.matches(&actual_extra, true)); // Strict fails
305        assert!(expected.matches(&actual_extra, false)); // Partial passes
306
307        let actual_wrong = ToolUse::new("get_weather").with_args(json!({"location": "LA"}));
308        assert!(!expected.matches(&actual_wrong, true));
309        assert!(!expected.matches(&actual_wrong, false));
310    }
311
312    #[test]
313    fn test_content_data() {
314        let content = ContentData::text("Hello world");
315        assert_eq!(content.get_text(), "Hello world");
316        assert_eq!(content.role, "user");
317
318        let model = ContentData::model_response("Hi there!");
319        assert_eq!(model.role, "model");
320    }
321}