1pub mod cluster;
2mod database;
3
4pub use cluster::ClusterConfig;
5pub use database::{DatabaseConfig, PoolConfig};
6
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9
10use crate::error::{ForgeError, Result};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ForgeConfig {
15 #[serde(default)]
17 pub project: ProjectConfig,
18
19 pub database: DatabaseConfig,
21
22 #[serde(default)]
24 pub node: NodeConfig,
25
26 #[serde(default)]
28 pub gateway: GatewayConfig,
29
30 #[serde(default)]
32 pub function: FunctionConfig,
33
34 #[serde(default)]
36 pub worker: WorkerConfig,
37
38 #[serde(default)]
40 pub cluster: ClusterConfig,
41
42 #[serde(default)]
44 pub security: SecurityConfig,
45
46 #[serde(default)]
48 pub auth: AuthConfig,
49
50 #[serde(default)]
52 pub observability: ObservabilityConfig,
53
54 #[serde(default)]
56 pub mcp: McpConfig,
57}
58
59impl ForgeConfig {
60 pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
62 let content = std::fs::read_to_string(path.as_ref())
63 .map_err(|e| ForgeError::Config(format!("Failed to read config file: {}", e)))?;
64
65 Self::parse_toml(&content)
66 }
67
68 pub fn parse_toml(content: &str) -> Result<Self> {
70 let content = substitute_env_vars(content);
72
73 let config: Self = toml::from_str(&content)
74 .map_err(|e| ForgeError::Config(format!("Failed to parse config: {}", e)))?;
75
76 config.validate()?;
77 Ok(config)
78 }
79
80 pub fn validate(&self) -> Result<()> {
82 self.database.validate()?;
83 self.auth.validate()?;
84 self.mcp.validate()?;
85 Ok(())
86 }
87
88 pub fn default_with_database_url(url: &str) -> Self {
90 Self {
91 project: ProjectConfig::default(),
92 database: DatabaseConfig::new(url),
93 node: NodeConfig::default(),
94 gateway: GatewayConfig::default(),
95 function: FunctionConfig::default(),
96 worker: WorkerConfig::default(),
97 cluster: ClusterConfig::default(),
98 security: SecurityConfig::default(),
99 auth: AuthConfig::default(),
100 observability: ObservabilityConfig::default(),
101 mcp: McpConfig::default(),
102 }
103 }
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct ProjectConfig {
109 #[serde(default = "default_project_name")]
111 pub name: String,
112
113 #[serde(default = "default_version")]
115 pub version: String,
116}
117
118impl Default for ProjectConfig {
119 fn default() -> Self {
120 Self {
121 name: default_project_name(),
122 version: default_version(),
123 }
124 }
125}
126
127fn default_project_name() -> String {
128 "forge-app".to_string()
129}
130
131fn default_version() -> String {
132 "0.1.0".to_string()
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct NodeConfig {
138 #[serde(default = "default_roles")]
140 pub roles: Vec<NodeRole>,
141
142 #[serde(default = "default_capabilities")]
144 pub worker_capabilities: Vec<String>,
145}
146
147impl Default for NodeConfig {
148 fn default() -> Self {
149 Self {
150 roles: default_roles(),
151 worker_capabilities: default_capabilities(),
152 }
153 }
154}
155
156fn default_roles() -> Vec<NodeRole> {
157 vec![
158 NodeRole::Gateway,
159 NodeRole::Function,
160 NodeRole::Worker,
161 NodeRole::Scheduler,
162 ]
163}
164
165fn default_capabilities() -> Vec<String> {
166 vec!["general".to_string()]
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
171#[serde(rename_all = "lowercase")]
172pub enum NodeRole {
173 Gateway,
174 Function,
175 Worker,
176 Scheduler,
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct GatewayConfig {
182 #[serde(default = "default_http_port")]
184 pub port: u16,
185
186 #[serde(default = "default_grpc_port")]
188 pub grpc_port: u16,
189
190 #[serde(default = "default_max_connections")]
192 pub max_connections: usize,
193
194 #[serde(default = "default_sse_max_sessions")]
196 pub sse_max_sessions: usize,
197
198 #[serde(default = "default_request_timeout")]
200 pub request_timeout_secs: u64,
201
202 #[serde(default = "default_cors_enabled")]
204 pub cors_enabled: bool,
205
206 #[serde(default = "default_cors_origins")]
208 pub cors_origins: Vec<String>,
209
210 #[serde(default = "default_quiet_routes")]
213 pub quiet_routes: Vec<String>,
214}
215
216impl Default for GatewayConfig {
217 fn default() -> Self {
218 Self {
219 port: default_http_port(),
220 grpc_port: default_grpc_port(),
221 max_connections: default_max_connections(),
222 sse_max_sessions: default_sse_max_sessions(),
223 request_timeout_secs: default_request_timeout(),
224 cors_enabled: default_cors_enabled(),
225 cors_origins: default_cors_origins(),
226 quiet_routes: default_quiet_routes(),
227 }
228 }
229}
230
231fn default_http_port() -> u16 {
232 8080
233}
234
235fn default_grpc_port() -> u16 {
236 9000
237}
238
239fn default_max_connections() -> usize {
240 4096
241}
242
243fn default_sse_max_sessions() -> usize {
244 10_000
245}
246
247fn default_request_timeout() -> u64 {
248 30
249}
250
251fn default_cors_enabled() -> bool {
252 false
253}
254
255fn default_cors_origins() -> Vec<String> {
256 Vec::new()
257}
258
259fn default_quiet_routes() -> Vec<String> {
260 vec!["/_api/health".to_string(), "/_api/ready".to_string()]
261}
262
263#[derive(Debug, Clone, Serialize, Deserialize)]
265pub struct FunctionConfig {
266 #[serde(default = "default_max_concurrent")]
268 pub max_concurrent: usize,
269
270 #[serde(default = "default_function_timeout")]
272 pub timeout_secs: u64,
273
274 #[serde(default = "default_memory_limit")]
276 pub memory_limit: usize,
277}
278
279impl Default for FunctionConfig {
280 fn default() -> Self {
281 Self {
282 max_concurrent: default_max_concurrent(),
283 timeout_secs: default_function_timeout(),
284 memory_limit: default_memory_limit(),
285 }
286 }
287}
288
289fn default_max_concurrent() -> usize {
290 1000
291}
292
293fn default_function_timeout() -> u64 {
294 30
295}
296
297fn default_memory_limit() -> usize {
298 512 * 1024 * 1024 }
300
301#[derive(Debug, Clone, Serialize, Deserialize)]
303pub struct WorkerConfig {
304 #[serde(default = "default_max_concurrent_jobs")]
306 pub max_concurrent_jobs: usize,
307
308 #[serde(default = "default_job_timeout")]
310 pub job_timeout_secs: u64,
311
312 #[serde(default = "default_poll_interval")]
314 pub poll_interval_ms: u64,
315}
316
317impl Default for WorkerConfig {
318 fn default() -> Self {
319 Self {
320 max_concurrent_jobs: default_max_concurrent_jobs(),
321 job_timeout_secs: default_job_timeout(),
322 poll_interval_ms: default_poll_interval(),
323 }
324 }
325}
326
327fn default_max_concurrent_jobs() -> usize {
328 50
329}
330
331fn default_job_timeout() -> u64 {
332 3600 }
334
335fn default_poll_interval() -> u64 {
336 100
337}
338
339#[derive(Debug, Clone, Serialize, Deserialize, Default)]
341pub struct SecurityConfig {
342 pub secret_key: Option<String>,
344}
345
346#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
348#[serde(rename_all = "UPPERCASE")]
349pub enum JwtAlgorithm {
350 #[default]
352 HS256,
353 HS384,
355 HS512,
357 RS256,
359 RS384,
361 RS512,
363}
364
365#[derive(Debug, Clone, Serialize, Deserialize)]
367pub struct AuthConfig {
368 pub jwt_secret: Option<String>,
371
372 #[serde(default)]
376 pub jwt_algorithm: JwtAlgorithm,
377
378 pub jwt_issuer: Option<String>,
381
382 pub jwt_audience: Option<String>,
385
386 pub token_expiry: Option<String>,
388
389 pub jwks_url: Option<String>,
392
393 #[serde(default = "default_jwks_cache_ttl")]
395 pub jwks_cache_ttl_secs: u64,
396
397 #[serde(default = "default_session_ttl")]
399 pub session_ttl_secs: u64,
400}
401
402impl Default for AuthConfig {
403 fn default() -> Self {
404 Self {
405 jwt_secret: None,
406 jwt_algorithm: JwtAlgorithm::default(),
407 jwt_issuer: None,
408 jwt_audience: None,
409 token_expiry: None,
410 jwks_url: None,
411 jwks_cache_ttl_secs: default_jwks_cache_ttl(),
412 session_ttl_secs: default_session_ttl(),
413 }
414 }
415}
416
417impl AuthConfig {
418 fn is_configured(&self) -> bool {
420 self.jwt_secret.is_some()
421 || self.jwks_url.is_some()
422 || self.jwt_issuer.is_some()
423 || self.jwt_audience.is_some()
424 }
425
426 pub fn validate(&self) -> Result<()> {
429 if !self.is_configured() {
430 return Ok(());
431 }
432
433 match self.jwt_algorithm {
434 JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512 => {
435 if self.jwt_secret.is_none() {
436 return Err(ForgeError::Config(
437 "auth.jwt_secret is required for HMAC algorithms (HS256, HS384, HS512). \
438 Set auth.jwt_secret to a secure random string, \
439 or switch to RS256 and provide auth.jwks_url for external identity providers."
440 .into(),
441 ));
442 }
443 }
444 JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512 => {
445 if self.jwks_url.is_none() {
446 return Err(ForgeError::Config(
447 "auth.jwks_url is required for RSA algorithms (RS256, RS384, RS512). \
448 Set auth.jwks_url to your identity provider's JWKS endpoint, \
449 or switch to HS256 and provide auth.jwt_secret for symmetric signing."
450 .into(),
451 ));
452 }
453 }
454 }
455 Ok(())
456 }
457
458 pub fn is_hmac(&self) -> bool {
460 matches!(
461 self.jwt_algorithm,
462 JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
463 )
464 }
465
466 pub fn is_rsa(&self) -> bool {
468 matches!(
469 self.jwt_algorithm,
470 JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
471 )
472 }
473}
474
475fn default_jwks_cache_ttl() -> u64 {
476 3600 }
478
479fn default_session_ttl() -> u64 {
480 7 * 24 * 60 * 60 }
482
483#[derive(Debug, Clone, Serialize, Deserialize)]
485pub struct ObservabilityConfig {
486 #[serde(default)]
488 pub enabled: bool,
489
490 #[serde(default = "default_otlp_endpoint")]
492 pub otlp_endpoint: String,
493
494 pub service_name: Option<String>,
496
497 #[serde(default = "default_true")]
499 pub enable_traces: bool,
500
501 #[serde(default = "default_true")]
503 pub enable_metrics: bool,
504
505 #[serde(default = "default_true")]
507 pub enable_logs: bool,
508
509 #[serde(default = "default_sampling_ratio")]
511 pub sampling_ratio: f64,
512
513 #[serde(default = "default_log_level")]
515 pub log_level: String,
516}
517
518impl Default for ObservabilityConfig {
519 fn default() -> Self {
520 Self {
521 enabled: false,
522 otlp_endpoint: default_otlp_endpoint(),
523 service_name: None,
524 enable_traces: true,
525 enable_metrics: true,
526 enable_logs: true,
527 sampling_ratio: default_sampling_ratio(),
528 log_level: default_log_level(),
529 }
530 }
531}
532
533impl ObservabilityConfig {
534 pub fn otlp_active(&self) -> bool {
535 self.enabled && (self.enable_traces || self.enable_metrics || self.enable_logs)
536 }
537}
538
539fn default_otlp_endpoint() -> String {
540 "http://localhost:4318".to_string()
541}
542
543fn default_true() -> bool {
544 true
545}
546
547fn default_sampling_ratio() -> f64 {
548 1.0
549}
550
551fn default_log_level() -> String {
552 "info".to_string()
553}
554
555#[derive(Debug, Clone, Serialize, Deserialize)]
557pub struct McpConfig {
558 #[serde(default)]
560 pub enabled: bool,
561
562 #[serde(default = "default_mcp_path")]
564 pub path: String,
565
566 #[serde(default = "default_mcp_session_ttl_secs")]
568 pub session_ttl_secs: u64,
569
570 #[serde(default)]
572 pub allowed_origins: Vec<String>,
573
574 #[serde(default = "default_true")]
576 pub require_protocol_version_header: bool,
577}
578
579impl Default for McpConfig {
580 fn default() -> Self {
581 Self {
582 enabled: false,
583 path: default_mcp_path(),
584 session_ttl_secs: default_mcp_session_ttl_secs(),
585 allowed_origins: Vec::new(),
586 require_protocol_version_header: default_true(),
587 }
588 }
589}
590
591impl McpConfig {
592 pub fn validate(&self) -> Result<()> {
593 if self.path.is_empty() || !self.path.starts_with('/') {
594 return Err(ForgeError::Config(
595 "mcp.path must start with '/' (example: /mcp)".to_string(),
596 ));
597 }
598 if self.path.contains(' ') {
599 return Err(ForgeError::Config(
600 "mcp.path cannot contain spaces".to_string(),
601 ));
602 }
603 if self.session_ttl_secs == 0 {
604 return Err(ForgeError::Config(
605 "mcp.session_ttl_secs must be greater than 0".to_string(),
606 ));
607 }
608 Ok(())
609 }
610}
611
612fn default_mcp_path() -> String {
613 "/mcp".to_string()
614}
615
616fn default_mcp_session_ttl_secs() -> u64 {
617 60 * 60
618}
619
620#[allow(clippy::indexing_slicing)]
627pub fn substitute_env_vars(content: &str) -> String {
628 let mut result = String::with_capacity(content.len());
629 let bytes = content.as_bytes();
630 let len = bytes.len();
631 let mut i = 0;
632
633 while i < len {
634 if i + 1 < len
635 && bytes[i] == b'$'
636 && bytes[i + 1] == b'{'
637 && let Some(end) = content[i + 2..].find('}')
638 {
639 let inner = &content[i + 2..i + 2 + end];
640
641 let (var_name, default_value) = parse_var_with_default(inner);
643
644 if is_valid_env_var_name(var_name) {
645 if let Ok(value) = std::env::var(var_name) {
646 result.push_str(&value);
647 } else if let Some(default) = default_value {
648 result.push_str(default);
649 } else {
650 result.push_str(&content[i..i + 2 + end + 1]);
651 }
652 i += 2 + end + 1;
653 continue;
654 }
655 }
656 result.push(bytes[i] as char);
657 i += 1;
658 }
659
660 result
661}
662
663fn parse_var_with_default(inner: &str) -> (&str, Option<&str>) {
667 if let Some(pos) = inner.find(":-") {
668 return (&inner[..pos], Some(&inner[pos + 2..]));
669 }
670 if let Some(pos) = inner.find('-') {
671 return (&inner[..pos], Some(&inner[pos + 1..]));
672 }
673 (inner, None)
674}
675
676fn is_valid_env_var_name(name: &str) -> bool {
677 let first = match name.as_bytes().first() {
678 Some(b) => b,
679 None => return false,
680 };
681 (first.is_ascii_uppercase() || *first == b'_')
682 && name
683 .bytes()
684 .all(|b| b.is_ascii_uppercase() || b.is_ascii_digit() || b == b'_')
685}
686
687#[cfg(test)]
688#[allow(clippy::unwrap_used, clippy::indexing_slicing, unsafe_code)]
689mod tests {
690 use super::*;
691
692 #[test]
693 fn test_default_config() {
694 let config = ForgeConfig::default_with_database_url("postgres://localhost/test");
695 assert_eq!(config.gateway.port, 8080);
696 assert_eq!(config.node.roles.len(), 4);
697 assert_eq!(config.mcp.path, "/mcp");
698 assert!(!config.mcp.enabled);
699 }
700
701 #[test]
702 fn test_parse_minimal_config() {
703 let toml = r#"
704 [database]
705 url = "postgres://localhost/myapp"
706 "#;
707
708 let config = ForgeConfig::parse_toml(toml).unwrap();
709 assert_eq!(config.database.url(), "postgres://localhost/myapp");
710 assert_eq!(config.gateway.port, 8080);
711 }
712
713 #[test]
714 fn test_parse_full_config() {
715 let toml = r#"
716 [project]
717 name = "my-app"
718 version = "1.0.0"
719
720 [database]
721 url = "postgres://localhost/myapp"
722 pool_size = 100
723
724 [node]
725 roles = ["gateway", "worker"]
726 worker_capabilities = ["media", "general"]
727
728 [gateway]
729 port = 3000
730 grpc_port = 9001
731 "#;
732
733 let config = ForgeConfig::parse_toml(toml).unwrap();
734 assert_eq!(config.project.name, "my-app");
735 assert_eq!(config.database.pool_size, 100);
736 assert_eq!(config.node.roles.len(), 2);
737 assert_eq!(config.gateway.port, 3000);
738 }
739
740 #[test]
741 fn test_env_var_substitution() {
742 unsafe {
743 std::env::set_var("TEST_DB_URL", "postgres://test:test@localhost/test");
744 }
745
746 let toml = r#"
747 [database]
748 url = "${TEST_DB_URL}"
749 "#;
750
751 let config = ForgeConfig::parse_toml(toml).unwrap();
752 assert_eq!(config.database.url(), "postgres://test:test@localhost/test");
753
754 unsafe {
755 std::env::remove_var("TEST_DB_URL");
756 }
757 }
758
759 #[test]
760 fn test_auth_validation_no_config() {
761 let auth = AuthConfig::default();
762 assert!(auth.validate().is_ok());
763 }
764
765 #[test]
766 fn test_auth_validation_hmac_with_secret() {
767 let auth = AuthConfig {
768 jwt_secret: Some("my-secret".into()),
769 jwt_algorithm: JwtAlgorithm::HS256,
770 ..Default::default()
771 };
772 assert!(auth.validate().is_ok());
773 }
774
775 #[test]
776 fn test_auth_validation_hmac_missing_secret() {
777 let auth = AuthConfig {
778 jwt_issuer: Some("my-issuer".into()),
779 jwt_algorithm: JwtAlgorithm::HS256,
780 ..Default::default()
781 };
782 let result = auth.validate();
783 assert!(result.is_err());
784 let err_msg = result.unwrap_err().to_string();
785 assert!(err_msg.contains("jwt_secret is required"));
786 }
787
788 #[test]
789 fn test_auth_validation_rsa_with_jwks() {
790 let auth = AuthConfig {
791 jwks_url: Some("https://example.com/.well-known/jwks.json".into()),
792 jwt_algorithm: JwtAlgorithm::RS256,
793 ..Default::default()
794 };
795 assert!(auth.validate().is_ok());
796 }
797
798 #[test]
799 fn test_auth_validation_rsa_missing_jwks() {
800 let auth = AuthConfig {
801 jwt_issuer: Some("my-issuer".into()),
802 jwt_algorithm: JwtAlgorithm::RS256,
803 ..Default::default()
804 };
805 let result = auth.validate();
806 assert!(result.is_err());
807 let err_msg = result.unwrap_err().to_string();
808 assert!(err_msg.contains("jwks_url is required"));
809 }
810
811 #[test]
812 fn test_forge_config_validation_fails_on_empty_url() {
813 let toml = r#"
814 [database]
815
816 url = ""
817 "#;
818
819 let result = ForgeConfig::parse_toml(toml);
820 assert!(result.is_err());
821 let err_msg = result.unwrap_err().to_string();
822 assert!(err_msg.contains("database.url is required"));
823 }
824
825 #[test]
826 fn test_forge_config_validation_fails_on_invalid_auth() {
827 let toml = r#"
828 [database]
829
830 url = "postgres://localhost/test"
831
832 [auth]
833 jwt_issuer = "my-issuer"
834 jwt_algorithm = "RS256"
835 "#;
836
837 let result = ForgeConfig::parse_toml(toml);
838 assert!(result.is_err());
839 let err_msg = result.unwrap_err().to_string();
840 assert!(err_msg.contains("jwks_url is required"));
841 }
842
843 #[test]
844 fn test_env_var_default_used_when_unset() {
845 unsafe {
847 std::env::remove_var("TEST_FORGE_OTEL_UNSET");
848 }
849
850 let input = r#"enabled = ${TEST_FORGE_OTEL_UNSET-false}"#;
851 let result = substitute_env_vars(input);
852 assert_eq!(result, "enabled = false");
853 }
854
855 #[test]
856 fn test_env_var_default_overridden_when_set() {
857 unsafe {
858 std::env::set_var("TEST_FORGE_OTEL_SET", "true");
859 }
860
861 let input = r#"enabled = ${TEST_FORGE_OTEL_SET-false}"#;
862 let result = substitute_env_vars(input);
863 assert_eq!(result, "enabled = true");
864
865 unsafe {
866 std::env::remove_var("TEST_FORGE_OTEL_SET");
867 }
868 }
869
870 #[test]
871 fn test_env_var_colon_dash_default() {
872 unsafe {
873 std::env::remove_var("TEST_FORGE_ENDPOINT_UNSET");
874 }
875
876 let input = r#"endpoint = "${TEST_FORGE_ENDPOINT_UNSET:-http://localhost:4318}""#;
877 let result = substitute_env_vars(input);
878 assert_eq!(result, r#"endpoint = "http://localhost:4318""#);
879 }
880
881 #[test]
882 fn test_env_var_no_default_preserves_literal() {
883 unsafe {
884 std::env::remove_var("TEST_FORGE_MISSING");
885 }
886
887 let input = r#"url = "${TEST_FORGE_MISSING}""#;
888 let result = substitute_env_vars(input);
889 assert_eq!(result, r#"url = "${TEST_FORGE_MISSING}""#);
890 }
891
892 #[test]
893 fn test_env_var_default_empty_string() {
894 unsafe {
895 std::env::remove_var("TEST_FORGE_EMPTY_DEFAULT");
896 }
897
898 let input = r#"val = "${TEST_FORGE_EMPTY_DEFAULT-}""#;
899 let result = substitute_env_vars(input);
900 assert_eq!(result, r#"val = """#);
901 }
902
903 #[test]
904 fn test_observability_config_default_disabled() {
905 let toml = r#"
906 [database]
907 url = "postgres://localhost/test"
908 "#;
909
910 let config = ForgeConfig::parse_toml(toml).unwrap();
911 assert!(!config.observability.enabled);
912 assert!(!config.observability.otlp_active());
913 }
914
915 #[test]
916 fn test_observability_config_with_env_default() {
917 unsafe {
919 std::env::remove_var("TEST_OTEL_ENABLED");
920 }
921
922 let toml = r#"
923 [database]
924 url = "postgres://localhost/test"
925
926 [observability]
927 enabled = ${TEST_OTEL_ENABLED-false}
928 "#;
929
930 let config = ForgeConfig::parse_toml(toml).unwrap();
931 assert!(!config.observability.enabled);
932 }
933
934 #[test]
935 fn test_mcp_config_validation_rejects_invalid_path() {
936 let toml = r#"
937 [database]
938
939 url = "postgres://localhost/test"
940
941 [mcp]
942 enabled = true
943 path = "mcp"
944 "#;
945
946 let result = ForgeConfig::parse_toml(toml);
947 assert!(result.is_err());
948 let err_msg = result.unwrap_err().to_string();
949 assert!(err_msg.contains("mcp.path must start with '/'"));
950 }
951}