1mod cluster;
2mod database;
3
4pub use cluster::ClusterConfig;
5pub use database::{DatabaseConfig, PoolConfig};
6
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9
10use crate::error::{ForgeError, Result};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ForgeConfig {
15 #[serde(default)]
17 pub project: ProjectConfig,
18
19 pub database: DatabaseConfig,
21
22 #[serde(default)]
24 pub node: NodeConfig,
25
26 #[serde(default)]
28 pub gateway: GatewayConfig,
29
30 #[serde(default)]
32 pub function: FunctionConfig,
33
34 #[serde(default)]
36 pub worker: WorkerConfig,
37
38 #[serde(default)]
40 pub cluster: ClusterConfig,
41
42 #[serde(default)]
44 pub security: SecurityConfig,
45
46 #[serde(default)]
48 pub auth: AuthConfig,
49
50 #[serde(default)]
52 pub observability: ObservabilityConfig,
53
54 #[serde(default)]
56 pub mcp: McpConfig,
57}
58
59impl ForgeConfig {
60 pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
62 let content = std::fs::read_to_string(path.as_ref())
63 .map_err(|e| ForgeError::Config(format!("Failed to read config file: {}", e)))?;
64
65 Self::parse_toml(&content)
66 }
67
68 pub fn parse_toml(content: &str) -> Result<Self> {
70 let content = substitute_env_vars(content);
72
73 let config: Self = toml::from_str(&content)
74 .map_err(|e| ForgeError::Config(format!("Failed to parse config: {}", e)))?;
75
76 config.validate()?;
77 Ok(config)
78 }
79
80 pub fn validate(&self) -> Result<()> {
82 self.database.validate()?;
83 self.auth.validate()?;
84 self.mcp.validate()?;
85 Ok(())
86 }
87
88 pub fn default_with_database_url(url: &str) -> Self {
90 Self {
91 project: ProjectConfig::default(),
92 database: DatabaseConfig::new(url),
93 node: NodeConfig::default(),
94 gateway: GatewayConfig::default(),
95 function: FunctionConfig::default(),
96 worker: WorkerConfig::default(),
97 cluster: ClusterConfig::default(),
98 security: SecurityConfig::default(),
99 auth: AuthConfig::default(),
100 observability: ObservabilityConfig::default(),
101 mcp: McpConfig::default(),
102 }
103 }
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct ProjectConfig {
109 #[serde(default = "default_project_name")]
111 pub name: String,
112
113 #[serde(default = "default_version")]
115 pub version: String,
116}
117
118impl Default for ProjectConfig {
119 fn default() -> Self {
120 Self {
121 name: default_project_name(),
122 version: default_version(),
123 }
124 }
125}
126
127fn default_project_name() -> String {
128 "forge-app".to_string()
129}
130
131fn default_version() -> String {
132 "0.1.0".to_string()
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct NodeConfig {
138 #[serde(default = "default_roles")]
140 pub roles: Vec<NodeRole>,
141
142 #[serde(default = "default_capabilities")]
144 pub worker_capabilities: Vec<String>,
145}
146
147impl Default for NodeConfig {
148 fn default() -> Self {
149 Self {
150 roles: default_roles(),
151 worker_capabilities: default_capabilities(),
152 }
153 }
154}
155
156fn default_roles() -> Vec<NodeRole> {
157 vec![
158 NodeRole::Gateway,
159 NodeRole::Function,
160 NodeRole::Worker,
161 NodeRole::Scheduler,
162 ]
163}
164
165fn default_capabilities() -> Vec<String> {
166 vec!["general".to_string()]
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
171#[serde(rename_all = "lowercase")]
172pub enum NodeRole {
173 Gateway,
174 Function,
175 Worker,
176 Scheduler,
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct GatewayConfig {
182 #[serde(default = "default_http_port")]
184 pub port: u16,
185
186 #[serde(default = "default_grpc_port")]
188 pub grpc_port: u16,
189
190 #[serde(default = "default_max_connections")]
192 pub max_connections: usize,
193
194 #[serde(default = "default_request_timeout")]
196 pub request_timeout_secs: u64,
197
198 #[serde(default = "default_cors_enabled")]
200 pub cors_enabled: bool,
201
202 #[serde(default = "default_cors_origins")]
204 pub cors_origins: Vec<String>,
205
206 #[serde(default = "default_quiet_routes")]
209 pub quiet_routes: Vec<String>,
210}
211
212impl Default for GatewayConfig {
213 fn default() -> Self {
214 Self {
215 port: default_http_port(),
216 grpc_port: default_grpc_port(),
217 max_connections: default_max_connections(),
218 request_timeout_secs: default_request_timeout(),
219 cors_enabled: default_cors_enabled(),
220 cors_origins: default_cors_origins(),
221 quiet_routes: default_quiet_routes(),
222 }
223 }
224}
225
226fn default_http_port() -> u16 {
227 8080
228}
229
230fn default_grpc_port() -> u16 {
231 9000
232}
233
234fn default_max_connections() -> usize {
235 512
236}
237
238fn default_request_timeout() -> u64 {
239 30
240}
241
242fn default_cors_enabled() -> bool {
243 false
244}
245
246fn default_cors_origins() -> Vec<String> {
247 Vec::new()
248}
249
250fn default_quiet_routes() -> Vec<String> {
251 vec!["/_api/health".to_string(), "/_api/ready".to_string()]
252}
253
254#[derive(Debug, Clone, Serialize, Deserialize)]
256pub struct FunctionConfig {
257 #[serde(default = "default_max_concurrent")]
259 pub max_concurrent: usize,
260
261 #[serde(default = "default_function_timeout")]
263 pub timeout_secs: u64,
264
265 #[serde(default = "default_memory_limit")]
267 pub memory_limit: usize,
268}
269
270impl Default for FunctionConfig {
271 fn default() -> Self {
272 Self {
273 max_concurrent: default_max_concurrent(),
274 timeout_secs: default_function_timeout(),
275 memory_limit: default_memory_limit(),
276 }
277 }
278}
279
280fn default_max_concurrent() -> usize {
281 1000
282}
283
284fn default_function_timeout() -> u64 {
285 30
286}
287
288fn default_memory_limit() -> usize {
289 512 * 1024 * 1024 }
291
292#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct WorkerConfig {
295 #[serde(default = "default_max_concurrent_jobs")]
297 pub max_concurrent_jobs: usize,
298
299 #[serde(default = "default_job_timeout")]
301 pub job_timeout_secs: u64,
302
303 #[serde(default = "default_poll_interval")]
305 pub poll_interval_ms: u64,
306}
307
308impl Default for WorkerConfig {
309 fn default() -> Self {
310 Self {
311 max_concurrent_jobs: default_max_concurrent_jobs(),
312 job_timeout_secs: default_job_timeout(),
313 poll_interval_ms: default_poll_interval(),
314 }
315 }
316}
317
318fn default_max_concurrent_jobs() -> usize {
319 50
320}
321
322fn default_job_timeout() -> u64 {
323 3600 }
325
326fn default_poll_interval() -> u64 {
327 100
328}
329
330#[derive(Debug, Clone, Serialize, Deserialize, Default)]
332pub struct SecurityConfig {
333 pub secret_key: Option<String>,
335}
336
337#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
339#[serde(rename_all = "UPPERCASE")]
340pub enum JwtAlgorithm {
341 #[default]
343 HS256,
344 HS384,
346 HS512,
348 RS256,
350 RS384,
352 RS512,
354}
355
356#[derive(Debug, Clone, Serialize, Deserialize)]
358pub struct AuthConfig {
359 pub jwt_secret: Option<String>,
362
363 #[serde(default)]
367 pub jwt_algorithm: JwtAlgorithm,
368
369 pub jwt_issuer: Option<String>,
372
373 pub jwt_audience: Option<String>,
376
377 pub token_expiry: Option<String>,
379
380 pub jwks_url: Option<String>,
383
384 #[serde(default = "default_jwks_cache_ttl")]
386 pub jwks_cache_ttl_secs: u64,
387
388 #[serde(default = "default_session_ttl")]
390 pub session_ttl_secs: u64,
391}
392
393impl Default for AuthConfig {
394 fn default() -> Self {
395 Self {
396 jwt_secret: None,
397 jwt_algorithm: JwtAlgorithm::default(),
398 jwt_issuer: None,
399 jwt_audience: None,
400 token_expiry: None,
401 jwks_url: None,
402 jwks_cache_ttl_secs: default_jwks_cache_ttl(),
403 session_ttl_secs: default_session_ttl(),
404 }
405 }
406}
407
408impl AuthConfig {
409 fn is_configured(&self) -> bool {
411 self.jwt_secret.is_some()
412 || self.jwks_url.is_some()
413 || self.jwt_issuer.is_some()
414 || self.jwt_audience.is_some()
415 }
416
417 pub fn validate(&self) -> Result<()> {
420 if !self.is_configured() {
421 return Ok(());
422 }
423
424 match self.jwt_algorithm {
425 JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512 => {
426 if self.jwt_secret.is_none() {
427 return Err(ForgeError::Config(
428 "auth.jwt_secret is required for HMAC algorithms (HS256, HS384, HS512). \
429 Set auth.jwt_secret to a secure random string, \
430 or switch to RS256 and provide auth.jwks_url for external identity providers."
431 .into(),
432 ));
433 }
434 }
435 JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512 => {
436 if self.jwks_url.is_none() {
437 return Err(ForgeError::Config(
438 "auth.jwks_url is required for RSA algorithms (RS256, RS384, RS512). \
439 Set auth.jwks_url to your identity provider's JWKS endpoint, \
440 or switch to HS256 and provide auth.jwt_secret for symmetric signing."
441 .into(),
442 ));
443 }
444 }
445 }
446 Ok(())
447 }
448
449 pub fn is_hmac(&self) -> bool {
451 matches!(
452 self.jwt_algorithm,
453 JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
454 )
455 }
456
457 pub fn is_rsa(&self) -> bool {
459 matches!(
460 self.jwt_algorithm,
461 JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
462 )
463 }
464}
465
466fn default_jwks_cache_ttl() -> u64 {
467 3600 }
469
470fn default_session_ttl() -> u64 {
471 7 * 24 * 60 * 60 }
473
474#[derive(Debug, Clone, Serialize, Deserialize)]
476pub struct ObservabilityConfig {
477 #[serde(default)]
479 pub enabled: bool,
480
481 #[serde(default = "default_otlp_endpoint")]
483 pub otlp_endpoint: String,
484
485 pub service_name: Option<String>,
487
488 #[serde(default = "default_true")]
490 pub enable_traces: bool,
491
492 #[serde(default = "default_true")]
494 pub enable_metrics: bool,
495
496 #[serde(default = "default_true")]
498 pub enable_logs: bool,
499
500 #[serde(default = "default_sampling_ratio")]
502 pub sampling_ratio: f64,
503
504 #[serde(default = "default_log_level")]
506 pub log_level: String,
507}
508
509impl Default for ObservabilityConfig {
510 fn default() -> Self {
511 Self {
512 enabled: false,
513 otlp_endpoint: default_otlp_endpoint(),
514 service_name: None,
515 enable_traces: true,
516 enable_metrics: true,
517 enable_logs: true,
518 sampling_ratio: default_sampling_ratio(),
519 log_level: default_log_level(),
520 }
521 }
522}
523
524fn default_otlp_endpoint() -> String {
525 "http://localhost:4318".to_string()
526}
527
528fn default_true() -> bool {
529 true
530}
531
532fn default_sampling_ratio() -> f64 {
533 1.0
534}
535
536fn default_log_level() -> String {
537 "info".to_string()
538}
539
540#[derive(Debug, Clone, Serialize, Deserialize)]
542pub struct McpConfig {
543 #[serde(default)]
545 pub enabled: bool,
546
547 #[serde(default = "default_mcp_path")]
549 pub path: String,
550
551 #[serde(default = "default_mcp_session_ttl_secs")]
553 pub session_ttl_secs: u64,
554
555 #[serde(default)]
557 pub allowed_origins: Vec<String>,
558
559 #[serde(default = "default_true")]
561 pub require_protocol_version_header: bool,
562}
563
564impl Default for McpConfig {
565 fn default() -> Self {
566 Self {
567 enabled: false,
568 path: default_mcp_path(),
569 session_ttl_secs: default_mcp_session_ttl_secs(),
570 allowed_origins: Vec::new(),
571 require_protocol_version_header: default_true(),
572 }
573 }
574}
575
576impl McpConfig {
577 pub fn validate(&self) -> Result<()> {
578 if self.path.is_empty() || !self.path.starts_with('/') {
579 return Err(ForgeError::Config(
580 "mcp.path must start with '/' (example: /mcp)".to_string(),
581 ));
582 }
583 if self.path.contains(' ') {
584 return Err(ForgeError::Config(
585 "mcp.path cannot contain spaces".to_string(),
586 ));
587 }
588 if self.session_ttl_secs == 0 {
589 return Err(ForgeError::Config(
590 "mcp.session_ttl_secs must be greater than 0".to_string(),
591 ));
592 }
593 Ok(())
594 }
595}
596
597fn default_mcp_path() -> String {
598 "/mcp".to_string()
599}
600
601fn default_mcp_session_ttl_secs() -> u64 {
602 60 * 60
603}
604
605#[allow(clippy::indexing_slicing)]
607fn substitute_env_vars(content: &str) -> String {
608 let mut result = String::with_capacity(content.len());
609 let bytes = content.as_bytes();
610 let len = bytes.len();
611 let mut i = 0;
612
613 while i < len {
614 if i + 1 < len
615 && bytes[i] == b'$'
616 && bytes[i + 1] == b'{'
617 && let Some(end) = content[i + 2..].find('}')
618 {
619 let var_name = &content[i + 2..i + 2 + end];
620 if is_valid_env_var_name(var_name) {
621 if let Ok(value) = std::env::var(var_name) {
622 result.push_str(&value);
623 } else {
624 result.push_str(&content[i..i + 2 + end + 1]);
625 }
626 i += 2 + end + 1;
627 continue;
628 }
629 }
630 result.push(bytes[i] as char);
631 i += 1;
632 }
633
634 result
635}
636
637fn is_valid_env_var_name(name: &str) -> bool {
638 let first = match name.as_bytes().first() {
639 Some(b) => b,
640 None => return false,
641 };
642 (first.is_ascii_uppercase() || *first == b'_')
643 && name
644 .bytes()
645 .all(|b| b.is_ascii_uppercase() || b.is_ascii_digit() || b == b'_')
646}
647
648#[cfg(test)]
649#[allow(clippy::unwrap_used, clippy::indexing_slicing, unsafe_code)]
650mod tests {
651 use super::*;
652
653 #[test]
654 fn test_default_config() {
655 let config = ForgeConfig::default_with_database_url("postgres://localhost/test");
656 assert_eq!(config.gateway.port, 8080);
657 assert_eq!(config.node.roles.len(), 4);
658 assert_eq!(config.mcp.path, "/mcp");
659 assert!(!config.mcp.enabled);
660 }
661
662 #[test]
663 fn test_parse_minimal_config() {
664 let toml = r#"
665 [database]
666 url = "postgres://localhost/myapp"
667 "#;
668
669 let config = ForgeConfig::parse_toml(toml).unwrap();
670 assert_eq!(config.database.url(), "postgres://localhost/myapp");
671 assert_eq!(config.gateway.port, 8080);
672 }
673
674 #[test]
675 fn test_parse_full_config() {
676 let toml = r#"
677 [project]
678 name = "my-app"
679 version = "1.0.0"
680
681 [database]
682 url = "postgres://localhost/myapp"
683 pool_size = 100
684
685 [node]
686 roles = ["gateway", "worker"]
687 worker_capabilities = ["media", "general"]
688
689 [gateway]
690 port = 3000
691 grpc_port = 9001
692 "#;
693
694 let config = ForgeConfig::parse_toml(toml).unwrap();
695 assert_eq!(config.project.name, "my-app");
696 assert_eq!(config.database.pool_size, 100);
697 assert_eq!(config.node.roles.len(), 2);
698 assert_eq!(config.gateway.port, 3000);
699 }
700
701 #[test]
702 fn test_env_var_substitution() {
703 unsafe {
704 std::env::set_var("TEST_DB_URL", "postgres://test:test@localhost/test");
705 }
706
707 let toml = r#"
708 [database]
709 url = "${TEST_DB_URL}"
710 "#;
711
712 let config = ForgeConfig::parse_toml(toml).unwrap();
713 assert_eq!(config.database.url(), "postgres://test:test@localhost/test");
714
715 unsafe {
716 std::env::remove_var("TEST_DB_URL");
717 }
718 }
719
720 #[test]
721 fn test_auth_validation_no_config() {
722 let auth = AuthConfig::default();
723 assert!(auth.validate().is_ok());
724 }
725
726 #[test]
727 fn test_auth_validation_hmac_with_secret() {
728 let auth = AuthConfig {
729 jwt_secret: Some("my-secret".into()),
730 jwt_algorithm: JwtAlgorithm::HS256,
731 ..Default::default()
732 };
733 assert!(auth.validate().is_ok());
734 }
735
736 #[test]
737 fn test_auth_validation_hmac_missing_secret() {
738 let auth = AuthConfig {
739 jwt_issuer: Some("my-issuer".into()),
740 jwt_algorithm: JwtAlgorithm::HS256,
741 ..Default::default()
742 };
743 let result = auth.validate();
744 assert!(result.is_err());
745 let err_msg = result.unwrap_err().to_string();
746 assert!(err_msg.contains("jwt_secret is required"));
747 }
748
749 #[test]
750 fn test_auth_validation_rsa_with_jwks() {
751 let auth = AuthConfig {
752 jwks_url: Some("https://example.com/.well-known/jwks.json".into()),
753 jwt_algorithm: JwtAlgorithm::RS256,
754 ..Default::default()
755 };
756 assert!(auth.validate().is_ok());
757 }
758
759 #[test]
760 fn test_auth_validation_rsa_missing_jwks() {
761 let auth = AuthConfig {
762 jwt_issuer: Some("my-issuer".into()),
763 jwt_algorithm: JwtAlgorithm::RS256,
764 ..Default::default()
765 };
766 let result = auth.validate();
767 assert!(result.is_err());
768 let err_msg = result.unwrap_err().to_string();
769 assert!(err_msg.contains("jwks_url is required"));
770 }
771
772 #[test]
773 fn test_forge_config_validation_fails_on_empty_url() {
774 let toml = r#"
775 [database]
776
777 url = ""
778 "#;
779
780 let result = ForgeConfig::parse_toml(toml);
781 assert!(result.is_err());
782 let err_msg = result.unwrap_err().to_string();
783 assert!(err_msg.contains("database.url is required"));
784 }
785
786 #[test]
787 fn test_forge_config_validation_fails_on_invalid_auth() {
788 let toml = r#"
789 [database]
790
791 url = "postgres://localhost/test"
792
793 [auth]
794 jwt_issuer = "my-issuer"
795 jwt_algorithm = "RS256"
796 "#;
797
798 let result = ForgeConfig::parse_toml(toml);
799 assert!(result.is_err());
800 let err_msg = result.unwrap_err().to_string();
801 assert!(err_msg.contains("jwks_url is required"));
802 }
803
804 #[test]
805 fn test_mcp_config_validation_rejects_invalid_path() {
806 let toml = r#"
807 [database]
808
809 url = "postgres://localhost/test"
810
811 [mcp]
812 enabled = true
813 path = "mcp"
814 "#;
815
816 let result = ForgeConfig::parse_toml(toml);
817 assert!(result.is_err());
818 let err_msg = result.unwrap_err().to_string();
819 assert!(err_msg.contains("mcp.path must start with '/'"));
820 }
821}