Skip to main content

forge_core/config/
mod.rs

1mod cluster;
2mod database;
3
4pub use cluster::ClusterConfig;
5pub use database::{DatabaseConfig, DatabaseSource};
6
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9
10use crate::error::{ForgeError, Result};
11
12/// Root configuration for FORGE.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ForgeConfig {
15    /// Project metadata.
16    #[serde(default)]
17    pub project: ProjectConfig,
18
19    /// Database configuration.
20    pub database: DatabaseConfig,
21
22    /// Node configuration.
23    #[serde(default)]
24    pub node: NodeConfig,
25
26    /// Gateway configuration.
27    #[serde(default)]
28    pub gateway: GatewayConfig,
29
30    /// Function execution configuration.
31    #[serde(default)]
32    pub function: FunctionConfig,
33
34    /// Worker configuration.
35    #[serde(default)]
36    pub worker: WorkerConfig,
37
38    /// Cluster configuration.
39    #[serde(default)]
40    pub cluster: ClusterConfig,
41
42    /// Security configuration.
43    #[serde(default)]
44    pub security: SecurityConfig,
45
46    /// Authentication configuration.
47    #[serde(default)]
48    pub auth: AuthConfig,
49
50    /// Observability configuration.
51    #[serde(default)]
52    pub observability: ObservabilityConfig,
53
54    /// MCP server configuration.
55    #[serde(default)]
56    pub mcp: McpConfig,
57}
58
59impl ForgeConfig {
60    /// Load configuration from a TOML file.
61    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
62        let content = std::fs::read_to_string(path.as_ref())
63            .map_err(|e| ForgeError::Config(format!("Failed to read config file: {}", e)))?;
64
65        Self::parse_toml(&content)
66    }
67
68    /// Parse configuration from a TOML string.
69    pub fn parse_toml(content: &str) -> Result<Self> {
70        // Substitute environment variables
71        let content = substitute_env_vars(content);
72
73        let config: Self = toml::from_str(&content)
74            .map_err(|e| ForgeError::Config(format!("Failed to parse config: {}", e)))?;
75
76        config.validate()?;
77        Ok(config)
78    }
79
80    /// Validate the configuration for invalid combinations.
81    pub fn validate(&self) -> Result<()> {
82        self.database.validate()?;
83        self.auth.validate()?;
84        self.mcp.validate()?;
85        Ok(())
86    }
87
88    /// Load configuration with defaults.
89    pub fn default_with_database_url(url: &str) -> Self {
90        Self {
91            project: ProjectConfig::default(),
92            database: DatabaseConfig::remote(url),
93            node: NodeConfig::default(),
94            gateway: GatewayConfig::default(),
95            function: FunctionConfig::default(),
96            worker: WorkerConfig::default(),
97            cluster: ClusterConfig::default(),
98            security: SecurityConfig::default(),
99            auth: AuthConfig::default(),
100            observability: ObservabilityConfig::default(),
101            mcp: McpConfig::default(),
102        }
103    }
104}
105
106/// Project metadata.
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct ProjectConfig {
109    /// Project name.
110    #[serde(default = "default_project_name")]
111    pub name: String,
112
113    /// Project version.
114    #[serde(default = "default_version")]
115    pub version: String,
116}
117
118impl Default for ProjectConfig {
119    fn default() -> Self {
120        Self {
121            name: default_project_name(),
122            version: default_version(),
123        }
124    }
125}
126
127fn default_project_name() -> String {
128    "forge-app".to_string()
129}
130
131fn default_version() -> String {
132    "0.1.0".to_string()
133}
134
135/// Node role configuration.
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct NodeConfig {
138    /// Roles this node should assume.
139    #[serde(default = "default_roles")]
140    pub roles: Vec<NodeRole>,
141
142    /// Worker capabilities for job routing.
143    #[serde(default = "default_capabilities")]
144    pub worker_capabilities: Vec<String>,
145}
146
147impl Default for NodeConfig {
148    fn default() -> Self {
149        Self {
150            roles: default_roles(),
151            worker_capabilities: default_capabilities(),
152        }
153    }
154}
155
156fn default_roles() -> Vec<NodeRole> {
157    vec![
158        NodeRole::Gateway,
159        NodeRole::Function,
160        NodeRole::Worker,
161        NodeRole::Scheduler,
162    ]
163}
164
165fn default_capabilities() -> Vec<String> {
166    vec!["general".to_string()]
167}
168
169/// Available node roles.
170#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
171#[serde(rename_all = "lowercase")]
172pub enum NodeRole {
173    Gateway,
174    Function,
175    Worker,
176    Scheduler,
177}
178
179/// Gateway configuration.
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct GatewayConfig {
182    /// HTTP port.
183    #[serde(default = "default_http_port")]
184    pub port: u16,
185
186    /// gRPC port for inter-node communication.
187    #[serde(default = "default_grpc_port")]
188    pub grpc_port: u16,
189
190    /// Maximum concurrent connections.
191    #[serde(default = "default_max_connections")]
192    pub max_connections: usize,
193
194    /// Request timeout in seconds.
195    #[serde(default = "default_request_timeout")]
196    pub request_timeout_secs: u64,
197
198    /// Enable CORS handling.
199    #[serde(default = "default_cors_enabled")]
200    pub cors_enabled: bool,
201
202    /// Allowed CORS origins.
203    #[serde(default = "default_cors_origins")]
204    pub cors_origins: Vec<String>,
205}
206
207impl Default for GatewayConfig {
208    fn default() -> Self {
209        Self {
210            port: default_http_port(),
211            grpc_port: default_grpc_port(),
212            max_connections: default_max_connections(),
213            request_timeout_secs: default_request_timeout(),
214            cors_enabled: default_cors_enabled(),
215            cors_origins: default_cors_origins(),
216        }
217    }
218}
219
220fn default_http_port() -> u16 {
221    8080
222}
223
224fn default_grpc_port() -> u16 {
225    9000
226}
227
228fn default_max_connections() -> usize {
229    512
230}
231
232fn default_request_timeout() -> u64 {
233    30
234}
235
236fn default_cors_enabled() -> bool {
237    false
238}
239
240fn default_cors_origins() -> Vec<String> {
241    Vec::new()
242}
243
244/// Function execution configuration.
245#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct FunctionConfig {
247    /// Maximum concurrent function executions.
248    #[serde(default = "default_max_concurrent")]
249    pub max_concurrent: usize,
250
251    /// Function timeout in seconds.
252    #[serde(default = "default_function_timeout")]
253    pub timeout_secs: u64,
254
255    /// Memory limit per function (in bytes).
256    #[serde(default = "default_memory_limit")]
257    pub memory_limit: usize,
258}
259
260impl Default for FunctionConfig {
261    fn default() -> Self {
262        Self {
263            max_concurrent: default_max_concurrent(),
264            timeout_secs: default_function_timeout(),
265            memory_limit: default_memory_limit(),
266        }
267    }
268}
269
270fn default_max_concurrent() -> usize {
271    1000
272}
273
274fn default_function_timeout() -> u64 {
275    30
276}
277
278fn default_memory_limit() -> usize {
279    512 * 1024 * 1024 // 512 MiB
280}
281
282/// Worker configuration.
283#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct WorkerConfig {
285    /// Maximum concurrent jobs.
286    #[serde(default = "default_max_concurrent_jobs")]
287    pub max_concurrent_jobs: usize,
288
289    /// Job timeout in seconds.
290    #[serde(default = "default_job_timeout")]
291    pub job_timeout_secs: u64,
292
293    /// Poll interval in milliseconds.
294    #[serde(default = "default_poll_interval")]
295    pub poll_interval_ms: u64,
296}
297
298impl Default for WorkerConfig {
299    fn default() -> Self {
300        Self {
301            max_concurrent_jobs: default_max_concurrent_jobs(),
302            job_timeout_secs: default_job_timeout(),
303            poll_interval_ms: default_poll_interval(),
304        }
305    }
306}
307
308fn default_max_concurrent_jobs() -> usize {
309    50
310}
311
312fn default_job_timeout() -> u64 {
313    3600 // 1 hour
314}
315
316fn default_poll_interval() -> u64 {
317    100
318}
319
320/// Security configuration.
321#[derive(Debug, Clone, Serialize, Deserialize, Default)]
322pub struct SecurityConfig {
323    /// Secret key for signing.
324    pub secret_key: Option<String>,
325}
326
327/// JWT signing algorithm.
328#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
329#[serde(rename_all = "UPPERCASE")]
330pub enum JwtAlgorithm {
331    /// HMAC using SHA-256 (symmetric, requires jwt_secret).
332    #[default]
333    HS256,
334    /// HMAC using SHA-384 (symmetric, requires jwt_secret).
335    HS384,
336    /// HMAC using SHA-512 (symmetric, requires jwt_secret).
337    HS512,
338    /// RSA using SHA-256 (asymmetric, requires jwks_url).
339    RS256,
340    /// RSA using SHA-384 (asymmetric, requires jwks_url).
341    RS384,
342    /// RSA using SHA-512 (asymmetric, requires jwks_url).
343    RS512,
344}
345
346/// Authentication configuration.
347#[derive(Debug, Clone, Serialize, Deserialize)]
348pub struct AuthConfig {
349    /// JWT secret for HMAC algorithms (HS256, HS384, HS512).
350    /// Required when using HMAC algorithms.
351    pub jwt_secret: Option<String>,
352
353    /// JWT signing algorithm.
354    /// HMAC algorithms (HS256, HS384, HS512) require jwt_secret.
355    /// RSA algorithms (RS256, RS384, RS512) require jwks_url.
356    #[serde(default)]
357    pub jwt_algorithm: JwtAlgorithm,
358
359    /// Expected token issuer (iss claim).
360    /// If set, tokens with a different issuer are rejected.
361    pub jwt_issuer: Option<String>,
362
363    /// Expected audience (aud claim).
364    /// If set, tokens with a different audience are rejected.
365    pub jwt_audience: Option<String>,
366
367    /// Token expiry duration (e.g., "15m", "1h", "7d").
368    pub token_expiry: Option<String>,
369
370    /// JWKS URL for RSA algorithms (RS256, RS384, RS512).
371    /// Keys are fetched and cached automatically.
372    pub jwks_url: Option<String>,
373
374    /// JWKS cache TTL in seconds.
375    #[serde(default = "default_jwks_cache_ttl")]
376    pub jwks_cache_ttl_secs: u64,
377
378    /// Session TTL in seconds (for WebSocket sessions).
379    #[serde(default = "default_session_ttl")]
380    pub session_ttl_secs: u64,
381}
382
383impl Default for AuthConfig {
384    fn default() -> Self {
385        Self {
386            jwt_secret: None,
387            jwt_algorithm: JwtAlgorithm::default(),
388            jwt_issuer: None,
389            jwt_audience: None,
390            token_expiry: None,
391            jwks_url: None,
392            jwks_cache_ttl_secs: default_jwks_cache_ttl(),
393            session_ttl_secs: default_session_ttl(),
394        }
395    }
396}
397
398impl AuthConfig {
399    /// Check if auth is configured (any credential or claim validation is set).
400    fn is_configured(&self) -> bool {
401        self.jwt_secret.is_some()
402            || self.jwks_url.is_some()
403            || self.jwt_issuer.is_some()
404            || self.jwt_audience.is_some()
405    }
406
407    /// Validate that the configuration is complete for the chosen algorithm.
408    /// Skips validation if no auth settings are configured (auth disabled).
409    pub fn validate(&self) -> Result<()> {
410        if !self.is_configured() {
411            return Ok(());
412        }
413
414        match self.jwt_algorithm {
415            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512 => {
416                if self.jwt_secret.is_none() {
417                    return Err(ForgeError::Config(
418                        "auth.jwt_secret is required for HMAC algorithms (HS256, HS384, HS512). \
419                         Set auth.jwt_secret to a secure random string, \
420                         or switch to RS256 and provide auth.jwks_url for external identity providers."
421                            .into(),
422                    ));
423                }
424            }
425            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512 => {
426                if self.jwks_url.is_none() {
427                    return Err(ForgeError::Config(
428                        "auth.jwks_url is required for RSA algorithms (RS256, RS384, RS512). \
429                         Set auth.jwks_url to your identity provider's JWKS endpoint, \
430                         or switch to HS256 and provide auth.jwt_secret for symmetric signing."
431                            .into(),
432                    ));
433                }
434            }
435        }
436        Ok(())
437    }
438
439    /// Check if this config uses HMAC (symmetric) algorithms.
440    pub fn is_hmac(&self) -> bool {
441        matches!(
442            self.jwt_algorithm,
443            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
444        )
445    }
446
447    /// Check if this config uses RSA (asymmetric) algorithms.
448    pub fn is_rsa(&self) -> bool {
449        matches!(
450            self.jwt_algorithm,
451            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
452        )
453    }
454}
455
456fn default_jwks_cache_ttl() -> u64 {
457    3600 // 1 hour
458}
459
460fn default_session_ttl() -> u64 {
461    7 * 24 * 60 * 60 // 7 days
462}
463
464/// Observability configuration for OTLP telemetry.
465#[derive(Debug, Clone, Serialize, Deserialize)]
466pub struct ObservabilityConfig {
467    /// Enable observability (traces, metrics, logs).
468    #[serde(default)]
469    pub enabled: bool,
470
471    /// OTLP endpoint for telemetry export.
472    #[serde(default = "default_otlp_endpoint")]
473    pub otlp_endpoint: String,
474
475    /// Service name for telemetry identification.
476    pub service_name: Option<String>,
477
478    /// Enable distributed tracing.
479    #[serde(default = "default_true")]
480    pub enable_traces: bool,
481
482    /// Enable metrics collection.
483    #[serde(default = "default_true")]
484    pub enable_metrics: bool,
485
486    /// Enable log export via OTLP.
487    #[serde(default = "default_true")]
488    pub enable_logs: bool,
489
490    /// Trace sampling ratio (0.0 to 1.0).
491    #[serde(default = "default_sampling_ratio")]
492    pub sampling_ratio: f64,
493}
494
495impl Default for ObservabilityConfig {
496    fn default() -> Self {
497        Self {
498            enabled: false,
499            otlp_endpoint: default_otlp_endpoint(),
500            service_name: None,
501            enable_traces: true,
502            enable_metrics: true,
503            enable_logs: true,
504            sampling_ratio: default_sampling_ratio(),
505        }
506    }
507}
508
509fn default_otlp_endpoint() -> String {
510    "http://localhost:4317".to_string()
511}
512
513fn default_true() -> bool {
514    true
515}
516
517fn default_sampling_ratio() -> f64 {
518    1.0
519}
520
521/// MCP server configuration.
522#[derive(Debug, Clone, Serialize, Deserialize)]
523pub struct McpConfig {
524    /// Enable MCP endpoint exposure.
525    #[serde(default)]
526    pub enabled: bool,
527
528    /// MCP endpoint path under the gateway API namespace.
529    #[serde(default = "default_mcp_path")]
530    pub path: String,
531
532    /// Session TTL in seconds.
533    #[serde(default = "default_mcp_session_ttl_secs")]
534    pub session_ttl_secs: u64,
535
536    /// Allowed origins for Origin header validation.
537    #[serde(default)]
538    pub allowed_origins: Vec<String>,
539
540    /// Enforce MCP-Protocol-Version header on post-initialize requests.
541    #[serde(default = "default_true")]
542    pub require_protocol_version_header: bool,
543}
544
545impl Default for McpConfig {
546    fn default() -> Self {
547        Self {
548            enabled: false,
549            path: default_mcp_path(),
550            session_ttl_secs: default_mcp_session_ttl_secs(),
551            allowed_origins: Vec::new(),
552            require_protocol_version_header: default_true(),
553        }
554    }
555}
556
557impl McpConfig {
558    pub fn validate(&self) -> Result<()> {
559        if self.path.is_empty() || !self.path.starts_with('/') {
560            return Err(ForgeError::Config(
561                "mcp.path must start with '/' (example: /mcp)".to_string(),
562            ));
563        }
564        if self.path.contains(' ') {
565            return Err(ForgeError::Config(
566                "mcp.path cannot contain spaces".to_string(),
567            ));
568        }
569        if self.session_ttl_secs == 0 {
570            return Err(ForgeError::Config(
571                "mcp.session_ttl_secs must be greater than 0".to_string(),
572            ));
573        }
574        Ok(())
575    }
576}
577
578fn default_mcp_path() -> String {
579    "/mcp".to_string()
580}
581
582fn default_mcp_session_ttl_secs() -> u64 {
583    60 * 60
584}
585
586/// Substitute environment variables in the format ${VAR_NAME}.
587fn substitute_env_vars(content: &str) -> String {
588    let mut result = content.to_string();
589    let re = regex_lite::Regex::new(r"\$\{([A-Z_][A-Z0-9_]*)\}").expect("valid regex pattern");
590
591    for cap in re.captures_iter(content) {
592        let var_name = &cap[1];
593        if let Ok(value) = std::env::var(var_name) {
594            result = result.replace(&cap[0], &value);
595        }
596    }
597
598    result
599}
600
601#[cfg(test)]
602#[allow(clippy::unwrap_used, clippy::indexing_slicing, unsafe_code)]
603mod tests {
604    use super::*;
605
606    #[test]
607    fn test_default_config() {
608        let config = ForgeConfig::default_with_database_url("postgres://localhost/test");
609        assert_eq!(config.gateway.port, 8080);
610        assert_eq!(config.node.roles.len(), 4);
611        assert_eq!(config.mcp.path, "/mcp");
612        assert!(!config.mcp.enabled);
613    }
614
615    #[test]
616    fn test_parse_minimal_config() {
617        let toml = r#"
618            [database]
619            mode = "remote"
620            url = "postgres://localhost/myapp"
621        "#;
622
623        let config = ForgeConfig::parse_toml(toml).unwrap();
624        assert_eq!(config.database.url(), Some("postgres://localhost/myapp"));
625        assert_eq!(config.gateway.port, 8080);
626    }
627
628    #[test]
629    fn test_parse_full_config() {
630        let toml = r#"
631            [project]
632            name = "my-app"
633            version = "1.0.0"
634
635            [database]
636            mode = "remote"
637            url = "postgres://localhost/myapp"
638            pool_size = 100
639
640            [node]
641            roles = ["gateway", "worker"]
642            worker_capabilities = ["media", "general"]
643
644            [gateway]
645            port = 3000
646            grpc_port = 9001
647        "#;
648
649        let config = ForgeConfig::parse_toml(toml).unwrap();
650        assert_eq!(config.project.name, "my-app");
651        assert_eq!(config.database.pool_size, 100);
652        assert_eq!(config.node.roles.len(), 2);
653        assert_eq!(config.gateway.port, 3000);
654    }
655
656    #[test]
657    fn test_env_var_substitution() {
658        unsafe {
659            std::env::set_var("TEST_DB_URL", "postgres://test:test@localhost/test");
660        }
661
662        let toml = r#"
663            [database]
664            mode = "remote"
665            url = "${TEST_DB_URL}"
666        "#;
667
668        let config = ForgeConfig::parse_toml(toml).unwrap();
669        assert_eq!(
670            config.database.url(),
671            Some("postgres://test:test@localhost/test")
672        );
673
674        unsafe {
675            std::env::remove_var("TEST_DB_URL");
676        }
677    }
678
679    #[test]
680    fn test_auth_validation_no_config() {
681        let auth = AuthConfig::default();
682        assert!(auth.validate().is_ok());
683    }
684
685    #[test]
686    fn test_auth_validation_hmac_with_secret() {
687        let auth = AuthConfig {
688            jwt_secret: Some("my-secret".into()),
689            jwt_algorithm: JwtAlgorithm::HS256,
690            ..Default::default()
691        };
692        assert!(auth.validate().is_ok());
693    }
694
695    #[test]
696    fn test_auth_validation_hmac_missing_secret() {
697        let auth = AuthConfig {
698            jwt_issuer: Some("my-issuer".into()),
699            jwt_algorithm: JwtAlgorithm::HS256,
700            ..Default::default()
701        };
702        let result = auth.validate();
703        assert!(result.is_err());
704        let err_msg = result.unwrap_err().to_string();
705        assert!(err_msg.contains("jwt_secret is required"));
706    }
707
708    #[test]
709    fn test_auth_validation_rsa_with_jwks() {
710        let auth = AuthConfig {
711            jwks_url: Some("https://example.com/.well-known/jwks.json".into()),
712            jwt_algorithm: JwtAlgorithm::RS256,
713            ..Default::default()
714        };
715        assert!(auth.validate().is_ok());
716    }
717
718    #[test]
719    fn test_auth_validation_rsa_missing_jwks() {
720        let auth = AuthConfig {
721            jwt_issuer: Some("my-issuer".into()),
722            jwt_algorithm: JwtAlgorithm::RS256,
723            ..Default::default()
724        };
725        let result = auth.validate();
726        assert!(result.is_err());
727        let err_msg = result.unwrap_err().to_string();
728        assert!(err_msg.contains("jwks_url is required"));
729    }
730
731    #[test]
732    fn test_forge_config_validation_fails_on_empty_url() {
733        let toml = r#"
734            [database]
735            mode = "remote"
736            url = ""
737        "#;
738
739        let result = ForgeConfig::parse_toml(toml);
740        assert!(result.is_err());
741        let err_msg = result.unwrap_err().to_string();
742        assert!(err_msg.contains("database.url is required"));
743    }
744
745    #[test]
746    fn test_forge_config_validation_fails_on_invalid_auth() {
747        let toml = r#"
748            [database]
749            mode = "remote"
750            url = "postgres://localhost/test"
751
752            [auth]
753            jwt_issuer = "my-issuer"
754            jwt_algorithm = "RS256"
755        "#;
756
757        let result = ForgeConfig::parse_toml(toml);
758        assert!(result.is_err());
759        let err_msg = result.unwrap_err().to_string();
760        assert!(err_msg.contains("jwks_url is required"));
761    }
762
763    #[test]
764    fn test_mcp_config_validation_rejects_invalid_path() {
765        let toml = r#"
766            [database]
767            mode = "remote"
768            url = "postgres://localhost/test"
769
770            [mcp]
771            enabled = true
772            path = "mcp"
773        "#;
774
775        let result = ForgeConfig::parse_toml(toml);
776        assert!(result.is_err());
777        let err_msg = result.unwrap_err().to_string();
778        assert!(err_msg.contains("mcp.path must start with '/'"));
779    }
780}