Skip to main content

forge_core/config/
mod.rs

1mod cluster;
2mod database;
3
4pub use cluster::ClusterConfig;
5pub use database::DatabaseConfig;
6
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9
10use crate::error::{ForgeError, Result};
11
12/// Root configuration for FORGE.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ForgeConfig {
15    /// Project metadata.
16    #[serde(default)]
17    pub project: ProjectConfig,
18
19    /// Database configuration.
20    pub database: DatabaseConfig,
21
22    /// Node configuration.
23    #[serde(default)]
24    pub node: NodeConfig,
25
26    /// Gateway configuration.
27    #[serde(default)]
28    pub gateway: GatewayConfig,
29
30    /// Function execution configuration.
31    #[serde(default)]
32    pub function: FunctionConfig,
33
34    /// Worker configuration.
35    #[serde(default)]
36    pub worker: WorkerConfig,
37
38    /// Cluster configuration.
39    #[serde(default)]
40    pub cluster: ClusterConfig,
41
42    /// Security configuration.
43    #[serde(default)]
44    pub security: SecurityConfig,
45
46    /// Authentication configuration.
47    #[serde(default)]
48    pub auth: AuthConfig,
49
50    /// Observability configuration.
51    #[serde(default)]
52    pub observability: ObservabilityConfig,
53
54    /// MCP server configuration.
55    #[serde(default)]
56    pub mcp: McpConfig,
57}
58
59impl ForgeConfig {
60    /// Load configuration from a TOML file.
61    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
62        let content = std::fs::read_to_string(path.as_ref())
63            .map_err(|e| ForgeError::Config(format!("Failed to read config file: {}", e)))?;
64
65        Self::parse_toml(&content)
66    }
67
68    /// Parse configuration from a TOML string.
69    pub fn parse_toml(content: &str) -> Result<Self> {
70        // Substitute environment variables
71        let content = substitute_env_vars(content);
72
73        let config: Self = toml::from_str(&content)
74            .map_err(|e| ForgeError::Config(format!("Failed to parse config: {}", e)))?;
75
76        config.validate()?;
77        Ok(config)
78    }
79
80    /// Validate the configuration for invalid combinations.
81    pub fn validate(&self) -> Result<()> {
82        self.database.validate()?;
83        self.auth.validate()?;
84        self.mcp.validate()?;
85        Ok(())
86    }
87
88    /// Load configuration with defaults.
89    pub fn default_with_database_url(url: &str) -> Self {
90        Self {
91            project: ProjectConfig::default(),
92            database: DatabaseConfig::new(url),
93            node: NodeConfig::default(),
94            gateway: GatewayConfig::default(),
95            function: FunctionConfig::default(),
96            worker: WorkerConfig::default(),
97            cluster: ClusterConfig::default(),
98            security: SecurityConfig::default(),
99            auth: AuthConfig::default(),
100            observability: ObservabilityConfig::default(),
101            mcp: McpConfig::default(),
102        }
103    }
104}
105
106/// Project metadata.
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct ProjectConfig {
109    /// Project name.
110    #[serde(default = "default_project_name")]
111    pub name: String,
112
113    /// Project version.
114    #[serde(default = "default_version")]
115    pub version: String,
116}
117
118impl Default for ProjectConfig {
119    fn default() -> Self {
120        Self {
121            name: default_project_name(),
122            version: default_version(),
123        }
124    }
125}
126
127fn default_project_name() -> String {
128    "forge-app".to_string()
129}
130
131fn default_version() -> String {
132    "0.1.0".to_string()
133}
134
135/// Node role configuration.
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct NodeConfig {
138    /// Roles this node should assume.
139    #[serde(default = "default_roles")]
140    pub roles: Vec<NodeRole>,
141
142    /// Worker capabilities for job routing.
143    #[serde(default = "default_capabilities")]
144    pub worker_capabilities: Vec<String>,
145}
146
147impl Default for NodeConfig {
148    fn default() -> Self {
149        Self {
150            roles: default_roles(),
151            worker_capabilities: default_capabilities(),
152        }
153    }
154}
155
156fn default_roles() -> Vec<NodeRole> {
157    vec![
158        NodeRole::Gateway,
159        NodeRole::Function,
160        NodeRole::Worker,
161        NodeRole::Scheduler,
162    ]
163}
164
165fn default_capabilities() -> Vec<String> {
166    vec!["general".to_string()]
167}
168
169/// Available node roles.
170#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
171#[serde(rename_all = "lowercase")]
172pub enum NodeRole {
173    Gateway,
174    Function,
175    Worker,
176    Scheduler,
177}
178
179/// Gateway configuration.
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct GatewayConfig {
182    /// HTTP port.
183    #[serde(default = "default_http_port")]
184    pub port: u16,
185
186    /// gRPC port for inter-node communication.
187    #[serde(default = "default_grpc_port")]
188    pub grpc_port: u16,
189
190    /// Maximum concurrent connections.
191    #[serde(default = "default_max_connections")]
192    pub max_connections: usize,
193
194    /// Request timeout in seconds.
195    #[serde(default = "default_request_timeout")]
196    pub request_timeout_secs: u64,
197
198    /// Enable CORS handling.
199    #[serde(default = "default_cors_enabled")]
200    pub cors_enabled: bool,
201
202    /// Allowed CORS origins.
203    #[serde(default = "default_cors_origins")]
204    pub cors_origins: Vec<String>,
205}
206
207impl Default for GatewayConfig {
208    fn default() -> Self {
209        Self {
210            port: default_http_port(),
211            grpc_port: default_grpc_port(),
212            max_connections: default_max_connections(),
213            request_timeout_secs: default_request_timeout(),
214            cors_enabled: default_cors_enabled(),
215            cors_origins: default_cors_origins(),
216        }
217    }
218}
219
220fn default_http_port() -> u16 {
221    8080
222}
223
224fn default_grpc_port() -> u16 {
225    9000
226}
227
228fn default_max_connections() -> usize {
229    512
230}
231
232fn default_request_timeout() -> u64 {
233    30
234}
235
236fn default_cors_enabled() -> bool {
237    false
238}
239
240fn default_cors_origins() -> Vec<String> {
241    Vec::new()
242}
243
244/// Function execution configuration.
245#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct FunctionConfig {
247    /// Maximum concurrent function executions.
248    #[serde(default = "default_max_concurrent")]
249    pub max_concurrent: usize,
250
251    /// Function timeout in seconds.
252    #[serde(default = "default_function_timeout")]
253    pub timeout_secs: u64,
254
255    /// Memory limit per function (in bytes).
256    #[serde(default = "default_memory_limit")]
257    pub memory_limit: usize,
258}
259
260impl Default for FunctionConfig {
261    fn default() -> Self {
262        Self {
263            max_concurrent: default_max_concurrent(),
264            timeout_secs: default_function_timeout(),
265            memory_limit: default_memory_limit(),
266        }
267    }
268}
269
270fn default_max_concurrent() -> usize {
271    1000
272}
273
274fn default_function_timeout() -> u64 {
275    30
276}
277
278fn default_memory_limit() -> usize {
279    512 * 1024 * 1024 // 512 MiB
280}
281
282/// Worker configuration.
283#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct WorkerConfig {
285    /// Maximum concurrent jobs.
286    #[serde(default = "default_max_concurrent_jobs")]
287    pub max_concurrent_jobs: usize,
288
289    /// Job timeout in seconds.
290    #[serde(default = "default_job_timeout")]
291    pub job_timeout_secs: u64,
292
293    /// Poll interval in milliseconds.
294    #[serde(default = "default_poll_interval")]
295    pub poll_interval_ms: u64,
296}
297
298impl Default for WorkerConfig {
299    fn default() -> Self {
300        Self {
301            max_concurrent_jobs: default_max_concurrent_jobs(),
302            job_timeout_secs: default_job_timeout(),
303            poll_interval_ms: default_poll_interval(),
304        }
305    }
306}
307
308fn default_max_concurrent_jobs() -> usize {
309    50
310}
311
312fn default_job_timeout() -> u64 {
313    3600 // 1 hour
314}
315
316fn default_poll_interval() -> u64 {
317    100
318}
319
320/// Security configuration.
321#[derive(Debug, Clone, Serialize, Deserialize, Default)]
322pub struct SecurityConfig {
323    /// Secret key for signing.
324    pub secret_key: Option<String>,
325}
326
327/// JWT signing algorithm.
328#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
329#[serde(rename_all = "UPPERCASE")]
330pub enum JwtAlgorithm {
331    /// HMAC using SHA-256 (symmetric, requires jwt_secret).
332    #[default]
333    HS256,
334    /// HMAC using SHA-384 (symmetric, requires jwt_secret).
335    HS384,
336    /// HMAC using SHA-512 (symmetric, requires jwt_secret).
337    HS512,
338    /// RSA using SHA-256 (asymmetric, requires jwks_url).
339    RS256,
340    /// RSA using SHA-384 (asymmetric, requires jwks_url).
341    RS384,
342    /// RSA using SHA-512 (asymmetric, requires jwks_url).
343    RS512,
344}
345
346/// Authentication configuration.
347#[derive(Debug, Clone, Serialize, Deserialize)]
348pub struct AuthConfig {
349    /// JWT secret for HMAC algorithms (HS256, HS384, HS512).
350    /// Required when using HMAC algorithms.
351    pub jwt_secret: Option<String>,
352
353    /// JWT signing algorithm.
354    /// HMAC algorithms (HS256, HS384, HS512) require jwt_secret.
355    /// RSA algorithms (RS256, RS384, RS512) require jwks_url.
356    #[serde(default)]
357    pub jwt_algorithm: JwtAlgorithm,
358
359    /// Expected token issuer (iss claim).
360    /// If set, tokens with a different issuer are rejected.
361    pub jwt_issuer: Option<String>,
362
363    /// Expected audience (aud claim).
364    /// If set, tokens with a different audience are rejected.
365    pub jwt_audience: Option<String>,
366
367    /// Token expiry duration (e.g., "15m", "1h", "7d").
368    pub token_expiry: Option<String>,
369
370    /// JWKS URL for RSA algorithms (RS256, RS384, RS512).
371    /// Keys are fetched and cached automatically.
372    pub jwks_url: Option<String>,
373
374    /// JWKS cache TTL in seconds.
375    #[serde(default = "default_jwks_cache_ttl")]
376    pub jwks_cache_ttl_secs: u64,
377
378    /// Session TTL in seconds (for WebSocket sessions).
379    #[serde(default = "default_session_ttl")]
380    pub session_ttl_secs: u64,
381}
382
383impl Default for AuthConfig {
384    fn default() -> Self {
385        Self {
386            jwt_secret: None,
387            jwt_algorithm: JwtAlgorithm::default(),
388            jwt_issuer: None,
389            jwt_audience: None,
390            token_expiry: None,
391            jwks_url: None,
392            jwks_cache_ttl_secs: default_jwks_cache_ttl(),
393            session_ttl_secs: default_session_ttl(),
394        }
395    }
396}
397
398impl AuthConfig {
399    /// Check if auth is configured (any credential or claim validation is set).
400    fn is_configured(&self) -> bool {
401        self.jwt_secret.is_some()
402            || self.jwks_url.is_some()
403            || self.jwt_issuer.is_some()
404            || self.jwt_audience.is_some()
405    }
406
407    /// Validate that the configuration is complete for the chosen algorithm.
408    /// Skips validation if no auth settings are configured (auth disabled).
409    pub fn validate(&self) -> Result<()> {
410        if !self.is_configured() {
411            return Ok(());
412        }
413
414        match self.jwt_algorithm {
415            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512 => {
416                if self.jwt_secret.is_none() {
417                    return Err(ForgeError::Config(
418                        "auth.jwt_secret is required for HMAC algorithms (HS256, HS384, HS512). \
419                         Set auth.jwt_secret to a secure random string, \
420                         or switch to RS256 and provide auth.jwks_url for external identity providers."
421                            .into(),
422                    ));
423                }
424            }
425            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512 => {
426                if self.jwks_url.is_none() {
427                    return Err(ForgeError::Config(
428                        "auth.jwks_url is required for RSA algorithms (RS256, RS384, RS512). \
429                         Set auth.jwks_url to your identity provider's JWKS endpoint, \
430                         or switch to HS256 and provide auth.jwt_secret for symmetric signing."
431                            .into(),
432                    ));
433                }
434            }
435        }
436        Ok(())
437    }
438
439    /// Check if this config uses HMAC (symmetric) algorithms.
440    pub fn is_hmac(&self) -> bool {
441        matches!(
442            self.jwt_algorithm,
443            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
444        )
445    }
446
447    /// Check if this config uses RSA (asymmetric) algorithms.
448    pub fn is_rsa(&self) -> bool {
449        matches!(
450            self.jwt_algorithm,
451            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
452        )
453    }
454}
455
456fn default_jwks_cache_ttl() -> u64 {
457    3600 // 1 hour
458}
459
460fn default_session_ttl() -> u64 {
461    7 * 24 * 60 * 60 // 7 days
462}
463
464/// Observability configuration for OTLP telemetry.
465#[derive(Debug, Clone, Serialize, Deserialize)]
466pub struct ObservabilityConfig {
467    /// Enable observability (traces, metrics, logs).
468    #[serde(default)]
469    pub enabled: bool,
470
471    /// OTLP endpoint for telemetry export.
472    #[serde(default = "default_otlp_endpoint")]
473    pub otlp_endpoint: String,
474
475    /// Service name for telemetry identification.
476    pub service_name: Option<String>,
477
478    /// Enable distributed tracing.
479    #[serde(default = "default_true")]
480    pub enable_traces: bool,
481
482    /// Enable metrics collection.
483    #[serde(default = "default_true")]
484    pub enable_metrics: bool,
485
486    /// Enable log export via OTLP.
487    #[serde(default = "default_true")]
488    pub enable_logs: bool,
489
490    /// Trace sampling ratio (0.0 to 1.0).
491    #[serde(default = "default_sampling_ratio")]
492    pub sampling_ratio: f64,
493
494    /// Log level for the tracing subscriber (e.g., "debug", "info", "warn").
495    #[serde(default = "default_log_level")]
496    pub log_level: String,
497}
498
499impl Default for ObservabilityConfig {
500    fn default() -> Self {
501        Self {
502            enabled: false,
503            otlp_endpoint: default_otlp_endpoint(),
504            service_name: None,
505            enable_traces: true,
506            enable_metrics: true,
507            enable_logs: true,
508            sampling_ratio: default_sampling_ratio(),
509            log_level: default_log_level(),
510        }
511    }
512}
513
514fn default_otlp_endpoint() -> String {
515    "http://localhost:4317".to_string()
516}
517
518fn default_true() -> bool {
519    true
520}
521
522fn default_sampling_ratio() -> f64 {
523    1.0
524}
525
526fn default_log_level() -> String {
527    "info".to_string()
528}
529
530/// MCP server configuration.
531#[derive(Debug, Clone, Serialize, Deserialize)]
532pub struct McpConfig {
533    /// Enable MCP endpoint exposure.
534    #[serde(default)]
535    pub enabled: bool,
536
537    /// MCP endpoint path under the gateway API namespace.
538    #[serde(default = "default_mcp_path")]
539    pub path: String,
540
541    /// Session TTL in seconds.
542    #[serde(default = "default_mcp_session_ttl_secs")]
543    pub session_ttl_secs: u64,
544
545    /// Allowed origins for Origin header validation.
546    #[serde(default)]
547    pub allowed_origins: Vec<String>,
548
549    /// Enforce MCP-Protocol-Version header on post-initialize requests.
550    #[serde(default = "default_true")]
551    pub require_protocol_version_header: bool,
552}
553
554impl Default for McpConfig {
555    fn default() -> Self {
556        Self {
557            enabled: false,
558            path: default_mcp_path(),
559            session_ttl_secs: default_mcp_session_ttl_secs(),
560            allowed_origins: Vec::new(),
561            require_protocol_version_header: default_true(),
562        }
563    }
564}
565
566impl McpConfig {
567    pub fn validate(&self) -> Result<()> {
568        if self.path.is_empty() || !self.path.starts_with('/') {
569            return Err(ForgeError::Config(
570                "mcp.path must start with '/' (example: /mcp)".to_string(),
571            ));
572        }
573        if self.path.contains(' ') {
574            return Err(ForgeError::Config(
575                "mcp.path cannot contain spaces".to_string(),
576            ));
577        }
578        if self.session_ttl_secs == 0 {
579            return Err(ForgeError::Config(
580                "mcp.session_ttl_secs must be greater than 0".to_string(),
581            ));
582        }
583        Ok(())
584    }
585}
586
587fn default_mcp_path() -> String {
588    "/mcp".to_string()
589}
590
591fn default_mcp_session_ttl_secs() -> u64 {
592    60 * 60
593}
594
595/// Substitute environment variables in the format ${VAR_NAME}.
596fn substitute_env_vars(content: &str) -> String {
597    let mut result = content.to_string();
598    let re = regex_lite::Regex::new(r"\$\{([A-Z_][A-Z0-9_]*)\}").expect("valid regex pattern");
599
600    for cap in re.captures_iter(content) {
601        let var_name = &cap[1];
602        if let Ok(value) = std::env::var(var_name) {
603            result = result.replace(&cap[0], &value);
604        }
605    }
606
607    result
608}
609
610#[cfg(test)]
611#[allow(clippy::unwrap_used, clippy::indexing_slicing, unsafe_code)]
612mod tests {
613    use super::*;
614
615    #[test]
616    fn test_default_config() {
617        let config = ForgeConfig::default_with_database_url("postgres://localhost/test");
618        assert_eq!(config.gateway.port, 8080);
619        assert_eq!(config.node.roles.len(), 4);
620        assert_eq!(config.mcp.path, "/mcp");
621        assert!(!config.mcp.enabled);
622    }
623
624    #[test]
625    fn test_parse_minimal_config() {
626        let toml = r#"
627            [database]
628            url = "postgres://localhost/myapp"
629        "#;
630
631        let config = ForgeConfig::parse_toml(toml).unwrap();
632        assert_eq!(config.database.url(), "postgres://localhost/myapp");
633        assert_eq!(config.gateway.port, 8080);
634    }
635
636    #[test]
637    fn test_parse_full_config() {
638        let toml = r#"
639            [project]
640            name = "my-app"
641            version = "1.0.0"
642
643            [database]
644            url = "postgres://localhost/myapp"
645            pool_size = 100
646
647            [node]
648            roles = ["gateway", "worker"]
649            worker_capabilities = ["media", "general"]
650
651            [gateway]
652            port = 3000
653            grpc_port = 9001
654        "#;
655
656        let config = ForgeConfig::parse_toml(toml).unwrap();
657        assert_eq!(config.project.name, "my-app");
658        assert_eq!(config.database.pool_size, 100);
659        assert_eq!(config.node.roles.len(), 2);
660        assert_eq!(config.gateway.port, 3000);
661    }
662
663    #[test]
664    fn test_env_var_substitution() {
665        unsafe {
666            std::env::set_var("TEST_DB_URL", "postgres://test:test@localhost/test");
667        }
668
669        let toml = r#"
670            [database]
671            url = "${TEST_DB_URL}"
672        "#;
673
674        let config = ForgeConfig::parse_toml(toml).unwrap();
675        assert_eq!(config.database.url(), "postgres://test:test@localhost/test");
676
677        unsafe {
678            std::env::remove_var("TEST_DB_URL");
679        }
680    }
681
682    #[test]
683    fn test_auth_validation_no_config() {
684        let auth = AuthConfig::default();
685        assert!(auth.validate().is_ok());
686    }
687
688    #[test]
689    fn test_auth_validation_hmac_with_secret() {
690        let auth = AuthConfig {
691            jwt_secret: Some("my-secret".into()),
692            jwt_algorithm: JwtAlgorithm::HS256,
693            ..Default::default()
694        };
695        assert!(auth.validate().is_ok());
696    }
697
698    #[test]
699    fn test_auth_validation_hmac_missing_secret() {
700        let auth = AuthConfig {
701            jwt_issuer: Some("my-issuer".into()),
702            jwt_algorithm: JwtAlgorithm::HS256,
703            ..Default::default()
704        };
705        let result = auth.validate();
706        assert!(result.is_err());
707        let err_msg = result.unwrap_err().to_string();
708        assert!(err_msg.contains("jwt_secret is required"));
709    }
710
711    #[test]
712    fn test_auth_validation_rsa_with_jwks() {
713        let auth = AuthConfig {
714            jwks_url: Some("https://example.com/.well-known/jwks.json".into()),
715            jwt_algorithm: JwtAlgorithm::RS256,
716            ..Default::default()
717        };
718        assert!(auth.validate().is_ok());
719    }
720
721    #[test]
722    fn test_auth_validation_rsa_missing_jwks() {
723        let auth = AuthConfig {
724            jwt_issuer: Some("my-issuer".into()),
725            jwt_algorithm: JwtAlgorithm::RS256,
726            ..Default::default()
727        };
728        let result = auth.validate();
729        assert!(result.is_err());
730        let err_msg = result.unwrap_err().to_string();
731        assert!(err_msg.contains("jwks_url is required"));
732    }
733
734    #[test]
735    fn test_forge_config_validation_fails_on_empty_url() {
736        let toml = r#"
737            [database]
738
739            url = ""
740        "#;
741
742        let result = ForgeConfig::parse_toml(toml);
743        assert!(result.is_err());
744        let err_msg = result.unwrap_err().to_string();
745        assert!(err_msg.contains("database.url is required"));
746    }
747
748    #[test]
749    fn test_forge_config_validation_fails_on_invalid_auth() {
750        let toml = r#"
751            [database]
752
753            url = "postgres://localhost/test"
754
755            [auth]
756            jwt_issuer = "my-issuer"
757            jwt_algorithm = "RS256"
758        "#;
759
760        let result = ForgeConfig::parse_toml(toml);
761        assert!(result.is_err());
762        let err_msg = result.unwrap_err().to_string();
763        assert!(err_msg.contains("jwks_url is required"));
764    }
765
766    #[test]
767    fn test_mcp_config_validation_rejects_invalid_path() {
768        let toml = r#"
769            [database]
770
771            url = "postgres://localhost/test"
772
773            [mcp]
774            enabled = true
775            path = "mcp"
776        "#;
777
778        let result = ForgeConfig::parse_toml(toml);
779        assert!(result.is_err());
780        let err_msg = result.unwrap_err().to_string();
781        assert!(err_msg.contains("mcp.path must start with '/'"));
782    }
783}