Skip to main content

forge_core/config/
mod.rs

1mod cluster;
2mod database;
3
4pub use cluster::ClusterConfig;
5pub use database::{DatabaseConfig, PoolConfig};
6
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9
10use crate::error::{ForgeError, Result};
11
12/// Root configuration for FORGE.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ForgeConfig {
15    /// Project metadata.
16    #[serde(default)]
17    pub project: ProjectConfig,
18
19    /// Database configuration.
20    pub database: DatabaseConfig,
21
22    /// Node configuration.
23    #[serde(default)]
24    pub node: NodeConfig,
25
26    /// Gateway configuration.
27    #[serde(default)]
28    pub gateway: GatewayConfig,
29
30    /// Function execution configuration.
31    #[serde(default)]
32    pub function: FunctionConfig,
33
34    /// Worker configuration.
35    #[serde(default)]
36    pub worker: WorkerConfig,
37
38    /// Cluster configuration.
39    #[serde(default)]
40    pub cluster: ClusterConfig,
41
42    /// Security configuration.
43    #[serde(default)]
44    pub security: SecurityConfig,
45
46    /// Authentication configuration.
47    #[serde(default)]
48    pub auth: AuthConfig,
49
50    /// Observability configuration.
51    #[serde(default)]
52    pub observability: ObservabilityConfig,
53
54    /// MCP server configuration.
55    #[serde(default)]
56    pub mcp: McpConfig,
57}
58
59impl ForgeConfig {
60    /// Load configuration from a TOML file.
61    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
62        let content = std::fs::read_to_string(path.as_ref())
63            .map_err(|e| ForgeError::Config(format!("Failed to read config file: {}", e)))?;
64
65        Self::parse_toml(&content)
66    }
67
68    /// Parse configuration from a TOML string.
69    pub fn parse_toml(content: &str) -> Result<Self> {
70        // Substitute environment variables
71        let content = substitute_env_vars(content);
72
73        let config: Self = toml::from_str(&content)
74            .map_err(|e| ForgeError::Config(format!("Failed to parse config: {}", e)))?;
75
76        config.validate()?;
77        Ok(config)
78    }
79
80    /// Validate the configuration for invalid combinations.
81    pub fn validate(&self) -> Result<()> {
82        self.database.validate()?;
83        self.auth.validate()?;
84        self.mcp.validate()?;
85        Ok(())
86    }
87
88    /// Load configuration with defaults.
89    pub fn default_with_database_url(url: &str) -> Self {
90        Self {
91            project: ProjectConfig::default(),
92            database: DatabaseConfig::new(url),
93            node: NodeConfig::default(),
94            gateway: GatewayConfig::default(),
95            function: FunctionConfig::default(),
96            worker: WorkerConfig::default(),
97            cluster: ClusterConfig::default(),
98            security: SecurityConfig::default(),
99            auth: AuthConfig::default(),
100            observability: ObservabilityConfig::default(),
101            mcp: McpConfig::default(),
102        }
103    }
104}
105
106/// Project metadata.
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct ProjectConfig {
109    /// Project name.
110    #[serde(default = "default_project_name")]
111    pub name: String,
112
113    /// Project version.
114    #[serde(default = "default_version")]
115    pub version: String,
116}
117
118impl Default for ProjectConfig {
119    fn default() -> Self {
120        Self {
121            name: default_project_name(),
122            version: default_version(),
123        }
124    }
125}
126
127fn default_project_name() -> String {
128    "forge-app".to_string()
129}
130
131fn default_version() -> String {
132    "0.1.0".to_string()
133}
134
135/// Node role configuration.
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct NodeConfig {
138    /// Roles this node should assume.
139    #[serde(default = "default_roles")]
140    pub roles: Vec<NodeRole>,
141
142    /// Worker capabilities for job routing.
143    #[serde(default = "default_capabilities")]
144    pub worker_capabilities: Vec<String>,
145}
146
147impl Default for NodeConfig {
148    fn default() -> Self {
149        Self {
150            roles: default_roles(),
151            worker_capabilities: default_capabilities(),
152        }
153    }
154}
155
156fn default_roles() -> Vec<NodeRole> {
157    vec![
158        NodeRole::Gateway,
159        NodeRole::Function,
160        NodeRole::Worker,
161        NodeRole::Scheduler,
162    ]
163}
164
165fn default_capabilities() -> Vec<String> {
166    vec!["general".to_string()]
167}
168
169/// Available node roles.
170#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
171#[serde(rename_all = "lowercase")]
172pub enum NodeRole {
173    Gateway,
174    Function,
175    Worker,
176    Scheduler,
177}
178
179/// Gateway configuration.
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct GatewayConfig {
182    /// HTTP port.
183    #[serde(default = "default_http_port")]
184    pub port: u16,
185
186    /// gRPC port for inter-node communication.
187    #[serde(default = "default_grpc_port")]
188    pub grpc_port: u16,
189
190    /// Maximum concurrent connections.
191    #[serde(default = "default_max_connections")]
192    pub max_connections: usize,
193
194    /// Request timeout in seconds.
195    #[serde(default = "default_request_timeout")]
196    pub request_timeout_secs: u64,
197
198    /// Enable CORS handling.
199    #[serde(default = "default_cors_enabled")]
200    pub cors_enabled: bool,
201
202    /// Allowed CORS origins.
203    #[serde(default = "default_cors_origins")]
204    pub cors_origins: Vec<String>,
205
206    /// Routes excluded from request logs, metrics, and traces.
207    /// Defaults to `["/_api/health", "/_api/ready"]`. Set to `[]` to monitor everything.
208    #[serde(default = "default_quiet_routes")]
209    pub quiet_routes: Vec<String>,
210}
211
212impl Default for GatewayConfig {
213    fn default() -> Self {
214        Self {
215            port: default_http_port(),
216            grpc_port: default_grpc_port(),
217            max_connections: default_max_connections(),
218            request_timeout_secs: default_request_timeout(),
219            cors_enabled: default_cors_enabled(),
220            cors_origins: default_cors_origins(),
221            quiet_routes: default_quiet_routes(),
222        }
223    }
224}
225
226fn default_http_port() -> u16 {
227    8080
228}
229
230fn default_grpc_port() -> u16 {
231    9000
232}
233
234fn default_max_connections() -> usize {
235    512
236}
237
238fn default_request_timeout() -> u64 {
239    30
240}
241
242fn default_cors_enabled() -> bool {
243    false
244}
245
246fn default_cors_origins() -> Vec<String> {
247    Vec::new()
248}
249
250fn default_quiet_routes() -> Vec<String> {
251    vec!["/_api/health".to_string(), "/_api/ready".to_string()]
252}
253
254/// Function execution configuration.
255#[derive(Debug, Clone, Serialize, Deserialize)]
256pub struct FunctionConfig {
257    /// Maximum concurrent function executions.
258    #[serde(default = "default_max_concurrent")]
259    pub max_concurrent: usize,
260
261    /// Function timeout in seconds.
262    #[serde(default = "default_function_timeout")]
263    pub timeout_secs: u64,
264
265    /// Memory limit per function (in bytes).
266    #[serde(default = "default_memory_limit")]
267    pub memory_limit: usize,
268}
269
270impl Default for FunctionConfig {
271    fn default() -> Self {
272        Self {
273            max_concurrent: default_max_concurrent(),
274            timeout_secs: default_function_timeout(),
275            memory_limit: default_memory_limit(),
276        }
277    }
278}
279
280fn default_max_concurrent() -> usize {
281    1000
282}
283
284fn default_function_timeout() -> u64 {
285    30
286}
287
288fn default_memory_limit() -> usize {
289    512 * 1024 * 1024 // 512 MiB
290}
291
292/// Worker configuration.
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct WorkerConfig {
295    /// Maximum concurrent jobs.
296    #[serde(default = "default_max_concurrent_jobs")]
297    pub max_concurrent_jobs: usize,
298
299    /// Job timeout in seconds.
300    #[serde(default = "default_job_timeout")]
301    pub job_timeout_secs: u64,
302
303    /// Poll interval in milliseconds.
304    #[serde(default = "default_poll_interval")]
305    pub poll_interval_ms: u64,
306}
307
308impl Default for WorkerConfig {
309    fn default() -> Self {
310        Self {
311            max_concurrent_jobs: default_max_concurrent_jobs(),
312            job_timeout_secs: default_job_timeout(),
313            poll_interval_ms: default_poll_interval(),
314        }
315    }
316}
317
318fn default_max_concurrent_jobs() -> usize {
319    50
320}
321
322fn default_job_timeout() -> u64 {
323    3600 // 1 hour
324}
325
326fn default_poll_interval() -> u64 {
327    100
328}
329
330/// Security configuration.
331#[derive(Debug, Clone, Serialize, Deserialize, Default)]
332pub struct SecurityConfig {
333    /// Secret key for signing.
334    pub secret_key: Option<String>,
335}
336
337/// JWT signing algorithm.
338#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
339#[serde(rename_all = "UPPERCASE")]
340pub enum JwtAlgorithm {
341    /// HMAC using SHA-256 (symmetric, requires jwt_secret).
342    #[default]
343    HS256,
344    /// HMAC using SHA-384 (symmetric, requires jwt_secret).
345    HS384,
346    /// HMAC using SHA-512 (symmetric, requires jwt_secret).
347    HS512,
348    /// RSA using SHA-256 (asymmetric, requires jwks_url).
349    RS256,
350    /// RSA using SHA-384 (asymmetric, requires jwks_url).
351    RS384,
352    /// RSA using SHA-512 (asymmetric, requires jwks_url).
353    RS512,
354}
355
356/// Authentication configuration.
357#[derive(Debug, Clone, Serialize, Deserialize)]
358pub struct AuthConfig {
359    /// JWT secret for HMAC algorithms (HS256, HS384, HS512).
360    /// Required when using HMAC algorithms.
361    pub jwt_secret: Option<String>,
362
363    /// JWT signing algorithm.
364    /// HMAC algorithms (HS256, HS384, HS512) require jwt_secret.
365    /// RSA algorithms (RS256, RS384, RS512) require jwks_url.
366    #[serde(default)]
367    pub jwt_algorithm: JwtAlgorithm,
368
369    /// Expected token issuer (iss claim).
370    /// If set, tokens with a different issuer are rejected.
371    pub jwt_issuer: Option<String>,
372
373    /// Expected audience (aud claim).
374    /// If set, tokens with a different audience are rejected.
375    pub jwt_audience: Option<String>,
376
377    /// Token expiry duration (e.g., "15m", "1h", "7d").
378    pub token_expiry: Option<String>,
379
380    /// JWKS URL for RSA algorithms (RS256, RS384, RS512).
381    /// Keys are fetched and cached automatically.
382    pub jwks_url: Option<String>,
383
384    /// JWKS cache TTL in seconds.
385    #[serde(default = "default_jwks_cache_ttl")]
386    pub jwks_cache_ttl_secs: u64,
387
388    /// Session TTL in seconds (for WebSocket sessions).
389    #[serde(default = "default_session_ttl")]
390    pub session_ttl_secs: u64,
391}
392
393impl Default for AuthConfig {
394    fn default() -> Self {
395        Self {
396            jwt_secret: None,
397            jwt_algorithm: JwtAlgorithm::default(),
398            jwt_issuer: None,
399            jwt_audience: None,
400            token_expiry: None,
401            jwks_url: None,
402            jwks_cache_ttl_secs: default_jwks_cache_ttl(),
403            session_ttl_secs: default_session_ttl(),
404        }
405    }
406}
407
408impl AuthConfig {
409    /// Check if auth is configured (any credential or claim validation is set).
410    fn is_configured(&self) -> bool {
411        self.jwt_secret.is_some()
412            || self.jwks_url.is_some()
413            || self.jwt_issuer.is_some()
414            || self.jwt_audience.is_some()
415    }
416
417    /// Validate that the configuration is complete for the chosen algorithm.
418    /// Skips validation if no auth settings are configured (auth disabled).
419    pub fn validate(&self) -> Result<()> {
420        if !self.is_configured() {
421            return Ok(());
422        }
423
424        match self.jwt_algorithm {
425            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512 => {
426                if self.jwt_secret.is_none() {
427                    return Err(ForgeError::Config(
428                        "auth.jwt_secret is required for HMAC algorithms (HS256, HS384, HS512). \
429                         Set auth.jwt_secret to a secure random string, \
430                         or switch to RS256 and provide auth.jwks_url for external identity providers."
431                            .into(),
432                    ));
433                }
434            }
435            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512 => {
436                if self.jwks_url.is_none() {
437                    return Err(ForgeError::Config(
438                        "auth.jwks_url is required for RSA algorithms (RS256, RS384, RS512). \
439                         Set auth.jwks_url to your identity provider's JWKS endpoint, \
440                         or switch to HS256 and provide auth.jwt_secret for symmetric signing."
441                            .into(),
442                    ));
443                }
444            }
445        }
446        Ok(())
447    }
448
449    /// Check if this config uses HMAC (symmetric) algorithms.
450    pub fn is_hmac(&self) -> bool {
451        matches!(
452            self.jwt_algorithm,
453            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
454        )
455    }
456
457    /// Check if this config uses RSA (asymmetric) algorithms.
458    pub fn is_rsa(&self) -> bool {
459        matches!(
460            self.jwt_algorithm,
461            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
462        )
463    }
464}
465
466fn default_jwks_cache_ttl() -> u64 {
467    3600 // 1 hour
468}
469
470fn default_session_ttl() -> u64 {
471    7 * 24 * 60 * 60 // 7 days
472}
473
474/// Observability configuration for OTLP telemetry.
475#[derive(Debug, Clone, Serialize, Deserialize)]
476pub struct ObservabilityConfig {
477    /// Enable observability (traces, metrics, logs).
478    #[serde(default)]
479    pub enabled: bool,
480
481    /// OTLP endpoint for telemetry export.
482    #[serde(default = "default_otlp_endpoint")]
483    pub otlp_endpoint: String,
484
485    /// Service name for telemetry identification.
486    pub service_name: Option<String>,
487
488    /// Enable distributed tracing.
489    #[serde(default = "default_true")]
490    pub enable_traces: bool,
491
492    /// Enable metrics collection.
493    #[serde(default = "default_true")]
494    pub enable_metrics: bool,
495
496    /// Enable log export via OTLP.
497    #[serde(default = "default_true")]
498    pub enable_logs: bool,
499
500    /// Trace sampling ratio (0.0 to 1.0).
501    #[serde(default = "default_sampling_ratio")]
502    pub sampling_ratio: f64,
503
504    /// Log level for the tracing subscriber (e.g., "debug", "info", "warn").
505    #[serde(default = "default_log_level")]
506    pub log_level: String,
507}
508
509impl Default for ObservabilityConfig {
510    fn default() -> Self {
511        Self {
512            enabled: false,
513            otlp_endpoint: default_otlp_endpoint(),
514            service_name: None,
515            enable_traces: true,
516            enable_metrics: true,
517            enable_logs: true,
518            sampling_ratio: default_sampling_ratio(),
519            log_level: default_log_level(),
520        }
521    }
522}
523
524fn default_otlp_endpoint() -> String {
525    "http://localhost:4318".to_string()
526}
527
528fn default_true() -> bool {
529    true
530}
531
532fn default_sampling_ratio() -> f64 {
533    1.0
534}
535
536fn default_log_level() -> String {
537    "info".to_string()
538}
539
540/// MCP server configuration.
541#[derive(Debug, Clone, Serialize, Deserialize)]
542pub struct McpConfig {
543    /// Enable MCP endpoint exposure.
544    #[serde(default)]
545    pub enabled: bool,
546
547    /// MCP endpoint path under the gateway API namespace.
548    #[serde(default = "default_mcp_path")]
549    pub path: String,
550
551    /// Session TTL in seconds.
552    #[serde(default = "default_mcp_session_ttl_secs")]
553    pub session_ttl_secs: u64,
554
555    /// Allowed origins for Origin header validation.
556    #[serde(default)]
557    pub allowed_origins: Vec<String>,
558
559    /// Enforce MCP-Protocol-Version header on post-initialize requests.
560    #[serde(default = "default_true")]
561    pub require_protocol_version_header: bool,
562}
563
564impl Default for McpConfig {
565    fn default() -> Self {
566        Self {
567            enabled: false,
568            path: default_mcp_path(),
569            session_ttl_secs: default_mcp_session_ttl_secs(),
570            allowed_origins: Vec::new(),
571            require_protocol_version_header: default_true(),
572        }
573    }
574}
575
576impl McpConfig {
577    pub fn validate(&self) -> Result<()> {
578        if self.path.is_empty() || !self.path.starts_with('/') {
579            return Err(ForgeError::Config(
580                "mcp.path must start with '/' (example: /mcp)".to_string(),
581            ));
582        }
583        if self.path.contains(' ') {
584            return Err(ForgeError::Config(
585                "mcp.path cannot contain spaces".to_string(),
586            ));
587        }
588        if self.session_ttl_secs == 0 {
589            return Err(ForgeError::Config(
590                "mcp.session_ttl_secs must be greater than 0".to_string(),
591            ));
592        }
593        Ok(())
594    }
595}
596
597fn default_mcp_path() -> String {
598    "/mcp".to_string()
599}
600
601fn default_mcp_session_ttl_secs() -> u64 {
602    60 * 60
603}
604
605/// Substitute environment variables in the format ${VAR_NAME}.
606#[allow(clippy::indexing_slicing)]
607fn substitute_env_vars(content: &str) -> String {
608    let mut result = String::with_capacity(content.len());
609    let bytes = content.as_bytes();
610    let len = bytes.len();
611    let mut i = 0;
612
613    while i < len {
614        if i + 1 < len
615            && bytes[i] == b'$'
616            && bytes[i + 1] == b'{'
617            && let Some(end) = content[i + 2..].find('}')
618        {
619            let var_name = &content[i + 2..i + 2 + end];
620            if is_valid_env_var_name(var_name) {
621                if let Ok(value) = std::env::var(var_name) {
622                    result.push_str(&value);
623                } else {
624                    result.push_str(&content[i..i + 2 + end + 1]);
625                }
626                i += 2 + end + 1;
627                continue;
628            }
629        }
630        result.push(bytes[i] as char);
631        i += 1;
632    }
633
634    result
635}
636
637fn is_valid_env_var_name(name: &str) -> bool {
638    let first = match name.as_bytes().first() {
639        Some(b) => b,
640        None => return false,
641    };
642    (first.is_ascii_uppercase() || *first == b'_')
643        && name
644            .bytes()
645            .all(|b| b.is_ascii_uppercase() || b.is_ascii_digit() || b == b'_')
646}
647
648#[cfg(test)]
649#[allow(clippy::unwrap_used, clippy::indexing_slicing, unsafe_code)]
650mod tests {
651    use super::*;
652
653    #[test]
654    fn test_default_config() {
655        let config = ForgeConfig::default_with_database_url("postgres://localhost/test");
656        assert_eq!(config.gateway.port, 8080);
657        assert_eq!(config.node.roles.len(), 4);
658        assert_eq!(config.mcp.path, "/mcp");
659        assert!(!config.mcp.enabled);
660    }
661
662    #[test]
663    fn test_parse_minimal_config() {
664        let toml = r#"
665            [database]
666            url = "postgres://localhost/myapp"
667        "#;
668
669        let config = ForgeConfig::parse_toml(toml).unwrap();
670        assert_eq!(config.database.url(), "postgres://localhost/myapp");
671        assert_eq!(config.gateway.port, 8080);
672    }
673
674    #[test]
675    fn test_parse_full_config() {
676        let toml = r#"
677            [project]
678            name = "my-app"
679            version = "1.0.0"
680
681            [database]
682            url = "postgres://localhost/myapp"
683            pool_size = 100
684
685            [node]
686            roles = ["gateway", "worker"]
687            worker_capabilities = ["media", "general"]
688
689            [gateway]
690            port = 3000
691            grpc_port = 9001
692        "#;
693
694        let config = ForgeConfig::parse_toml(toml).unwrap();
695        assert_eq!(config.project.name, "my-app");
696        assert_eq!(config.database.pool_size, 100);
697        assert_eq!(config.node.roles.len(), 2);
698        assert_eq!(config.gateway.port, 3000);
699    }
700
701    #[test]
702    fn test_env_var_substitution() {
703        unsafe {
704            std::env::set_var("TEST_DB_URL", "postgres://test:test@localhost/test");
705        }
706
707        let toml = r#"
708            [database]
709            url = "${TEST_DB_URL}"
710        "#;
711
712        let config = ForgeConfig::parse_toml(toml).unwrap();
713        assert_eq!(config.database.url(), "postgres://test:test@localhost/test");
714
715        unsafe {
716            std::env::remove_var("TEST_DB_URL");
717        }
718    }
719
720    #[test]
721    fn test_auth_validation_no_config() {
722        let auth = AuthConfig::default();
723        assert!(auth.validate().is_ok());
724    }
725
726    #[test]
727    fn test_auth_validation_hmac_with_secret() {
728        let auth = AuthConfig {
729            jwt_secret: Some("my-secret".into()),
730            jwt_algorithm: JwtAlgorithm::HS256,
731            ..Default::default()
732        };
733        assert!(auth.validate().is_ok());
734    }
735
736    #[test]
737    fn test_auth_validation_hmac_missing_secret() {
738        let auth = AuthConfig {
739            jwt_issuer: Some("my-issuer".into()),
740            jwt_algorithm: JwtAlgorithm::HS256,
741            ..Default::default()
742        };
743        let result = auth.validate();
744        assert!(result.is_err());
745        let err_msg = result.unwrap_err().to_string();
746        assert!(err_msg.contains("jwt_secret is required"));
747    }
748
749    #[test]
750    fn test_auth_validation_rsa_with_jwks() {
751        let auth = AuthConfig {
752            jwks_url: Some("https://example.com/.well-known/jwks.json".into()),
753            jwt_algorithm: JwtAlgorithm::RS256,
754            ..Default::default()
755        };
756        assert!(auth.validate().is_ok());
757    }
758
759    #[test]
760    fn test_auth_validation_rsa_missing_jwks() {
761        let auth = AuthConfig {
762            jwt_issuer: Some("my-issuer".into()),
763            jwt_algorithm: JwtAlgorithm::RS256,
764            ..Default::default()
765        };
766        let result = auth.validate();
767        assert!(result.is_err());
768        let err_msg = result.unwrap_err().to_string();
769        assert!(err_msg.contains("jwks_url is required"));
770    }
771
772    #[test]
773    fn test_forge_config_validation_fails_on_empty_url() {
774        let toml = r#"
775            [database]
776
777            url = ""
778        "#;
779
780        let result = ForgeConfig::parse_toml(toml);
781        assert!(result.is_err());
782        let err_msg = result.unwrap_err().to_string();
783        assert!(err_msg.contains("database.url is required"));
784    }
785
786    #[test]
787    fn test_forge_config_validation_fails_on_invalid_auth() {
788        let toml = r#"
789            [database]
790
791            url = "postgres://localhost/test"
792
793            [auth]
794            jwt_issuer = "my-issuer"
795            jwt_algorithm = "RS256"
796        "#;
797
798        let result = ForgeConfig::parse_toml(toml);
799        assert!(result.is_err());
800        let err_msg = result.unwrap_err().to_string();
801        assert!(err_msg.contains("jwks_url is required"));
802    }
803
804    #[test]
805    fn test_mcp_config_validation_rejects_invalid_path() {
806        let toml = r#"
807            [database]
808
809            url = "postgres://localhost/test"
810
811            [mcp]
812            enabled = true
813            path = "mcp"
814        "#;
815
816        let result = ForgeConfig::parse_toml(toml);
817        assert!(result.is_err());
818        let err_msg = result.unwrap_err().to_string();
819        assert!(err_msg.contains("mcp.path must start with '/'"));
820    }
821}