1use serde::{Deserialize, Serialize};
7use std::path::Path;
8use thiserror::Error;
9
10#[derive(Debug, Error)]
11pub enum ConfigError {
12 #[error("Failed to read config file: {0}")]
13 IoError(#[from] std::io::Error),
14 #[error("Failed to parse TOML: {0}")]
15 TomlError(#[from] toml::de::Error),
16 #[error("Environment variable not found: {0}")]
17 EnvVarError(String),
18 #[error("Invalid configuration: {0}")]
19 ValidationError(String),
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct AgentConfig {
25 pub agent: AgentMetadata,
27
28 #[serde(default)]
30 pub server: ServerConfig,
31
32 #[serde(default)]
34 pub skills: Vec<SkillConfig>,
35
36 #[serde(default)]
38 pub features: FeaturesConfig,
39}
40
41impl AgentConfig {
42 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
44 let content = std::fs::read_to_string(path)?;
45 Self::from_toml(&content)
46 }
47
48 pub fn from_toml(content: &str) -> Result<Self, ConfigError> {
50 let expanded = expand_env_vars(content)?;
52 let config: AgentConfig = toml::from_str(&expanded)?;
53 config.validate()?;
54 Ok(config)
55 }
56
57 pub fn validate(&self) -> Result<(), ConfigError> {
59 if self.agent.name.is_empty() {
60 return Err(ConfigError::ValidationError(
61 "Agent name cannot be empty".to_string(),
62 ));
63 }
64
65 if self.server.http_port == 0 && self.server.ws_port == 0 {
66 return Err(ConfigError::ValidationError(
67 "At least one server port must be configured".to_string(),
68 ));
69 }
70
71 for skill in &self.skills {
73 if skill.id.is_empty() {
74 return Err(ConfigError::ValidationError(
75 "Skill ID cannot be empty".to_string(),
76 ));
77 }
78 }
79
80 Ok(())
81 }
82
83 pub fn agent_url(&self) -> String {
85 format!("http://{}:{}", self.server.host, self.server.http_port)
86 }
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct AgentMetadata {
92 pub name: String,
94
95 #[serde(default)]
97 pub description: Option<String>,
98
99 #[serde(default)]
101 pub version: Option<String>,
102
103 #[serde(default)]
105 pub provider: Option<ProviderInfo>,
106
107 #[serde(default)]
109 pub documentation_url: Option<String>,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct ProviderInfo {
115 pub name: String,
116 pub url: String,
117}
118
119#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct ServerConfig {
122 #[serde(default = "default_host")]
124 pub host: String,
125
126 #[serde(default = "default_http_port")]
128 pub http_port: u16,
129
130 #[serde(default = "default_ws_port")]
132 pub ws_port: u16,
133
134 #[serde(default)]
136 pub storage: StorageConfig,
137
138 #[serde(default)]
140 pub auth: AuthConfig,
141}
142
143impl Default for ServerConfig {
144 fn default() -> Self {
145 Self {
146 host: default_host(),
147 http_port: default_http_port(),
148 ws_port: default_ws_port(),
149 storage: StorageConfig::default(),
150 auth: AuthConfig::default(),
151 }
152 }
153}
154
155#[derive(Debug, Clone, Default, Serialize, Deserialize)]
157#[serde(tag = "type", rename_all = "lowercase")]
158pub enum StorageConfig {
159 #[default]
161 InMemory,
162
163 Sqlx {
165 url: String,
167
168 #[serde(default = "default_max_connections")]
170 max_connections: u32,
171
172 #[serde(default)]
174 enable_logging: bool,
175 },
176}
177
178#[derive(Debug, Clone, Default, Serialize, Deserialize)]
180#[serde(tag = "type", rename_all = "lowercase")]
181pub enum AuthConfig {
182 #[default]
184 None,
185
186 Bearer {
188 tokens: Vec<String>,
190
191 #[serde(skip_serializing_if = "Option::is_none")]
193 format: Option<String>,
194 },
195
196 ApiKey {
198 keys: Vec<String>,
200
201 #[serde(default = "default_api_key_location")]
203 location: String,
204
205 #[serde(default = "default_api_key_name")]
207 name: String,
208 },
209
210 Jwt {
212 #[serde(skip_serializing_if = "Option::is_none")]
215 secret: Option<String>,
216
217 #[serde(skip_serializing_if = "Option::is_none")]
219 rsa_pem_path: Option<String>,
220
221 #[serde(default = "default_jwt_algorithm")]
223 algorithm: String,
224
225 #[serde(skip_serializing_if = "Option::is_none")]
227 issuer: Option<String>,
228
229 #[serde(skip_serializing_if = "Option::is_none")]
231 audience: Option<String>,
232 },
233
234 OAuth2 {
236 client_id: String,
238
239 client_secret: String,
241
242 authorization_url: String,
244
245 token_url: String,
247
248 #[serde(skip_serializing_if = "Option::is_none")]
250 redirect_url: Option<String>,
251
252 #[serde(default = "default_oauth2_flow")]
254 flow: String,
255
256 #[serde(default)]
258 scopes: Vec<String>,
259 },
260}
261
262#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct SkillConfig {
265 pub id: String,
267
268 pub name: String,
270
271 #[serde(default)]
273 pub description: Option<String>,
274
275 #[serde(default)]
277 pub keywords: Vec<String>,
278
279 #[serde(default)]
281 pub examples: Vec<String>,
282
283 #[serde(default = "default_formats")]
285 pub input_formats: Vec<String>,
286
287 #[serde(default = "default_formats")]
289 pub output_formats: Vec<String>,
290}
291
292#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct FeaturesConfig {
295 #[serde(default)]
297 pub streaming: bool,
298
299 #[serde(default)]
301 pub push_notifications: bool,
302
303 #[serde(default)]
305 pub state_history: bool,
306
307 #[serde(default)]
309 pub authenticated_card: bool,
310
311 #[serde(default)]
313 pub extensions: ExtensionsConfig,
314
315 #[serde(default)]
317 pub mcp_server: McpServerConfig,
318
319 #[serde(default)]
321 pub mcp_client: McpClientConfig,
322}
323
324impl Default for FeaturesConfig {
325 fn default() -> Self {
326 Self {
327 streaming: true,
328 push_notifications: true,
329 state_history: true,
330 authenticated_card: false,
331 extensions: ExtensionsConfig::default(),
332 mcp_server: McpServerConfig::default(),
333 mcp_client: McpClientConfig::default(),
334 }
335 }
336}
337
338#[derive(Debug, Clone, Default, Serialize, Deserialize)]
340pub struct ExtensionsConfig {
341 #[serde(default)]
343 pub ap2: Option<Ap2ExtensionConfig>,
344}
345
346#[derive(Debug, Clone, Serialize, Deserialize)]
348pub struct Ap2ExtensionConfig {
349 pub roles: Vec<String>,
351
352 #[serde(default)]
354 pub required: bool,
355}
356
357#[derive(Debug, Clone, Serialize, Deserialize)]
359pub struct McpServerConfig {
360 #[serde(default)]
362 pub enabled: bool,
363
364 #[serde(default = "default_true")]
366 pub stdio: bool,
367
368 #[serde(skip_serializing_if = "Option::is_none")]
370 pub name: Option<String>,
371
372 #[serde(skip_serializing_if = "Option::is_none")]
374 pub version: Option<String>,
375}
376
377impl Default for McpServerConfig {
378 fn default() -> Self {
379 Self {
380 enabled: false,
381 stdio: true,
382 name: None,
383 version: None,
384 }
385 }
386}
387
388fn default_true() -> bool {
389 true
390}
391
392#[derive(Debug, Clone, Default, Serialize, Deserialize)]
394pub struct McpClientConfig {
395 #[serde(default)]
397 pub enabled: bool,
398
399 #[serde(default)]
401 pub servers: Vec<McpServerConnection>,
402}
403
404#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct McpServerConnection {
407 pub name: String,
409
410 pub command: String,
412
413 #[serde(default)]
415 pub args: Vec<String>,
416
417 #[serde(default)]
419 pub env: std::collections::HashMap<String, String>,
420
421 #[serde(skip_serializing_if = "Option::is_none")]
423 pub cwd: Option<String>,
424}
425
426fn default_host() -> String {
429 std::env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string())
430}
431
432fn default_http_port() -> u16 {
433 std::env::var("HTTP_PORT")
434 .ok()
435 .and_then(|s| s.parse().ok())
436 .unwrap_or(8080)
437}
438
439fn default_ws_port() -> u16 {
440 std::env::var("WS_PORT")
441 .ok()
442 .and_then(|s| s.parse().ok())
443 .unwrap_or(8081)
444}
445
446fn default_max_connections() -> u32 {
447 10
448}
449
450fn default_jwt_algorithm() -> String {
451 "HS256".to_string()
452}
453
454fn default_oauth2_flow() -> String {
455 "authorization_code".to_string()
456}
457
458fn default_api_key_location() -> String {
459 "header".to_string()
460}
461
462fn default_api_key_name() -> String {
463 "X-API-Key".to_string()
464}
465
466fn default_formats() -> Vec<String> {
467 vec!["text".to_string(), "data".to_string()]
468}
469
470fn expand_env_vars(content: &str) -> Result<String, ConfigError> {
473 use std::sync::LazyLock;
474 static ENV_VAR_RE: LazyLock<regex::Regex> =
475 LazyLock::new(|| regex::Regex::new(r"\$\{([A-Z_][A-Z0-9_]*)\}").unwrap());
476
477 let mut result = content.to_string();
478 let re = &*ENV_VAR_RE;
479
480 for cap in re.captures_iter(content) {
481 let full_match = &cap[0];
482 let var_name = &cap[1];
483
484 let value =
485 std::env::var(var_name).map_err(|_| ConfigError::EnvVarError(var_name.to_string()))?;
486
487 result = result.replace(full_match, &value);
488 }
489
490 Ok(result)
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496
497 #[test]
498 fn test_minimal_config() {
499 let toml = r#"
500 [agent]
501 name = "Test Agent"
502 "#;
503
504 let config = AgentConfig::from_toml(toml).unwrap();
505 assert_eq!(config.agent.name, "Test Agent");
506 assert_eq!(config.server.http_port, 8080);
507 }
508
509 #[test]
510 fn test_complete_config() {
511 let toml = r#"
512 [agent]
513 name = "Reimbursement Agent"
514 description = "Handles employee reimbursements"
515 version = "1.0.0"
516
517 [agent.provider]
518 name = "Example Corp"
519 url = "https://example.com"
520
521 [server]
522 host = "0.0.0.0"
523 http_port = 3000
524 ws_port = 3001
525
526 [server.storage]
527 type = "sqlx"
528 url = "sqlite:test.db"
529 max_connections = 5
530 enable_logging = true
531
532 [server.auth]
533 type = "bearer"
534 tokens = ["token123"]
535 format = "JWT"
536
537 [[skills]]
538 id = "process_expense"
539 name = "Process Expense"
540 description = "Process expense reimbursements"
541 keywords = ["expense", "reimbursement"]
542 examples = ["Reimburse my $50 lunch"]
543 input_formats = ["text", "data"]
544 output_formats = ["text", "data"]
545
546 [features]
547 streaming = true
548 push_notifications = true
549 state_history = true
550 authenticated_card = false
551 "#;
552
553 let config = AgentConfig::from_toml(toml).unwrap();
554 assert_eq!(config.agent.name, "Reimbursement Agent");
555 assert_eq!(config.server.http_port, 3000);
556 assert_eq!(config.skills.len(), 1);
557 assert_eq!(config.skills[0].id, "process_expense");
558 assert!(config.features.streaming);
559 }
560
561 #[test]
562 fn test_env_var_expansion() {
563 unsafe {
566 std::env::set_var("TEST_TOKEN", "secret123");
567 }
568
569 let content = r#"
570 [server.auth]
571 type = "bearer"
572 tokens = ["${TEST_TOKEN}"]
573 "#;
574
575 let expanded = expand_env_vars(content).unwrap();
576 assert!(expanded.contains("secret123"));
577 }
578
579 #[test]
580 #[cfg(feature = "auth")]
581 fn test_jwt_auth_config() {
582 let toml = r#"
583 [agent]
584 name = "JWT Agent"
585
586 [server.auth]
587 type = "jwt"
588 secret = "my-jwt-secret"
589 algorithm = "HS256"
590 issuer = "https://auth.example.com"
591 audience = "api://my-agent"
592 "#;
593
594 let config = AgentConfig::from_toml(toml).unwrap();
595 match &config.server.auth {
596 AuthConfig::Jwt {
597 secret,
598 algorithm,
599 issuer,
600 audience,
601 ..
602 } => {
603 assert_eq!(secret.as_ref().unwrap(), "my-jwt-secret");
604 assert_eq!(algorithm, "HS256");
605 assert_eq!(issuer.as_ref().unwrap(), "https://auth.example.com");
606 assert_eq!(audience.as_ref().unwrap(), "api://my-agent");
607 }
608 _ => panic!("Expected JWT auth config"),
609 }
610 }
611
612 #[test]
613 #[cfg(feature = "auth")]
614 fn test_oauth2_auth_config() {
615 let toml = r#"
616 [agent]
617 name = "OAuth2 Agent"
618
619 [server.auth]
620 type = "oauth2"
621 client_id = "my-client-id"
622 client_secret = "my-client-secret"
623 authorization_url = "https://provider.com/auth"
624 token_url = "https://provider.com/token"
625 flow = "authorization_code"
626 scopes = ["read", "write"]
627 "#;
628
629 let config = AgentConfig::from_toml(toml).unwrap();
630 match &config.server.auth {
631 AuthConfig::OAuth2 {
632 client_id,
633 client_secret,
634 flow,
635 scopes,
636 ..
637 } => {
638 assert_eq!(client_id, "my-client-id");
639 assert_eq!(client_secret, "my-client-secret");
640 assert_eq!(flow, "authorization_code");
641 assert_eq!(scopes.len(), 2);
642 assert_eq!(scopes[0], "read");
643 }
644 _ => panic!("Expected OAuth2 auth config"),
645 }
646 }
647
648 #[test]
649 fn test_validation_empty_name() {
650 let toml = r#"
651 [agent]
652 name = ""
653 "#;
654
655 let result = AgentConfig::from_toml(toml);
656 assert!(result.is_err());
657 }
658
659 #[test]
660 fn test_ap2_extension_config() {
661 let toml = r#"
662 [agent]
663 name = "Merchant Agent"
664
665 [features.extensions.ap2]
666 roles = ["merchant", "payment-processor"]
667 required = true
668 "#;
669
670 let config = AgentConfig::from_toml(toml).unwrap();
671 let ap2 = config.features.extensions.ap2.unwrap();
672 assert_eq!(ap2.roles, vec!["merchant", "payment-processor"]);
673 assert!(ap2.required);
674 }
675
676 #[test]
677 fn test_ap2_extension_config_optional() {
678 let toml = r#"
679 [agent]
680 name = "Plain Agent"
681 "#;
682
683 let config = AgentConfig::from_toml(toml).unwrap();
684 assert!(config.features.extensions.ap2.is_none());
685 }
686}