1mod cluster;
2mod database;
3
4pub use cluster::ClusterConfig;
5pub use database::DatabaseConfig;
6
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9
10use crate::error::{ForgeError, Result};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ForgeConfig {
15 #[serde(default)]
17 pub project: ProjectConfig,
18
19 pub database: DatabaseConfig,
21
22 #[serde(default)]
24 pub node: NodeConfig,
25
26 #[serde(default)]
28 pub gateway: GatewayConfig,
29
30 #[serde(default)]
32 pub function: FunctionConfig,
33
34 #[serde(default)]
36 pub worker: WorkerConfig,
37
38 #[serde(default)]
40 pub cluster: ClusterConfig,
41
42 #[serde(default)]
44 pub security: SecurityConfig,
45
46 #[serde(default)]
48 pub auth: AuthConfig,
49
50 #[serde(default)]
52 pub observability: ObservabilityConfig,
53
54 #[serde(default)]
56 pub mcp: McpConfig,
57}
58
59impl ForgeConfig {
60 pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
62 let content = std::fs::read_to_string(path.as_ref())
63 .map_err(|e| ForgeError::Config(format!("Failed to read config file: {}", e)))?;
64
65 Self::parse_toml(&content)
66 }
67
68 pub fn parse_toml(content: &str) -> Result<Self> {
70 let content = substitute_env_vars(content);
72
73 let config: Self = toml::from_str(&content)
74 .map_err(|e| ForgeError::Config(format!("Failed to parse config: {}", e)))?;
75
76 config.validate()?;
77 Ok(config)
78 }
79
80 pub fn validate(&self) -> Result<()> {
82 self.database.validate()?;
83 self.auth.validate()?;
84 self.mcp.validate()?;
85 Ok(())
86 }
87
88 pub fn default_with_database_url(url: &str) -> Self {
90 Self {
91 project: ProjectConfig::default(),
92 database: DatabaseConfig::new(url),
93 node: NodeConfig::default(),
94 gateway: GatewayConfig::default(),
95 function: FunctionConfig::default(),
96 worker: WorkerConfig::default(),
97 cluster: ClusterConfig::default(),
98 security: SecurityConfig::default(),
99 auth: AuthConfig::default(),
100 observability: ObservabilityConfig::default(),
101 mcp: McpConfig::default(),
102 }
103 }
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct ProjectConfig {
109 #[serde(default = "default_project_name")]
111 pub name: String,
112
113 #[serde(default = "default_version")]
115 pub version: String,
116}
117
118impl Default for ProjectConfig {
119 fn default() -> Self {
120 Self {
121 name: default_project_name(),
122 version: default_version(),
123 }
124 }
125}
126
127fn default_project_name() -> String {
128 "forge-app".to_string()
129}
130
131fn default_version() -> String {
132 "0.1.0".to_string()
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct NodeConfig {
138 #[serde(default = "default_roles")]
140 pub roles: Vec<NodeRole>,
141
142 #[serde(default = "default_capabilities")]
144 pub worker_capabilities: Vec<String>,
145}
146
147impl Default for NodeConfig {
148 fn default() -> Self {
149 Self {
150 roles: default_roles(),
151 worker_capabilities: default_capabilities(),
152 }
153 }
154}
155
156fn default_roles() -> Vec<NodeRole> {
157 vec![
158 NodeRole::Gateway,
159 NodeRole::Function,
160 NodeRole::Worker,
161 NodeRole::Scheduler,
162 ]
163}
164
165fn default_capabilities() -> Vec<String> {
166 vec!["general".to_string()]
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
171#[serde(rename_all = "lowercase")]
172pub enum NodeRole {
173 Gateway,
174 Function,
175 Worker,
176 Scheduler,
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct GatewayConfig {
182 #[serde(default = "default_http_port")]
184 pub port: u16,
185
186 #[serde(default = "default_grpc_port")]
188 pub grpc_port: u16,
189
190 #[serde(default = "default_max_connections")]
192 pub max_connections: usize,
193
194 #[serde(default = "default_request_timeout")]
196 pub request_timeout_secs: u64,
197
198 #[serde(default = "default_cors_enabled")]
200 pub cors_enabled: bool,
201
202 #[serde(default = "default_cors_origins")]
204 pub cors_origins: Vec<String>,
205}
206
207impl Default for GatewayConfig {
208 fn default() -> Self {
209 Self {
210 port: default_http_port(),
211 grpc_port: default_grpc_port(),
212 max_connections: default_max_connections(),
213 request_timeout_secs: default_request_timeout(),
214 cors_enabled: default_cors_enabled(),
215 cors_origins: default_cors_origins(),
216 }
217 }
218}
219
220fn default_http_port() -> u16 {
221 8080
222}
223
224fn default_grpc_port() -> u16 {
225 9000
226}
227
228fn default_max_connections() -> usize {
229 512
230}
231
232fn default_request_timeout() -> u64 {
233 30
234}
235
236fn default_cors_enabled() -> bool {
237 false
238}
239
240fn default_cors_origins() -> Vec<String> {
241 Vec::new()
242}
243
244#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct FunctionConfig {
247 #[serde(default = "default_max_concurrent")]
249 pub max_concurrent: usize,
250
251 #[serde(default = "default_function_timeout")]
253 pub timeout_secs: u64,
254
255 #[serde(default = "default_memory_limit")]
257 pub memory_limit: usize,
258}
259
260impl Default for FunctionConfig {
261 fn default() -> Self {
262 Self {
263 max_concurrent: default_max_concurrent(),
264 timeout_secs: default_function_timeout(),
265 memory_limit: default_memory_limit(),
266 }
267 }
268}
269
270fn default_max_concurrent() -> usize {
271 1000
272}
273
274fn default_function_timeout() -> u64 {
275 30
276}
277
278fn default_memory_limit() -> usize {
279 512 * 1024 * 1024 }
281
282#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct WorkerConfig {
285 #[serde(default = "default_max_concurrent_jobs")]
287 pub max_concurrent_jobs: usize,
288
289 #[serde(default = "default_job_timeout")]
291 pub job_timeout_secs: u64,
292
293 #[serde(default = "default_poll_interval")]
295 pub poll_interval_ms: u64,
296}
297
298impl Default for WorkerConfig {
299 fn default() -> Self {
300 Self {
301 max_concurrent_jobs: default_max_concurrent_jobs(),
302 job_timeout_secs: default_job_timeout(),
303 poll_interval_ms: default_poll_interval(),
304 }
305 }
306}
307
308fn default_max_concurrent_jobs() -> usize {
309 50
310}
311
312fn default_job_timeout() -> u64 {
313 3600 }
315
316fn default_poll_interval() -> u64 {
317 100
318}
319
320#[derive(Debug, Clone, Serialize, Deserialize, Default)]
322pub struct SecurityConfig {
323 pub secret_key: Option<String>,
325}
326
327#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
329#[serde(rename_all = "UPPERCASE")]
330pub enum JwtAlgorithm {
331 #[default]
333 HS256,
334 HS384,
336 HS512,
338 RS256,
340 RS384,
342 RS512,
344}
345
346#[derive(Debug, Clone, Serialize, Deserialize)]
348pub struct AuthConfig {
349 pub jwt_secret: Option<String>,
352
353 #[serde(default)]
357 pub jwt_algorithm: JwtAlgorithm,
358
359 pub jwt_issuer: Option<String>,
362
363 pub jwt_audience: Option<String>,
366
367 pub token_expiry: Option<String>,
369
370 pub jwks_url: Option<String>,
373
374 #[serde(default = "default_jwks_cache_ttl")]
376 pub jwks_cache_ttl_secs: u64,
377
378 #[serde(default = "default_session_ttl")]
380 pub session_ttl_secs: u64,
381}
382
383impl Default for AuthConfig {
384 fn default() -> Self {
385 Self {
386 jwt_secret: None,
387 jwt_algorithm: JwtAlgorithm::default(),
388 jwt_issuer: None,
389 jwt_audience: None,
390 token_expiry: None,
391 jwks_url: None,
392 jwks_cache_ttl_secs: default_jwks_cache_ttl(),
393 session_ttl_secs: default_session_ttl(),
394 }
395 }
396}
397
398impl AuthConfig {
399 fn is_configured(&self) -> bool {
401 self.jwt_secret.is_some()
402 || self.jwks_url.is_some()
403 || self.jwt_issuer.is_some()
404 || self.jwt_audience.is_some()
405 }
406
407 pub fn validate(&self) -> Result<()> {
410 if !self.is_configured() {
411 return Ok(());
412 }
413
414 match self.jwt_algorithm {
415 JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512 => {
416 if self.jwt_secret.is_none() {
417 return Err(ForgeError::Config(
418 "auth.jwt_secret is required for HMAC algorithms (HS256, HS384, HS512). \
419 Set auth.jwt_secret to a secure random string, \
420 or switch to RS256 and provide auth.jwks_url for external identity providers."
421 .into(),
422 ));
423 }
424 }
425 JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512 => {
426 if self.jwks_url.is_none() {
427 return Err(ForgeError::Config(
428 "auth.jwks_url is required for RSA algorithms (RS256, RS384, RS512). \
429 Set auth.jwks_url to your identity provider's JWKS endpoint, \
430 or switch to HS256 and provide auth.jwt_secret for symmetric signing."
431 .into(),
432 ));
433 }
434 }
435 }
436 Ok(())
437 }
438
439 pub fn is_hmac(&self) -> bool {
441 matches!(
442 self.jwt_algorithm,
443 JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
444 )
445 }
446
447 pub fn is_rsa(&self) -> bool {
449 matches!(
450 self.jwt_algorithm,
451 JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
452 )
453 }
454}
455
456fn default_jwks_cache_ttl() -> u64 {
457 3600 }
459
460fn default_session_ttl() -> u64 {
461 7 * 24 * 60 * 60 }
463
464#[derive(Debug, Clone, Serialize, Deserialize)]
466pub struct ObservabilityConfig {
467 #[serde(default)]
469 pub enabled: bool,
470
471 #[serde(default = "default_otlp_endpoint")]
473 pub otlp_endpoint: String,
474
475 pub service_name: Option<String>,
477
478 #[serde(default = "default_true")]
480 pub enable_traces: bool,
481
482 #[serde(default = "default_true")]
484 pub enable_metrics: bool,
485
486 #[serde(default = "default_true")]
488 pub enable_logs: bool,
489
490 #[serde(default = "default_sampling_ratio")]
492 pub sampling_ratio: f64,
493
494 #[serde(default = "default_log_level")]
496 pub log_level: String,
497}
498
499impl Default for ObservabilityConfig {
500 fn default() -> Self {
501 Self {
502 enabled: false,
503 otlp_endpoint: default_otlp_endpoint(),
504 service_name: None,
505 enable_traces: true,
506 enable_metrics: true,
507 enable_logs: true,
508 sampling_ratio: default_sampling_ratio(),
509 log_level: default_log_level(),
510 }
511 }
512}
513
514fn default_otlp_endpoint() -> String {
515 "http://localhost:4317".to_string()
516}
517
518fn default_true() -> bool {
519 true
520}
521
522fn default_sampling_ratio() -> f64 {
523 1.0
524}
525
526fn default_log_level() -> String {
527 "info".to_string()
528}
529
530#[derive(Debug, Clone, Serialize, Deserialize)]
532pub struct McpConfig {
533 #[serde(default)]
535 pub enabled: bool,
536
537 #[serde(default = "default_mcp_path")]
539 pub path: String,
540
541 #[serde(default = "default_mcp_session_ttl_secs")]
543 pub session_ttl_secs: u64,
544
545 #[serde(default)]
547 pub allowed_origins: Vec<String>,
548
549 #[serde(default = "default_true")]
551 pub require_protocol_version_header: bool,
552}
553
554impl Default for McpConfig {
555 fn default() -> Self {
556 Self {
557 enabled: false,
558 path: default_mcp_path(),
559 session_ttl_secs: default_mcp_session_ttl_secs(),
560 allowed_origins: Vec::new(),
561 require_protocol_version_header: default_true(),
562 }
563 }
564}
565
566impl McpConfig {
567 pub fn validate(&self) -> Result<()> {
568 if self.path.is_empty() || !self.path.starts_with('/') {
569 return Err(ForgeError::Config(
570 "mcp.path must start with '/' (example: /mcp)".to_string(),
571 ));
572 }
573 if self.path.contains(' ') {
574 return Err(ForgeError::Config(
575 "mcp.path cannot contain spaces".to_string(),
576 ));
577 }
578 if self.session_ttl_secs == 0 {
579 return Err(ForgeError::Config(
580 "mcp.session_ttl_secs must be greater than 0".to_string(),
581 ));
582 }
583 Ok(())
584 }
585}
586
587fn default_mcp_path() -> String {
588 "/mcp".to_string()
589}
590
591fn default_mcp_session_ttl_secs() -> u64 {
592 60 * 60
593}
594
595fn substitute_env_vars(content: &str) -> String {
597 let mut result = content.to_string();
598 let re = regex_lite::Regex::new(r"\$\{([A-Z_][A-Z0-9_]*)\}").expect("valid regex pattern");
599
600 for cap in re.captures_iter(content) {
601 let var_name = &cap[1];
602 if let Ok(value) = std::env::var(var_name) {
603 result = result.replace(&cap[0], &value);
604 }
605 }
606
607 result
608}
609
610#[cfg(test)]
611#[allow(clippy::unwrap_used, clippy::indexing_slicing, unsafe_code)]
612mod tests {
613 use super::*;
614
615 #[test]
616 fn test_default_config() {
617 let config = ForgeConfig::default_with_database_url("postgres://localhost/test");
618 assert_eq!(config.gateway.port, 8080);
619 assert_eq!(config.node.roles.len(), 4);
620 assert_eq!(config.mcp.path, "/mcp");
621 assert!(!config.mcp.enabled);
622 }
623
624 #[test]
625 fn test_parse_minimal_config() {
626 let toml = r#"
627 [database]
628 url = "postgres://localhost/myapp"
629 "#;
630
631 let config = ForgeConfig::parse_toml(toml).unwrap();
632 assert_eq!(config.database.url(), "postgres://localhost/myapp");
633 assert_eq!(config.gateway.port, 8080);
634 }
635
636 #[test]
637 fn test_parse_full_config() {
638 let toml = r#"
639 [project]
640 name = "my-app"
641 version = "1.0.0"
642
643 [database]
644 url = "postgres://localhost/myapp"
645 pool_size = 100
646
647 [node]
648 roles = ["gateway", "worker"]
649 worker_capabilities = ["media", "general"]
650
651 [gateway]
652 port = 3000
653 grpc_port = 9001
654 "#;
655
656 let config = ForgeConfig::parse_toml(toml).unwrap();
657 assert_eq!(config.project.name, "my-app");
658 assert_eq!(config.database.pool_size, 100);
659 assert_eq!(config.node.roles.len(), 2);
660 assert_eq!(config.gateway.port, 3000);
661 }
662
663 #[test]
664 fn test_env_var_substitution() {
665 unsafe {
666 std::env::set_var("TEST_DB_URL", "postgres://test:test@localhost/test");
667 }
668
669 let toml = r#"
670 [database]
671 url = "${TEST_DB_URL}"
672 "#;
673
674 let config = ForgeConfig::parse_toml(toml).unwrap();
675 assert_eq!(config.database.url(), "postgres://test:test@localhost/test");
676
677 unsafe {
678 std::env::remove_var("TEST_DB_URL");
679 }
680 }
681
682 #[test]
683 fn test_auth_validation_no_config() {
684 let auth = AuthConfig::default();
685 assert!(auth.validate().is_ok());
686 }
687
688 #[test]
689 fn test_auth_validation_hmac_with_secret() {
690 let auth = AuthConfig {
691 jwt_secret: Some("my-secret".into()),
692 jwt_algorithm: JwtAlgorithm::HS256,
693 ..Default::default()
694 };
695 assert!(auth.validate().is_ok());
696 }
697
698 #[test]
699 fn test_auth_validation_hmac_missing_secret() {
700 let auth = AuthConfig {
701 jwt_issuer: Some("my-issuer".into()),
702 jwt_algorithm: JwtAlgorithm::HS256,
703 ..Default::default()
704 };
705 let result = auth.validate();
706 assert!(result.is_err());
707 let err_msg = result.unwrap_err().to_string();
708 assert!(err_msg.contains("jwt_secret is required"));
709 }
710
711 #[test]
712 fn test_auth_validation_rsa_with_jwks() {
713 let auth = AuthConfig {
714 jwks_url: Some("https://example.com/.well-known/jwks.json".into()),
715 jwt_algorithm: JwtAlgorithm::RS256,
716 ..Default::default()
717 };
718 assert!(auth.validate().is_ok());
719 }
720
721 #[test]
722 fn test_auth_validation_rsa_missing_jwks() {
723 let auth = AuthConfig {
724 jwt_issuer: Some("my-issuer".into()),
725 jwt_algorithm: JwtAlgorithm::RS256,
726 ..Default::default()
727 };
728 let result = auth.validate();
729 assert!(result.is_err());
730 let err_msg = result.unwrap_err().to_string();
731 assert!(err_msg.contains("jwks_url is required"));
732 }
733
734 #[test]
735 fn test_forge_config_validation_fails_on_empty_url() {
736 let toml = r#"
737 [database]
738
739 url = ""
740 "#;
741
742 let result = ForgeConfig::parse_toml(toml);
743 assert!(result.is_err());
744 let err_msg = result.unwrap_err().to_string();
745 assert!(err_msg.contains("database.url is required"));
746 }
747
748 #[test]
749 fn test_forge_config_validation_fails_on_invalid_auth() {
750 let toml = r#"
751 [database]
752
753 url = "postgres://localhost/test"
754
755 [auth]
756 jwt_issuer = "my-issuer"
757 jwt_algorithm = "RS256"
758 "#;
759
760 let result = ForgeConfig::parse_toml(toml);
761 assert!(result.is_err());
762 let err_msg = result.unwrap_err().to_string();
763 assert!(err_msg.contains("jwks_url is required"));
764 }
765
766 #[test]
767 fn test_mcp_config_validation_rejects_invalid_path() {
768 let toml = r#"
769 [database]
770
771 url = "postgres://localhost/test"
772
773 [mcp]
774 enabled = true
775 path = "mcp"
776 "#;
777
778 let result = ForgeConfig::parse_toml(toml);
779 assert!(result.is_err());
780 let err_msg = result.unwrap_err().to_string();
781 assert!(err_msg.contains("mcp.path must start with '/'"));
782 }
783}