Skip to main content

a2a_agents/core/
config.rs

1//! Agent configuration with TOML support
2//!
3//! This module provides declarative configuration for A2A agents via TOML files.
4//! It supports environment variable interpolation and sensible defaults.
5
6use serde::{Deserialize, Serialize};
7use std::path::Path;
8use thiserror::Error;
9
10#[derive(Debug, Error)]
11pub enum ConfigError {
12    #[error("Failed to read config file: {0}")]
13    IoError(#[from] std::io::Error),
14    #[error("Failed to parse TOML: {0}")]
15    TomlError(#[from] toml::de::Error),
16    #[error("Environment variable not found: {0}")]
17    EnvVarError(String),
18    #[error("Invalid configuration: {0}")]
19    ValidationError(String),
20}
21
22/// Complete agent configuration from TOML
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct AgentConfig {
25    /// Agent metadata
26    pub agent: AgentMetadata,
27
28    /// Server configuration
29    #[serde(default)]
30    pub server: ServerConfig,
31
32    /// Skills exposed by the agent
33    #[serde(default)]
34    pub skills: Vec<SkillConfig>,
35
36    /// Features enabled for the agent
37    #[serde(default)]
38    pub features: FeaturesConfig,
39}
40
41impl AgentConfig {
42    /// Load configuration from a TOML file
43    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
44        let content = std::fs::read_to_string(path)?;
45        Self::from_toml(&content)
46    }
47
48    /// Parse configuration from TOML string
49    pub fn from_toml(content: &str) -> Result<Self, ConfigError> {
50        // Expand environment variables
51        let expanded = expand_env_vars(content)?;
52        let config: AgentConfig = toml::from_str(&expanded)?;
53        config.validate()?;
54        Ok(config)
55    }
56
57    /// Validate the configuration
58    pub fn validate(&self) -> Result<(), ConfigError> {
59        if self.agent.name.is_empty() {
60            return Err(ConfigError::ValidationError(
61                "Agent name cannot be empty".to_string(),
62            ));
63        }
64
65        if self.server.http_port == 0 && self.server.ws_port == 0 {
66            return Err(ConfigError::ValidationError(
67                "At least one server port must be configured".to_string(),
68            ));
69        }
70
71        // Validate skills
72        for skill in &self.skills {
73            if skill.id.is_empty() {
74                return Err(ConfigError::ValidationError(
75                    "Skill ID cannot be empty".to_string(),
76                ));
77            }
78        }
79
80        Ok(())
81    }
82
83    /// Build agent card URL from server config
84    pub fn agent_url(&self) -> String {
85        format!("http://{}:{}", self.server.host, self.server.http_port)
86    }
87}
88
89/// Agent metadata and identity
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct AgentMetadata {
92    /// Agent name
93    pub name: String,
94
95    /// Agent description
96    #[serde(default)]
97    pub description: Option<String>,
98
99    /// Agent version
100    #[serde(default)]
101    pub version: Option<String>,
102
103    /// Provider information
104    #[serde(default)]
105    pub provider: Option<ProviderInfo>,
106
107    /// Documentation URL
108    #[serde(default)]
109    pub documentation_url: Option<String>,
110}
111
112/// Provider information
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct ProviderInfo {
115    pub name: String,
116    pub url: String,
117}
118
119/// Server configuration
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct ServerConfig {
122    /// Host to bind to
123    #[serde(default = "default_host")]
124    pub host: String,
125
126    /// HTTP server port (0 to disable)
127    #[serde(default = "default_http_port")]
128    pub http_port: u16,
129
130    /// WebSocket server port (0 to disable)
131    #[serde(default = "default_ws_port")]
132    pub ws_port: u16,
133
134    /// Storage configuration
135    #[serde(default)]
136    pub storage: StorageConfig,
137
138    /// Authentication configuration
139    #[serde(default)]
140    pub auth: AuthConfig,
141}
142
143impl Default for ServerConfig {
144    fn default() -> Self {
145        Self {
146            host: default_host(),
147            http_port: default_http_port(),
148            ws_port: default_ws_port(),
149            storage: StorageConfig::default(),
150            auth: AuthConfig::default(),
151        }
152    }
153}
154
155/// Storage backend configuration
156#[derive(Debug, Clone, Default, Serialize, Deserialize)]
157#[serde(tag = "type", rename_all = "lowercase")]
158pub enum StorageConfig {
159    /// In-memory storage (default)
160    #[default]
161    InMemory,
162
163    /// SQLx-based persistent storage
164    Sqlx {
165        /// Database URL (supports env vars like ${DATABASE_URL})
166        url: String,
167
168        /// Maximum number of connections in the pool
169        #[serde(default = "default_max_connections")]
170        max_connections: u32,
171
172        /// Enable SQL query logging
173        #[serde(default)]
174        enable_logging: bool,
175    },
176}
177
178/// Authentication configuration
179#[derive(Debug, Clone, Default, Serialize, Deserialize)]
180#[serde(tag = "type", rename_all = "lowercase")]
181pub enum AuthConfig {
182    /// No authentication (default for development)
183    #[default]
184    None,
185
186    /// Bearer token authentication
187    Bearer {
188        /// List of valid tokens (supports env vars)
189        tokens: Vec<String>,
190
191        /// Optional bearer format description (e.g., "JWT")
192        #[serde(skip_serializing_if = "Option::is_none")]
193        format: Option<String>,
194    },
195
196    /// API Key authentication
197    ApiKey {
198        /// Valid API keys
199        keys: Vec<String>,
200
201        /// Location of the API key: "header", "query", or "cookie"
202        #[serde(default = "default_api_key_location")]
203        location: String,
204
205        /// Name of the header/query param/cookie
206        #[serde(default = "default_api_key_name")]
207        name: String,
208    },
209
210    /// JWT (JSON Web Token) authentication
211    Jwt {
212        /// JWT secret for HMAC algorithms (HS256, HS384, HS512)
213        /// Use ${ENV_VAR} for environment variables
214        #[serde(skip_serializing_if = "Option::is_none")]
215        secret: Option<String>,
216
217        /// RSA public key in PEM format for RSA algorithms (RS256, RS384, RS512)
218        #[serde(skip_serializing_if = "Option::is_none")]
219        rsa_pem_path: Option<String>,
220
221        /// Algorithm to use (HS256, HS384, HS512, RS256, RS384, RS512)
222        #[serde(default = "default_jwt_algorithm")]
223        algorithm: String,
224
225        /// Required issuer (iss claim)
226        #[serde(skip_serializing_if = "Option::is_none")]
227        issuer: Option<String>,
228
229        /// Required audience (aud claim)
230        #[serde(skip_serializing_if = "Option::is_none")]
231        audience: Option<String>,
232    },
233
234    /// OAuth2 authentication
235    OAuth2 {
236        /// Client ID
237        client_id: String,
238
239        /// Client secret (use ${ENV_VAR} for environment variables)
240        client_secret: String,
241
242        /// Authorization URL
243        authorization_url: String,
244
245        /// Token URL
246        token_url: String,
247
248        /// Redirect URL for authorization code flow
249        #[serde(skip_serializing_if = "Option::is_none")]
250        redirect_url: Option<String>,
251
252        /// OAuth2 flow type: "authorization_code" or "client_credentials"
253        #[serde(default = "default_oauth2_flow")]
254        flow: String,
255
256        /// Required scopes
257        #[serde(default)]
258        scopes: Vec<String>,
259    },
260}
261
262/// Skill configuration
263#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct SkillConfig {
265    /// Unique skill identifier
266    pub id: String,
267
268    /// Human-readable skill name
269    pub name: String,
270
271    /// Skill description
272    #[serde(default)]
273    pub description: Option<String>,
274
275    /// Keywords for skill discovery
276    #[serde(default)]
277    pub keywords: Vec<String>,
278
279    /// Example queries for this skill
280    #[serde(default)]
281    pub examples: Vec<String>,
282
283    /// Supported input formats
284    #[serde(default = "default_formats")]
285    pub input_formats: Vec<String>,
286
287    /// Supported output formats
288    #[serde(default = "default_formats")]
289    pub output_formats: Vec<String>,
290}
291
292/// Features configuration
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct FeaturesConfig {
295    /// Enable streaming updates
296    #[serde(default)]
297    pub streaming: bool,
298
299    /// Enable push notifications
300    #[serde(default)]
301    pub push_notifications: bool,
302
303    /// Enable state transition history
304    #[serde(default)]
305    pub state_history: bool,
306
307    /// Enable authenticated extended card
308    #[serde(default)]
309    pub authenticated_card: bool,
310
311    /// Protocol extensions (AP2, etc.)
312    #[serde(default)]
313    pub extensions: ExtensionsConfig,
314
315    /// MCP server configuration (expose agent as MCP server)
316    #[serde(default)]
317    pub mcp_server: McpServerConfig,
318
319    /// MCP client configuration (connect to MCP servers to use their tools)
320    #[serde(default)]
321    pub mcp_client: McpClientConfig,
322}
323
324impl Default for FeaturesConfig {
325    fn default() -> Self {
326        Self {
327            streaming: true,
328            push_notifications: true,
329            state_history: true,
330            authenticated_card: false,
331            extensions: ExtensionsConfig::default(),
332            mcp_server: McpServerConfig::default(),
333            mcp_client: McpClientConfig::default(),
334        }
335    }
336}
337
338/// Protocol extensions configuration
339#[derive(Debug, Clone, Default, Serialize, Deserialize)]
340pub struct ExtensionsConfig {
341    /// AP2 (Agent Payments Protocol) extension
342    #[serde(default)]
343    pub ap2: Option<Ap2ExtensionConfig>,
344}
345
346/// AP2 extension configuration
347#[derive(Debug, Clone, Serialize, Deserialize)]
348pub struct Ap2ExtensionConfig {
349    /// AP2 roles this agent performs (merchant, shopper, credentials-provider, payment-processor)
350    pub roles: Vec<String>,
351
352    /// Whether clients must understand AP2 to interact with this agent
353    #[serde(default)]
354    pub required: bool,
355}
356
357/// MCP server configuration
358#[derive(Debug, Clone, Serialize, Deserialize)]
359pub struct McpServerConfig {
360    /// Enable MCP server (expose agent as MCP tools)
361    #[serde(default)]
362    pub enabled: bool,
363
364    /// Use stdio transport (for Claude Desktop integration)
365    #[serde(default = "default_true")]
366    pub stdio: bool,
367
368    /// Server name (defaults to agent name)
369    #[serde(skip_serializing_if = "Option::is_none")]
370    pub name: Option<String>,
371
372    /// Server version (defaults to agent version)
373    #[serde(skip_serializing_if = "Option::is_none")]
374    pub version: Option<String>,
375}
376
377impl Default for McpServerConfig {
378    fn default() -> Self {
379        Self {
380            enabled: false,
381            stdio: true,
382            name: None,
383            version: None,
384        }
385    }
386}
387
388fn default_true() -> bool {
389    true
390}
391
392/// MCP client configuration
393#[derive(Debug, Clone, Default, Serialize, Deserialize)]
394pub struct McpClientConfig {
395    /// Enable MCP client (connect to MCP servers to use their tools)
396    #[serde(default)]
397    pub enabled: bool,
398
399    /// MCP servers to connect to
400    #[serde(default)]
401    pub servers: Vec<McpServerConnection>,
402}
403
404/// Configuration for connecting to an MCP server
405#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct McpServerConnection {
407    /// Unique name for this MCP server
408    pub name: String,
409
410    /// Command to run to start the MCP server
411    pub command: String,
412
413    /// Arguments to pass to the command
414    #[serde(default)]
415    pub args: Vec<String>,
416
417    /// Environment variables to set
418    #[serde(default)]
419    pub env: std::collections::HashMap<String, String>,
420
421    /// Working directory for the command
422    #[serde(skip_serializing_if = "Option::is_none")]
423    pub cwd: Option<String>,
424}
425
426// Default value functions
427
428fn default_host() -> String {
429    std::env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string())
430}
431
432fn default_http_port() -> u16 {
433    std::env::var("HTTP_PORT")
434        .ok()
435        .and_then(|s| s.parse().ok())
436        .unwrap_or(8080)
437}
438
439fn default_ws_port() -> u16 {
440    std::env::var("WS_PORT")
441        .ok()
442        .and_then(|s| s.parse().ok())
443        .unwrap_or(8081)
444}
445
446fn default_max_connections() -> u32 {
447    10
448}
449
450fn default_jwt_algorithm() -> String {
451    "HS256".to_string()
452}
453
454fn default_oauth2_flow() -> String {
455    "authorization_code".to_string()
456}
457
458fn default_api_key_location() -> String {
459    "header".to_string()
460}
461
462fn default_api_key_name() -> String {
463    "X-API-Key".to_string()
464}
465
466fn default_formats() -> Vec<String> {
467    vec!["text".to_string(), "data".to_string()]
468}
469
470/// Expand environment variables in the config string
471/// Supports ${VAR_NAME} and ${VAR_NAME:-default} syntax
472fn expand_env_vars(content: &str) -> Result<String, ConfigError> {
473    use std::sync::LazyLock;
474    static ENV_VAR_RE: LazyLock<regex::Regex> =
475        LazyLock::new(|| regex::Regex::new(r"\$\{([A-Z_][A-Z0-9_]*)\}").unwrap());
476
477    let mut result = content.to_string();
478    let re = &*ENV_VAR_RE;
479
480    for cap in re.captures_iter(content) {
481        let full_match = &cap[0];
482        let var_name = &cap[1];
483
484        let value =
485            std::env::var(var_name).map_err(|_| ConfigError::EnvVarError(var_name.to_string()))?;
486
487        result = result.replace(full_match, &value);
488    }
489
490    Ok(result)
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496
497    #[test]
498    fn test_minimal_config() {
499        let toml = r#"
500            [agent]
501            name = "Test Agent"
502        "#;
503
504        let config = AgentConfig::from_toml(toml).unwrap();
505        assert_eq!(config.agent.name, "Test Agent");
506        assert_eq!(config.server.http_port, 8080);
507    }
508
509    #[test]
510    fn test_complete_config() {
511        let toml = r#"
512            [agent]
513            name = "Reimbursement Agent"
514            description = "Handles employee reimbursements"
515            version = "1.0.0"
516
517            [agent.provider]
518            name = "Example Corp"
519            url = "https://example.com"
520
521            [server]
522            host = "0.0.0.0"
523            http_port = 3000
524            ws_port = 3001
525
526            [server.storage]
527            type = "sqlx"
528            url = "sqlite:test.db"
529            max_connections = 5
530            enable_logging = true
531
532            [server.auth]
533            type = "bearer"
534            tokens = ["token123"]
535            format = "JWT"
536
537            [[skills]]
538            id = "process_expense"
539            name = "Process Expense"
540            description = "Process expense reimbursements"
541            keywords = ["expense", "reimbursement"]
542            examples = ["Reimburse my $50 lunch"]
543            input_formats = ["text", "data"]
544            output_formats = ["text", "data"]
545
546            [features]
547            streaming = true
548            push_notifications = true
549            state_history = true
550            authenticated_card = false
551        "#;
552
553        let config = AgentConfig::from_toml(toml).unwrap();
554        assert_eq!(config.agent.name, "Reimbursement Agent");
555        assert_eq!(config.server.http_port, 3000);
556        assert_eq!(config.skills.len(), 1);
557        assert_eq!(config.skills[0].id, "process_expense");
558        assert!(config.features.streaming);
559    }
560
561    #[test]
562    fn test_env_var_expansion() {
563        // SAFETY: This is a test function run in a controlled environment
564        // We're setting an environment variable that won't affect other tests
565        unsafe {
566            std::env::set_var("TEST_TOKEN", "secret123");
567        }
568
569        let content = r#"
570            [server.auth]
571            type = "bearer"
572            tokens = ["${TEST_TOKEN}"]
573        "#;
574
575        let expanded = expand_env_vars(content).unwrap();
576        assert!(expanded.contains("secret123"));
577    }
578
579    #[test]
580    #[cfg(feature = "auth")]
581    fn test_jwt_auth_config() {
582        let toml = r#"
583            [agent]
584            name = "JWT Agent"
585
586            [server.auth]
587            type = "jwt"
588            secret = "my-jwt-secret"
589            algorithm = "HS256"
590            issuer = "https://auth.example.com"
591            audience = "api://my-agent"
592        "#;
593
594        let config = AgentConfig::from_toml(toml).unwrap();
595        match &config.server.auth {
596            AuthConfig::Jwt {
597                secret,
598                algorithm,
599                issuer,
600                audience,
601                ..
602            } => {
603                assert_eq!(secret.as_ref().unwrap(), "my-jwt-secret");
604                assert_eq!(algorithm, "HS256");
605                assert_eq!(issuer.as_ref().unwrap(), "https://auth.example.com");
606                assert_eq!(audience.as_ref().unwrap(), "api://my-agent");
607            }
608            _ => panic!("Expected JWT auth config"),
609        }
610    }
611
612    #[test]
613    #[cfg(feature = "auth")]
614    fn test_oauth2_auth_config() {
615        let toml = r#"
616            [agent]
617            name = "OAuth2 Agent"
618
619            [server.auth]
620            type = "oauth2"
621            client_id = "my-client-id"
622            client_secret = "my-client-secret"
623            authorization_url = "https://provider.com/auth"
624            token_url = "https://provider.com/token"
625            flow = "authorization_code"
626            scopes = ["read", "write"]
627        "#;
628
629        let config = AgentConfig::from_toml(toml).unwrap();
630        match &config.server.auth {
631            AuthConfig::OAuth2 {
632                client_id,
633                client_secret,
634                flow,
635                scopes,
636                ..
637            } => {
638                assert_eq!(client_id, "my-client-id");
639                assert_eq!(client_secret, "my-client-secret");
640                assert_eq!(flow, "authorization_code");
641                assert_eq!(scopes.len(), 2);
642                assert_eq!(scopes[0], "read");
643            }
644            _ => panic!("Expected OAuth2 auth config"),
645        }
646    }
647
648    #[test]
649    fn test_validation_empty_name() {
650        let toml = r#"
651            [agent]
652            name = ""
653        "#;
654
655        let result = AgentConfig::from_toml(toml);
656        assert!(result.is_err());
657    }
658
659    #[test]
660    fn test_ap2_extension_config() {
661        let toml = r#"
662            [agent]
663            name = "Merchant Agent"
664
665            [features.extensions.ap2]
666            roles = ["merchant", "payment-processor"]
667            required = true
668        "#;
669
670        let config = AgentConfig::from_toml(toml).unwrap();
671        let ap2 = config.features.extensions.ap2.unwrap();
672        assert_eq!(ap2.roles, vec!["merchant", "payment-processor"]);
673        assert!(ap2.required);
674    }
675
676    #[test]
677    fn test_ap2_extension_config_optional() {
678        let toml = r#"
679            [agent]
680            name = "Plain Agent"
681        "#;
682
683        let config = AgentConfig::from_toml(toml).unwrap();
684        assert!(config.features.extensions.ap2.is_none());
685    }
686}