Skip to main content

adk_eval/
schema.rs

1//! Test file schema definitions
2//!
3//! Defines the structure for test files (`.test.json`) and eval sets (`.evalset.json`).
4
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::collections::HashMap;
8use std::path::Path;
9
10use crate::error::{EvalError, Result};
11use crate::test_generator::EvalCaseMetadata;
12
13/// A complete test file containing multiple evaluation cases
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct TestFile {
16    /// Unique identifier for this eval set
17    pub eval_set_id: String,
18    /// Human-readable name
19    pub name: String,
20    /// Description of what these tests cover
21    #[serde(default)]
22    pub description: String,
23    /// List of evaluation cases
24    pub eval_cases: Vec<EvalCase>,
25}
26
27impl TestFile {
28    /// Load a test file from disk
29    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
30        let content = std::fs::read_to_string(path.as_ref())?;
31        let test_file: TestFile = serde_json::from_str(&content)?;
32        Ok(test_file)
33    }
34
35    /// Save test file to disk
36    pub fn save(&self, path: impl AsRef<Path>) -> Result<()> {
37        let content = serde_json::to_string_pretty(self)?;
38        std::fs::write(path, content)?;
39        Ok(())
40    }
41}
42
43/// An eval set references multiple test files
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct EvalSet {
46    /// Unique identifier
47    pub eval_set_id: String,
48    /// Human-readable name
49    pub name: String,
50    /// Description
51    #[serde(default)]
52    pub description: String,
53    /// List of test file paths or inline eval cases
54    #[serde(default)]
55    pub test_files: Vec<String>,
56    /// Inline eval cases (alternative to test_files)
57    #[serde(default)]
58    pub eval_cases: Vec<EvalCase>,
59}
60
61impl EvalSet {
62    /// Load an eval set from disk
63    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
64        let content = std::fs::read_to_string(path.as_ref())?;
65        let eval_set: EvalSet = serde_json::from_str(&content)?;
66        Ok(eval_set)
67    }
68
69    /// Get all eval cases, loading from test files if needed
70    pub fn get_all_cases(&self, base_path: impl AsRef<Path>) -> Result<Vec<EvalCase>> {
71        let mut all_cases = self.eval_cases.clone();
72
73        for test_file_path in &self.test_files {
74            let full_path = base_path.as_ref().join(test_file_path);
75            let test_file = TestFile::load(&full_path).map_err(|e| {
76                EvalError::LoadError(format!("Failed to load {}: {}", test_file_path, e))
77            })?;
78            all_cases.extend(test_file.eval_cases);
79        }
80
81        Ok(all_cases)
82    }
83}
84
85/// A single evaluation case (test case)
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct EvalCase {
88    /// Unique identifier for this test case
89    pub eval_id: String,
90    /// Optional description
91    #[serde(default)]
92    pub description: String,
93    /// The conversation turns to evaluate
94    pub conversation: Vec<Turn>,
95    /// Session configuration
96    #[serde(default)]
97    pub session_input: SessionInput,
98    /// Optional tags for filtering
99    #[serde(default)]
100    pub tags: Vec<String>,
101    /// Optional metadata (generation info, etc.)
102    #[serde(default, skip_serializing_if = "Option::is_none")]
103    pub metadata: Option<EvalCaseMetadata>,
104}
105
106/// A single turn in a conversation
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct Turn {
109    /// Unique identifier for this turn
110    pub invocation_id: String,
111    /// User input content
112    pub user_content: ContentData,
113    /// Expected final response from the agent
114    #[serde(default)]
115    pub final_response: Option<ContentData>,
116    /// Expected intermediate data (tool calls, etc.)
117    #[serde(default)]
118    pub intermediate_data: Option<IntermediateData>,
119}
120
121/// Content data structure (matches ADK Content)
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct ContentData {
124    /// Content parts
125    pub parts: Vec<Part>,
126    /// Role (user, model, tool)
127    #[serde(default = "default_role")]
128    pub role: String,
129}
130
131fn default_role() -> String {
132    "user".to_string()
133}
134
135impl ContentData {
136    /// Create content from text
137    pub fn text(text: &str) -> Self {
138        Self { parts: vec![Part::Text { text: text.to_string() }], role: "user".to_string() }
139    }
140
141    /// Create model response content
142    pub fn model_response(text: &str) -> Self {
143        Self { parts: vec![Part::Text { text: text.to_string() }], role: "model".to_string() }
144    }
145
146    /// Get all text parts concatenated
147    pub fn get_text(&self) -> String {
148        self.parts
149            .iter()
150            .filter_map(|p| match p {
151                Part::Text { text } => Some(text.as_str()),
152                _ => None,
153            })
154            .collect::<Vec<_>>()
155            .join("")
156    }
157
158    /// Convert to ADK Content
159    pub fn to_adk_content(&self) -> adk_core::Content {
160        let mut content = adk_core::Content::new(&self.role);
161        for part in &self.parts {
162            match part {
163                Part::Text { text } => {
164                    content = content.with_text(text);
165                }
166                Part::FunctionCall { .. } | Part::FunctionResponse { .. } => {
167                    // Function calls/responses are handled separately in the evaluation
168                    // The Content type doesn't have direct methods for these
169                }
170            }
171        }
172        content
173    }
174}
175
176/// Content part variants
177#[derive(Debug, Clone, Serialize, Deserialize)]
178#[serde(untagged)]
179pub enum Part {
180    /// Text content
181    Text { text: String },
182    /// Function/tool call
183    FunctionCall { name: String, args: Value },
184    /// Function/tool response
185    FunctionResponse { name: String, response: Value },
186}
187
188/// Intermediate data during a turn (tool calls, etc.)
189#[derive(Debug, Clone, Default, Serialize, Deserialize)]
190pub struct IntermediateData {
191    /// Expected tool calls in order
192    #[serde(default)]
193    pub tool_uses: Vec<ToolUse>,
194    /// Intermediate responses before final
195    #[serde(default)]
196    pub intermediate_responses: Vec<ContentData>,
197}
198
199/// A tool use (function call)
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct ToolUse {
202    /// Tool/function name
203    pub name: String,
204    /// Arguments passed to the tool
205    #[serde(default)]
206    pub args: Value,
207    /// Expected response (optional, for mocking)
208    #[serde(default)]
209    pub expected_response: Option<Value>,
210}
211
212impl ToolUse {
213    /// Create a new tool use
214    pub fn new(name: &str) -> Self {
215        Self {
216            name: name.to_string(),
217            args: Value::Object(Default::default()),
218            expected_response: None,
219        }
220    }
221
222    /// Add arguments
223    pub fn with_args(mut self, args: Value) -> Self {
224        self.args = args;
225        self
226    }
227
228    /// Check if this tool use matches another (name and args)
229    pub fn matches(&self, other: &ToolUse, strict_args: bool) -> bool {
230        if self.name != other.name {
231            return false;
232        }
233
234        if strict_args {
235            self.args == other.args
236        } else {
237            // Partial match: check that expected args are present in actual
238            match (&self.args, &other.args) {
239                (Value::Object(expected), Value::Object(actual)) => {
240                    expected.iter().all(|(k, v)| actual.get(k) == Some(v))
241                }
242                _ => self.args == other.args,
243            }
244        }
245    }
246}
247
248/// Session input configuration
249#[derive(Debug, Clone, Default, Serialize, Deserialize)]
250pub struct SessionInput {
251    /// Application name
252    #[serde(default)]
253    pub app_name: String,
254    /// User identifier
255    #[serde(default)]
256    pub user_id: String,
257    /// Initial state
258    #[serde(default)]
259    pub state: HashMap<String, Value>,
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use serde_json::json;
266
267    #[test]
268    fn test_parse_test_file() {
269        let json = r#"{
270            "eval_set_id": "test_set",
271            "name": "Test Set",
272            "description": "A test set",
273            "eval_cases": [
274                {
275                    "eval_id": "test_1",
276                    "conversation": [
277                        {
278                            "invocation_id": "inv_1",
279                            "user_content": {
280                                "parts": [{"text": "Hello"}],
281                                "role": "user"
282                            },
283                            "final_response": {
284                                "parts": [{"text": "Hi there!"}],
285                                "role": "model"
286                            }
287                        }
288                    ]
289                }
290            ]
291        }"#;
292
293        let test_file: TestFile = serde_json::from_str(json).unwrap();
294        assert_eq!(test_file.eval_set_id, "test_set");
295        assert_eq!(test_file.eval_cases.len(), 1);
296        assert_eq!(test_file.eval_cases[0].eval_id, "test_1");
297    }
298
299    #[test]
300    fn test_tool_use_matching() {
301        let expected = ToolUse::new("get_weather").with_args(json!({"location": "NYC"}));
302
303        let actual_exact = ToolUse::new("get_weather").with_args(json!({"location": "NYC"}));
304        assert!(expected.matches(&actual_exact, true));
305
306        let actual_extra =
307            ToolUse::new("get_weather").with_args(json!({"location": "NYC", "unit": "celsius"}));
308        assert!(!expected.matches(&actual_extra, true)); // Strict fails
309        assert!(expected.matches(&actual_extra, false)); // Partial passes
310
311        let actual_wrong = ToolUse::new("get_weather").with_args(json!({"location": "LA"}));
312        assert!(!expected.matches(&actual_wrong, true));
313        assert!(!expected.matches(&actual_wrong, false));
314    }
315
316    #[test]
317    fn test_content_data() {
318        let content = ContentData::text("Hello world");
319        assert_eq!(content.get_text(), "Hello world");
320        assert_eq!(content.role, "user");
321
322        let model = ContentData::model_response("Hi there!");
323        assert_eq!(model.role, "model");
324    }
325}