Skip to main content

rs_adk/
agent_config.rs

1//! YAML/TOML agent configuration — define agents without code.
2//!
3//! Mirrors upstream ADK's `root_agent.yaml` format. Agents can be defined
4//! declaratively and loaded at runtime by the CLI or API server.
5
6use std::collections::HashMap;
7use std::path::Path;
8
9use serde::{Deserialize, Serialize};
10
11/// Declarative agent configuration — loadable from YAML or TOML.
12///
13/// # Example YAML
14///
15/// ```yaml
16/// name: weather_agent
17/// model: gemini-2.0-flash
18/// instruction: "You are a helpful weather assistant."
19/// description: "Provides weather information for cities."
20/// tools:
21///   - name: get_weather
22///     description: "Get weather for a city"
23///   - builtin: google_search
24/// sub_agents:
25///   - name: forecast_agent
26///     model: gemini-2.0-flash
27///     instruction: "Provide 5-day forecasts."
28/// output_key: weather_result
29/// ```
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct AgentConfig {
32    /// Agent name (required).
33    pub name: String,
34
35    /// Model identifier (e.g., "gemini-2.0-flash", "gemini-2.5-pro").
36    #[serde(default)]
37    pub model: Option<String>,
38
39    /// System instruction for the agent.
40    #[serde(default)]
41    pub instruction: Option<String>,
42
43    /// Human-readable description of what this agent does.
44    #[serde(default)]
45    pub description: Option<String>,
46
47    /// Tool declarations.
48    #[serde(default)]
49    pub tools: Vec<ToolConfig>,
50
51    /// Sub-agent configurations (for multi-agent hierarchies).
52    #[serde(default)]
53    pub sub_agents: Vec<AgentConfig>,
54
55    /// Temperature for generation (0.0 - 2.0).
56    #[serde(default)]
57    pub temperature: Option<f32>,
58
59    /// Maximum output tokens.
60    #[serde(default)]
61    pub max_output_tokens: Option<u32>,
62
63    /// Thinking budget (Google AI only).
64    #[serde(default)]
65    pub thinking_budget: Option<u32>,
66
67    /// State key to auto-save the agent's final response into.
68    #[serde(default)]
69    pub output_key: Option<String>,
70
71    /// JSON Schema for structured output.
72    #[serde(default)]
73    pub output_schema: Option<serde_json::Value>,
74
75    /// Maximum number of LLM calls per invocation (safety limit).
76    #[serde(default)]
77    pub max_llm_calls: Option<u32>,
78
79    /// Agent type: "llm" (default), "sequential", "parallel", "loop".
80    #[serde(default = "default_agent_type")]
81    pub agent_type: String,
82
83    /// For loop agents: maximum iterations.
84    #[serde(default)]
85    pub max_iterations: Option<u32>,
86
87    /// Custom metadata (passed through to state or callbacks).
88    #[serde(default)]
89    pub metadata: HashMap<String, serde_json::Value>,
90
91    /// Voice configuration for live agents.
92    #[serde(default)]
93    pub voice: Option<String>,
94
95    /// Greeting message (model speaks first on connect).
96    #[serde(default)]
97    pub greeting: Option<String>,
98
99    /// Whether to enable transcription.
100    #[serde(default)]
101    pub transcription: Option<bool>,
102
103    /// Whether to enable A2A protocol endpoint.
104    #[serde(default)]
105    pub a2a: Option<bool>,
106
107    /// Environment variables to set when loading this agent.
108    #[serde(default)]
109    pub env: HashMap<String, String>,
110}
111
112fn default_agent_type() -> String {
113    "llm".to_string()
114}
115
116/// Tool configuration within an agent config.
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct ToolConfig {
119    /// Tool name (for custom tools).
120    #[serde(default)]
121    pub name: Option<String>,
122
123    /// Tool description.
124    #[serde(default)]
125    pub description: Option<String>,
126
127    /// Built-in tool type (e.g., "google_search", "code_execution", "url_context").
128    #[serde(default)]
129    pub builtin: Option<String>,
130
131    /// JSON Schema for the tool's parameters.
132    #[serde(default)]
133    pub parameters: Option<serde_json::Value>,
134}
135
136/// Errors from agent config operations.
137#[derive(Debug, thiserror::Error)]
138pub enum AgentConfigError {
139    /// Failed to read the config file.
140    #[error("IO error: {0}")]
141    Io(#[from] std::io::Error),
142
143    /// Failed to parse YAML.
144    #[error("YAML parse error: {0}")]
145    Yaml(String),
146
147    /// Failed to parse TOML.
148    #[error("TOML parse error: {0}")]
149    Toml(String),
150
151    /// Failed to parse JSON.
152    #[error("JSON parse error: {0}")]
153    Json(#[from] serde_json::Error),
154
155    /// Invalid configuration.
156    #[error("Invalid config: {0}")]
157    Invalid(String),
158}
159
160impl AgentConfig {
161    /// Load agent config from a YAML file.
162    pub fn from_yaml_file(path: &Path) -> Result<Self, AgentConfigError> {
163        let content = std::fs::read_to_string(path)?;
164        Self::from_yaml(&content)
165    }
166
167    /// Parse agent config from a YAML string.
168    pub fn from_yaml(yaml: &str) -> Result<Self, AgentConfigError> {
169        serde_json::from_value(
170            serde_json::to_value(
171                // Use serde_json roundtrip since we don't want to add serde_yaml dep.
172                // In practice, the CLI crate will parse YAML and pass the Value.
173                serde_json::from_str::<serde_json::Value>(yaml)
174                    .map_err(|e| AgentConfigError::Yaml(e.to_string()))?,
175            )
176            .map_err(|e| AgentConfigError::Yaml(e.to_string()))?,
177        )
178        .map_err(|e| AgentConfigError::Yaml(e.to_string()))
179    }
180
181    /// Parse agent config from a JSON string.
182    pub fn from_json(json: &str) -> Result<Self, AgentConfigError> {
183        Ok(serde_json::from_str(json)?)
184    }
185
186    /// Parse agent config from a JSON value.
187    pub fn from_value(value: serde_json::Value) -> Result<Self, AgentConfigError> {
188        Ok(serde_json::from_value(value)?)
189    }
190
191    /// Validate the configuration.
192    pub fn validate(&self) -> Result<(), AgentConfigError> {
193        if self.name.is_empty() {
194            return Err(AgentConfigError::Invalid("Agent name is required".into()));
195        }
196        if let Some(temp) = self.temperature {
197            if !(0.0..=2.0).contains(&temp) {
198                return Err(AgentConfigError::Invalid(format!(
199                    "Temperature must be 0.0-2.0, got {}",
200                    temp
201                )));
202            }
203        }
204        // Validate sub-agents recursively.
205        for sub in &self.sub_agents {
206            sub.validate()?;
207        }
208        Ok(())
209    }
210
211    /// Check if this is a built-in tool reference.
212    pub fn builtin_tools(&self) -> Vec<&str> {
213        self.tools
214            .iter()
215            .filter_map(|t| t.builtin.as_deref())
216            .collect()
217    }
218
219    /// Check if this is a workflow agent (non-LLM).
220    pub fn is_workflow(&self) -> bool {
221        matches!(self.agent_type.as_str(), "sequential" | "parallel" | "loop")
222    }
223}
224
225/// Discover agent configurations in a directory.
226///
227/// Scans for files named `agent.yaml`, `agent.json`, `agent.toml`,
228/// `root_agent.yaml`, or `root_agent.json`.
229pub fn discover_agent_configs(dir: &Path) -> Result<Vec<AgentConfig>, AgentConfigError> {
230    let candidates = ["agent.json", "root_agent.json", "agent.toml"];
231
232    let mut configs = Vec::new();
233    for candidate in &candidates {
234        let path = dir.join(candidate);
235        if path.exists() {
236            let content = std::fs::read_to_string(&path)?;
237            let config: AgentConfig = if candidate.ends_with(".json") {
238                serde_json::from_str(&content)?
239            } else if candidate.ends_with(".toml") {
240                // TOML parsing delegated to CLI crate which has the toml dep
241                return Err(AgentConfigError::Toml(
242                    "TOML parsing requires the adk-cli crate".into(),
243                ));
244            } else {
245                return Err(AgentConfigError::Yaml(
246                    "YAML parsing requires the adk-cli crate".into(),
247                ));
248            };
249            config.validate()?;
250            configs.push(config);
251        }
252    }
253
254    // Also scan subdirectories for agent configs.
255    if let Ok(entries) = std::fs::read_dir(dir) {
256        for entry in entries.flatten() {
257            let path = entry.path();
258            if path.is_dir() {
259                if let Ok(sub_configs) = discover_agent_configs(&path) {
260                    configs.extend(sub_configs);
261                }
262            }
263        }
264    }
265
266    Ok(configs)
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn parse_minimal_json_config() {
275        let json = r#"{"name": "test_agent"}"#;
276        let config = AgentConfig::from_json(json).unwrap();
277        assert_eq!(config.name, "test_agent");
278        assert_eq!(config.agent_type, "llm");
279        assert!(config.model.is_none());
280        assert!(config.tools.is_empty());
281    }
282
283    #[test]
284    fn parse_full_json_config() {
285        let json = r#"{
286            "name": "weather_agent",
287            "model": "gemini-2.0-flash",
288            "instruction": "You are a weather assistant.",
289            "description": "Gets weather info",
290            "temperature": 0.3,
291            "max_output_tokens": 1024,
292            "output_key": "weather_result",
293            "max_llm_calls": 10,
294            "tools": [
295                {"name": "get_weather", "description": "Get weather for a city"},
296                {"builtin": "google_search"}
297            ],
298            "sub_agents": [
299                {"name": "forecast", "instruction": "Give forecasts"}
300            ]
301        }"#;
302        let config = AgentConfig::from_json(json).unwrap();
303        assert_eq!(config.name, "weather_agent");
304        assert_eq!(config.model.as_deref(), Some("gemini-2.0-flash"));
305        assert_eq!(config.temperature, Some(0.3));
306        assert_eq!(config.output_key.as_deref(), Some("weather_result"));
307        assert_eq!(config.max_llm_calls, Some(10));
308        assert_eq!(config.tools.len(), 2);
309        assert_eq!(config.sub_agents.len(), 1);
310        assert_eq!(config.builtin_tools(), vec!["google_search"]);
311    }
312
313    #[test]
314    fn validate_empty_name_fails() {
315        let config = AgentConfig::from_json(r#"{"name": ""}"#).unwrap();
316        assert!(config.validate().is_err());
317    }
318
319    #[test]
320    fn validate_bad_temperature_fails() {
321        let config = AgentConfig::from_json(r#"{"name": "test", "temperature": 3.0}"#).unwrap();
322        assert!(config.validate().is_err());
323    }
324
325    #[test]
326    fn validate_good_config_passes() {
327        let config = AgentConfig::from_json(r#"{"name": "test", "temperature": 0.7}"#).unwrap();
328        assert!(config.validate().is_ok());
329    }
330
331    #[test]
332    fn is_workflow_detection() {
333        let sequential =
334            AgentConfig::from_json(r#"{"name": "seq", "agent_type": "sequential"}"#).unwrap();
335        assert!(sequential.is_workflow());
336
337        let llm = AgentConfig::from_json(r#"{"name": "llm"}"#).unwrap();
338        assert!(!llm.is_workflow());
339    }
340
341    #[test]
342    fn tool_config_variants() {
343        let custom = ToolConfig {
344            name: Some("my_tool".into()),
345            description: Some("Does stuff".into()),
346            builtin: None,
347            parameters: Some(serde_json::json!({"type": "object"})),
348        };
349        assert!(custom.name.is_some());
350        assert!(custom.builtin.is_none());
351
352        let builtin = ToolConfig {
353            name: None,
354            description: None,
355            builtin: Some("google_search".into()),
356            parameters: None,
357        };
358        assert!(builtin.builtin.is_some());
359    }
360}