1mod cluster;
2mod database;
3
4pub use cluster::ClusterConfig;
5pub use database::{DatabaseConfig, PoolConfig};
6
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9
10use crate::error::{ForgeError, Result};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ForgeConfig {
15 #[serde(default)]
17 pub project: ProjectConfig,
18
19 pub database: DatabaseConfig,
21
22 #[serde(default)]
24 pub node: NodeConfig,
25
26 #[serde(default)]
28 pub gateway: GatewayConfig,
29
30 #[serde(default)]
32 pub function: FunctionConfig,
33
34 #[serde(default)]
36 pub worker: WorkerConfig,
37
38 #[serde(default)]
40 pub cluster: ClusterConfig,
41
42 #[serde(default)]
44 pub security: SecurityConfig,
45
46 #[serde(default)]
48 pub auth: AuthConfig,
49
50 #[serde(default)]
52 pub observability: ObservabilityConfig,
53
54 #[serde(default)]
56 pub mcp: McpConfig,
57}
58
59impl ForgeConfig {
60 pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
62 let content = std::fs::read_to_string(path.as_ref())
63 .map_err(|e| ForgeError::Config(format!("Failed to read config file: {}", e)))?;
64
65 Self::parse_toml(&content)
66 }
67
68 pub fn parse_toml(content: &str) -> Result<Self> {
70 let content = substitute_env_vars(content);
72
73 let config: Self = toml::from_str(&content)
74 .map_err(|e| ForgeError::Config(format!("Failed to parse config: {}", e)))?;
75
76 config.validate()?;
77 Ok(config)
78 }
79
80 pub fn validate(&self) -> Result<()> {
82 self.database.validate()?;
83 self.auth.validate()?;
84 self.mcp.validate()?;
85 Ok(())
86 }
87
88 pub fn default_with_database_url(url: &str) -> Self {
90 Self {
91 project: ProjectConfig::default(),
92 database: DatabaseConfig::new(url),
93 node: NodeConfig::default(),
94 gateway: GatewayConfig::default(),
95 function: FunctionConfig::default(),
96 worker: WorkerConfig::default(),
97 cluster: ClusterConfig::default(),
98 security: SecurityConfig::default(),
99 auth: AuthConfig::default(),
100 observability: ObservabilityConfig::default(),
101 mcp: McpConfig::default(),
102 }
103 }
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct ProjectConfig {
109 #[serde(default = "default_project_name")]
111 pub name: String,
112
113 #[serde(default = "default_version")]
115 pub version: String,
116}
117
118impl Default for ProjectConfig {
119 fn default() -> Self {
120 Self {
121 name: default_project_name(),
122 version: default_version(),
123 }
124 }
125}
126
127fn default_project_name() -> String {
128 "forge-app".to_string()
129}
130
131fn default_version() -> String {
132 "0.1.0".to_string()
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct NodeConfig {
138 #[serde(default = "default_roles")]
140 pub roles: Vec<NodeRole>,
141
142 #[serde(default = "default_capabilities")]
144 pub worker_capabilities: Vec<String>,
145}
146
147impl Default for NodeConfig {
148 fn default() -> Self {
149 Self {
150 roles: default_roles(),
151 worker_capabilities: default_capabilities(),
152 }
153 }
154}
155
156fn default_roles() -> Vec<NodeRole> {
157 vec![
158 NodeRole::Gateway,
159 NodeRole::Function,
160 NodeRole::Worker,
161 NodeRole::Scheduler,
162 ]
163}
164
165fn default_capabilities() -> Vec<String> {
166 vec!["general".to_string()]
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
171#[serde(rename_all = "lowercase")]
172pub enum NodeRole {
173 Gateway,
174 Function,
175 Worker,
176 Scheduler,
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct GatewayConfig {
182 #[serde(default = "default_http_port")]
184 pub port: u16,
185
186 #[serde(default = "default_grpc_port")]
188 pub grpc_port: u16,
189
190 #[serde(default = "default_max_connections")]
192 pub max_connections: usize,
193
194 #[serde(default = "default_request_timeout")]
196 pub request_timeout_secs: u64,
197
198 #[serde(default = "default_cors_enabled")]
200 pub cors_enabled: bool,
201
202 #[serde(default = "default_cors_origins")]
204 pub cors_origins: Vec<String>,
205
206 #[serde(default = "default_quiet_routes")]
209 pub quiet_routes: Vec<String>,
210}
211
212impl Default for GatewayConfig {
213 fn default() -> Self {
214 Self {
215 port: default_http_port(),
216 grpc_port: default_grpc_port(),
217 max_connections: default_max_connections(),
218 request_timeout_secs: default_request_timeout(),
219 cors_enabled: default_cors_enabled(),
220 cors_origins: default_cors_origins(),
221 quiet_routes: default_quiet_routes(),
222 }
223 }
224}
225
226fn default_http_port() -> u16 {
227 8080
228}
229
230fn default_grpc_port() -> u16 {
231 9000
232}
233
234fn default_max_connections() -> usize {
235 512
236}
237
238fn default_request_timeout() -> u64 {
239 30
240}
241
242fn default_cors_enabled() -> bool {
243 false
244}
245
246fn default_cors_origins() -> Vec<String> {
247 Vec::new()
248}
249
250fn default_quiet_routes() -> Vec<String> {
251 vec!["/_api/health".to_string(), "/_api/ready".to_string()]
252}
253
254#[derive(Debug, Clone, Serialize, Deserialize)]
256pub struct FunctionConfig {
257 #[serde(default = "default_max_concurrent")]
259 pub max_concurrent: usize,
260
261 #[serde(default = "default_function_timeout")]
263 pub timeout_secs: u64,
264
265 #[serde(default = "default_memory_limit")]
267 pub memory_limit: usize,
268}
269
270impl Default for FunctionConfig {
271 fn default() -> Self {
272 Self {
273 max_concurrent: default_max_concurrent(),
274 timeout_secs: default_function_timeout(),
275 memory_limit: default_memory_limit(),
276 }
277 }
278}
279
280fn default_max_concurrent() -> usize {
281 1000
282}
283
284fn default_function_timeout() -> u64 {
285 30
286}
287
288fn default_memory_limit() -> usize {
289 512 * 1024 * 1024 }
291
292#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct WorkerConfig {
295 #[serde(default = "default_max_concurrent_jobs")]
297 pub max_concurrent_jobs: usize,
298
299 #[serde(default = "default_job_timeout")]
301 pub job_timeout_secs: u64,
302
303 #[serde(default = "default_poll_interval")]
305 pub poll_interval_ms: u64,
306}
307
308impl Default for WorkerConfig {
309 fn default() -> Self {
310 Self {
311 max_concurrent_jobs: default_max_concurrent_jobs(),
312 job_timeout_secs: default_job_timeout(),
313 poll_interval_ms: default_poll_interval(),
314 }
315 }
316}
317
318fn default_max_concurrent_jobs() -> usize {
319 50
320}
321
322fn default_job_timeout() -> u64 {
323 3600 }
325
326fn default_poll_interval() -> u64 {
327 100
328}
329
330#[derive(Debug, Clone, Serialize, Deserialize, Default)]
332pub struct SecurityConfig {
333 pub secret_key: Option<String>,
335}
336
337#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
339#[serde(rename_all = "UPPERCASE")]
340pub enum JwtAlgorithm {
341 #[default]
343 HS256,
344 HS384,
346 HS512,
348 RS256,
350 RS384,
352 RS512,
354}
355
356#[derive(Debug, Clone, Serialize, Deserialize)]
358pub struct AuthConfig {
359 pub jwt_secret: Option<String>,
362
363 #[serde(default)]
367 pub jwt_algorithm: JwtAlgorithm,
368
369 pub jwt_issuer: Option<String>,
372
373 pub jwt_audience: Option<String>,
376
377 pub token_expiry: Option<String>,
379
380 pub jwks_url: Option<String>,
383
384 #[serde(default = "default_jwks_cache_ttl")]
386 pub jwks_cache_ttl_secs: u64,
387
388 #[serde(default = "default_session_ttl")]
390 pub session_ttl_secs: u64,
391}
392
393impl Default for AuthConfig {
394 fn default() -> Self {
395 Self {
396 jwt_secret: None,
397 jwt_algorithm: JwtAlgorithm::default(),
398 jwt_issuer: None,
399 jwt_audience: None,
400 token_expiry: None,
401 jwks_url: None,
402 jwks_cache_ttl_secs: default_jwks_cache_ttl(),
403 session_ttl_secs: default_session_ttl(),
404 }
405 }
406}
407
408impl AuthConfig {
409 fn is_configured(&self) -> bool {
411 self.jwt_secret.is_some()
412 || self.jwks_url.is_some()
413 || self.jwt_issuer.is_some()
414 || self.jwt_audience.is_some()
415 }
416
417 pub fn validate(&self) -> Result<()> {
420 if !self.is_configured() {
421 return Ok(());
422 }
423
424 match self.jwt_algorithm {
425 JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512 => {
426 if self.jwt_secret.is_none() {
427 return Err(ForgeError::Config(
428 "auth.jwt_secret is required for HMAC algorithms (HS256, HS384, HS512). \
429 Set auth.jwt_secret to a secure random string, \
430 or switch to RS256 and provide auth.jwks_url for external identity providers."
431 .into(),
432 ));
433 }
434 }
435 JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512 => {
436 if self.jwks_url.is_none() {
437 return Err(ForgeError::Config(
438 "auth.jwks_url is required for RSA algorithms (RS256, RS384, RS512). \
439 Set auth.jwks_url to your identity provider's JWKS endpoint, \
440 or switch to HS256 and provide auth.jwt_secret for symmetric signing."
441 .into(),
442 ));
443 }
444 }
445 }
446 Ok(())
447 }
448
449 pub fn is_hmac(&self) -> bool {
451 matches!(
452 self.jwt_algorithm,
453 JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
454 )
455 }
456
457 pub fn is_rsa(&self) -> bool {
459 matches!(
460 self.jwt_algorithm,
461 JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
462 )
463 }
464}
465
466fn default_jwks_cache_ttl() -> u64 {
467 3600 }
469
470fn default_session_ttl() -> u64 {
471 7 * 24 * 60 * 60 }
473
474#[derive(Debug, Clone, Serialize, Deserialize)]
476pub struct ObservabilityConfig {
477 #[serde(default)]
479 pub enabled: bool,
480
481 #[serde(default = "default_otlp_endpoint")]
483 pub otlp_endpoint: String,
484
485 pub service_name: Option<String>,
487
488 #[serde(default = "default_true")]
490 pub enable_traces: bool,
491
492 #[serde(default = "default_true")]
494 pub enable_metrics: bool,
495
496 #[serde(default = "default_true")]
498 pub enable_logs: bool,
499
500 #[serde(default = "default_sampling_ratio")]
502 pub sampling_ratio: f64,
503
504 #[serde(default = "default_log_level")]
506 pub log_level: String,
507}
508
509impl Default for ObservabilityConfig {
510 fn default() -> Self {
511 Self {
512 enabled: false,
513 otlp_endpoint: default_otlp_endpoint(),
514 service_name: None,
515 enable_traces: true,
516 enable_metrics: true,
517 enable_logs: true,
518 sampling_ratio: default_sampling_ratio(),
519 log_level: default_log_level(),
520 }
521 }
522}
523
524impl ObservabilityConfig {
525 pub fn otlp_active(&self) -> bool {
526 self.enabled && (self.enable_traces || self.enable_metrics || self.enable_logs)
527 }
528}
529
530fn default_otlp_endpoint() -> String {
531 "http://localhost:4318".to_string()
532}
533
534fn default_true() -> bool {
535 true
536}
537
538fn default_sampling_ratio() -> f64 {
539 1.0
540}
541
542fn default_log_level() -> String {
543 "info".to_string()
544}
545
546#[derive(Debug, Clone, Serialize, Deserialize)]
548pub struct McpConfig {
549 #[serde(default)]
551 pub enabled: bool,
552
553 #[serde(default = "default_mcp_path")]
555 pub path: String,
556
557 #[serde(default = "default_mcp_session_ttl_secs")]
559 pub session_ttl_secs: u64,
560
561 #[serde(default)]
563 pub allowed_origins: Vec<String>,
564
565 #[serde(default = "default_true")]
567 pub require_protocol_version_header: bool,
568}
569
570impl Default for McpConfig {
571 fn default() -> Self {
572 Self {
573 enabled: false,
574 path: default_mcp_path(),
575 session_ttl_secs: default_mcp_session_ttl_secs(),
576 allowed_origins: Vec::new(),
577 require_protocol_version_header: default_true(),
578 }
579 }
580}
581
582impl McpConfig {
583 pub fn validate(&self) -> Result<()> {
584 if self.path.is_empty() || !self.path.starts_with('/') {
585 return Err(ForgeError::Config(
586 "mcp.path must start with '/' (example: /mcp)".to_string(),
587 ));
588 }
589 if self.path.contains(' ') {
590 return Err(ForgeError::Config(
591 "mcp.path cannot contain spaces".to_string(),
592 ));
593 }
594 if self.session_ttl_secs == 0 {
595 return Err(ForgeError::Config(
596 "mcp.session_ttl_secs must be greater than 0".to_string(),
597 ));
598 }
599 Ok(())
600 }
601}
602
603fn default_mcp_path() -> String {
604 "/mcp".to_string()
605}
606
607fn default_mcp_session_ttl_secs() -> u64 {
608 60 * 60
609}
610
611#[allow(clippy::indexing_slicing)]
618pub fn substitute_env_vars(content: &str) -> String {
619 let mut result = String::with_capacity(content.len());
620 let bytes = content.as_bytes();
621 let len = bytes.len();
622 let mut i = 0;
623
624 while i < len {
625 if i + 1 < len
626 && bytes[i] == b'$'
627 && bytes[i + 1] == b'{'
628 && let Some(end) = content[i + 2..].find('}')
629 {
630 let inner = &content[i + 2..i + 2 + end];
631
632 let (var_name, default_value) = parse_var_with_default(inner);
634
635 if is_valid_env_var_name(var_name) {
636 if let Ok(value) = std::env::var(var_name) {
637 result.push_str(&value);
638 } else if let Some(default) = default_value {
639 result.push_str(default);
640 } else {
641 result.push_str(&content[i..i + 2 + end + 1]);
642 }
643 i += 2 + end + 1;
644 continue;
645 }
646 }
647 result.push(bytes[i] as char);
648 i += 1;
649 }
650
651 result
652}
653
654fn parse_var_with_default(inner: &str) -> (&str, Option<&str>) {
658 if let Some(pos) = inner.find(":-") {
659 return (&inner[..pos], Some(&inner[pos + 2..]));
660 }
661 if let Some(pos) = inner.find('-') {
662 return (&inner[..pos], Some(&inner[pos + 1..]));
663 }
664 (inner, None)
665}
666
667fn is_valid_env_var_name(name: &str) -> bool {
668 let first = match name.as_bytes().first() {
669 Some(b) => b,
670 None => return false,
671 };
672 (first.is_ascii_uppercase() || *first == b'_')
673 && name
674 .bytes()
675 .all(|b| b.is_ascii_uppercase() || b.is_ascii_digit() || b == b'_')
676}
677
678#[cfg(test)]
679#[allow(clippy::unwrap_used, clippy::indexing_slicing, unsafe_code)]
680mod tests {
681 use super::*;
682
683 #[test]
684 fn test_default_config() {
685 let config = ForgeConfig::default_with_database_url("postgres://localhost/test");
686 assert_eq!(config.gateway.port, 8080);
687 assert_eq!(config.node.roles.len(), 4);
688 assert_eq!(config.mcp.path, "/mcp");
689 assert!(!config.mcp.enabled);
690 }
691
692 #[test]
693 fn test_parse_minimal_config() {
694 let toml = r#"
695 [database]
696 url = "postgres://localhost/myapp"
697 "#;
698
699 let config = ForgeConfig::parse_toml(toml).unwrap();
700 assert_eq!(config.database.url(), "postgres://localhost/myapp");
701 assert_eq!(config.gateway.port, 8080);
702 }
703
704 #[test]
705 fn test_parse_full_config() {
706 let toml = r#"
707 [project]
708 name = "my-app"
709 version = "1.0.0"
710
711 [database]
712 url = "postgres://localhost/myapp"
713 pool_size = 100
714
715 [node]
716 roles = ["gateway", "worker"]
717 worker_capabilities = ["media", "general"]
718
719 [gateway]
720 port = 3000
721 grpc_port = 9001
722 "#;
723
724 let config = ForgeConfig::parse_toml(toml).unwrap();
725 assert_eq!(config.project.name, "my-app");
726 assert_eq!(config.database.pool_size, 100);
727 assert_eq!(config.node.roles.len(), 2);
728 assert_eq!(config.gateway.port, 3000);
729 }
730
731 #[test]
732 fn test_env_var_substitution() {
733 unsafe {
734 std::env::set_var("TEST_DB_URL", "postgres://test:test@localhost/test");
735 }
736
737 let toml = r#"
738 [database]
739 url = "${TEST_DB_URL}"
740 "#;
741
742 let config = ForgeConfig::parse_toml(toml).unwrap();
743 assert_eq!(config.database.url(), "postgres://test:test@localhost/test");
744
745 unsafe {
746 std::env::remove_var("TEST_DB_URL");
747 }
748 }
749
750 #[test]
751 fn test_auth_validation_no_config() {
752 let auth = AuthConfig::default();
753 assert!(auth.validate().is_ok());
754 }
755
756 #[test]
757 fn test_auth_validation_hmac_with_secret() {
758 let auth = AuthConfig {
759 jwt_secret: Some("my-secret".into()),
760 jwt_algorithm: JwtAlgorithm::HS256,
761 ..Default::default()
762 };
763 assert!(auth.validate().is_ok());
764 }
765
766 #[test]
767 fn test_auth_validation_hmac_missing_secret() {
768 let auth = AuthConfig {
769 jwt_issuer: Some("my-issuer".into()),
770 jwt_algorithm: JwtAlgorithm::HS256,
771 ..Default::default()
772 };
773 let result = auth.validate();
774 assert!(result.is_err());
775 let err_msg = result.unwrap_err().to_string();
776 assert!(err_msg.contains("jwt_secret is required"));
777 }
778
779 #[test]
780 fn test_auth_validation_rsa_with_jwks() {
781 let auth = AuthConfig {
782 jwks_url: Some("https://example.com/.well-known/jwks.json".into()),
783 jwt_algorithm: JwtAlgorithm::RS256,
784 ..Default::default()
785 };
786 assert!(auth.validate().is_ok());
787 }
788
789 #[test]
790 fn test_auth_validation_rsa_missing_jwks() {
791 let auth = AuthConfig {
792 jwt_issuer: Some("my-issuer".into()),
793 jwt_algorithm: JwtAlgorithm::RS256,
794 ..Default::default()
795 };
796 let result = auth.validate();
797 assert!(result.is_err());
798 let err_msg = result.unwrap_err().to_string();
799 assert!(err_msg.contains("jwks_url is required"));
800 }
801
802 #[test]
803 fn test_forge_config_validation_fails_on_empty_url() {
804 let toml = r#"
805 [database]
806
807 url = ""
808 "#;
809
810 let result = ForgeConfig::parse_toml(toml);
811 assert!(result.is_err());
812 let err_msg = result.unwrap_err().to_string();
813 assert!(err_msg.contains("database.url is required"));
814 }
815
816 #[test]
817 fn test_forge_config_validation_fails_on_invalid_auth() {
818 let toml = r#"
819 [database]
820
821 url = "postgres://localhost/test"
822
823 [auth]
824 jwt_issuer = "my-issuer"
825 jwt_algorithm = "RS256"
826 "#;
827
828 let result = ForgeConfig::parse_toml(toml);
829 assert!(result.is_err());
830 let err_msg = result.unwrap_err().to_string();
831 assert!(err_msg.contains("jwks_url is required"));
832 }
833
834 #[test]
835 fn test_env_var_default_used_when_unset() {
836 unsafe {
838 std::env::remove_var("TEST_FORGE_OTEL_UNSET");
839 }
840
841 let input = r#"enabled = ${TEST_FORGE_OTEL_UNSET-false}"#;
842 let result = substitute_env_vars(input);
843 assert_eq!(result, "enabled = false");
844 }
845
846 #[test]
847 fn test_env_var_default_overridden_when_set() {
848 unsafe {
849 std::env::set_var("TEST_FORGE_OTEL_SET", "true");
850 }
851
852 let input = r#"enabled = ${TEST_FORGE_OTEL_SET-false}"#;
853 let result = substitute_env_vars(input);
854 assert_eq!(result, "enabled = true");
855
856 unsafe {
857 std::env::remove_var("TEST_FORGE_OTEL_SET");
858 }
859 }
860
861 #[test]
862 fn test_env_var_colon_dash_default() {
863 unsafe {
864 std::env::remove_var("TEST_FORGE_ENDPOINT_UNSET");
865 }
866
867 let input = r#"endpoint = "${TEST_FORGE_ENDPOINT_UNSET:-http://localhost:4318}""#;
868 let result = substitute_env_vars(input);
869 assert_eq!(result, r#"endpoint = "http://localhost:4318""#);
870 }
871
872 #[test]
873 fn test_env_var_no_default_preserves_literal() {
874 unsafe {
875 std::env::remove_var("TEST_FORGE_MISSING");
876 }
877
878 let input = r#"url = "${TEST_FORGE_MISSING}""#;
879 let result = substitute_env_vars(input);
880 assert_eq!(result, r#"url = "${TEST_FORGE_MISSING}""#);
881 }
882
883 #[test]
884 fn test_env_var_default_empty_string() {
885 unsafe {
886 std::env::remove_var("TEST_FORGE_EMPTY_DEFAULT");
887 }
888
889 let input = r#"val = "${TEST_FORGE_EMPTY_DEFAULT-}""#;
890 let result = substitute_env_vars(input);
891 assert_eq!(result, r#"val = """#);
892 }
893
894 #[test]
895 fn test_observability_config_default_disabled() {
896 let toml = r#"
897 [database]
898 url = "postgres://localhost/test"
899 "#;
900
901 let config = ForgeConfig::parse_toml(toml).unwrap();
902 assert!(!config.observability.enabled);
903 assert!(!config.observability.otlp_active());
904 }
905
906 #[test]
907 fn test_observability_config_with_env_default() {
908 unsafe {
910 std::env::remove_var("TEST_OTEL_ENABLED");
911 }
912
913 let toml = r#"
914 [database]
915 url = "postgres://localhost/test"
916
917 [observability]
918 enabled = ${TEST_OTEL_ENABLED-false}
919 "#;
920
921 let config = ForgeConfig::parse_toml(toml).unwrap();
922 assert!(!config.observability.enabled);
923 }
924
925 #[test]
926 fn test_mcp_config_validation_rejects_invalid_path() {
927 let toml = r#"
928 [database]
929
930 url = "postgres://localhost/test"
931
932 [mcp]
933 enabled = true
934 path = "mcp"
935 "#;
936
937 let result = ForgeConfig::parse_toml(toml);
938 assert!(result.is_err());
939 let err_msg = result.unwrap_err().to_string();
940 assert!(err_msg.contains("mcp.path must start with '/'"));
941 }
942}