Skip to main content

forge_core/config/
mod.rs

1mod cluster;
2mod database;
3
4pub use cluster::ClusterConfig;
5pub use database::{DatabaseConfig, DatabaseSource};
6
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9
10use crate::error::{ForgeError, Result};
11
12/// Root configuration for FORGE.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ForgeConfig {
15    /// Project metadata.
16    #[serde(default)]
17    pub project: ProjectConfig,
18
19    /// Database configuration.
20    pub database: DatabaseConfig,
21
22    /// Node configuration.
23    #[serde(default)]
24    pub node: NodeConfig,
25
26    /// Gateway configuration.
27    #[serde(default)]
28    pub gateway: GatewayConfig,
29
30    /// Function execution configuration.
31    #[serde(default)]
32    pub function: FunctionConfig,
33
34    /// Worker configuration.
35    #[serde(default)]
36    pub worker: WorkerConfig,
37
38    /// Cluster configuration.
39    #[serde(default)]
40    pub cluster: ClusterConfig,
41
42    /// Security configuration.
43    #[serde(default)]
44    pub security: SecurityConfig,
45
46    /// Authentication configuration.
47    #[serde(default)]
48    pub auth: AuthConfig,
49
50    /// Observability configuration.
51    #[serde(default)]
52    pub observability: ObservabilityConfig,
53}
54
55impl ForgeConfig {
56    /// Load configuration from a TOML file.
57    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
58        let content = std::fs::read_to_string(path.as_ref())
59            .map_err(|e| ForgeError::Config(format!("Failed to read config file: {}", e)))?;
60
61        Self::parse_toml(&content)
62    }
63
64    /// Parse configuration from a TOML string.
65    pub fn parse_toml(content: &str) -> Result<Self> {
66        // Substitute environment variables
67        let content = substitute_env_vars(content);
68
69        let config: Self = toml::from_str(&content)
70            .map_err(|e| ForgeError::Config(format!("Failed to parse config: {}", e)))?;
71
72        config.validate()?;
73        Ok(config)
74    }
75
76    /// Validate the configuration for invalid combinations.
77    pub fn validate(&self) -> Result<()> {
78        self.database.validate()?;
79        self.auth.validate()?;
80        Ok(())
81    }
82
83    /// Load configuration with defaults.
84    pub fn default_with_database_url(url: &str) -> Self {
85        Self {
86            project: ProjectConfig::default(),
87            database: DatabaseConfig::remote(url),
88            node: NodeConfig::default(),
89            gateway: GatewayConfig::default(),
90            function: FunctionConfig::default(),
91            worker: WorkerConfig::default(),
92            cluster: ClusterConfig::default(),
93            security: SecurityConfig::default(),
94            auth: AuthConfig::default(),
95            observability: ObservabilityConfig::default(),
96        }
97    }
98}
99
100/// Project metadata.
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct ProjectConfig {
103    /// Project name.
104    #[serde(default = "default_project_name")]
105    pub name: String,
106
107    /// Project version.
108    #[serde(default = "default_version")]
109    pub version: String,
110}
111
112impl Default for ProjectConfig {
113    fn default() -> Self {
114        Self {
115            name: default_project_name(),
116            version: default_version(),
117        }
118    }
119}
120
121fn default_project_name() -> String {
122    "forge-app".to_string()
123}
124
125fn default_version() -> String {
126    "0.1.0".to_string()
127}
128
129/// Node role configuration.
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct NodeConfig {
132    /// Roles this node should assume.
133    #[serde(default = "default_roles")]
134    pub roles: Vec<NodeRole>,
135
136    /// Worker capabilities for job routing.
137    #[serde(default = "default_capabilities")]
138    pub worker_capabilities: Vec<String>,
139}
140
141impl Default for NodeConfig {
142    fn default() -> Self {
143        Self {
144            roles: default_roles(),
145            worker_capabilities: default_capabilities(),
146        }
147    }
148}
149
150fn default_roles() -> Vec<NodeRole> {
151    vec![
152        NodeRole::Gateway,
153        NodeRole::Function,
154        NodeRole::Worker,
155        NodeRole::Scheduler,
156    ]
157}
158
159fn default_capabilities() -> Vec<String> {
160    vec!["general".to_string()]
161}
162
163/// Available node roles.
164#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
165#[serde(rename_all = "lowercase")]
166pub enum NodeRole {
167    Gateway,
168    Function,
169    Worker,
170    Scheduler,
171}
172
173/// Gateway configuration.
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct GatewayConfig {
176    /// HTTP port.
177    #[serde(default = "default_http_port")]
178    pub port: u16,
179
180    /// gRPC port for inter-node communication.
181    #[serde(default = "default_grpc_port")]
182    pub grpc_port: u16,
183
184    /// Maximum concurrent connections.
185    #[serde(default = "default_max_connections")]
186    pub max_connections: usize,
187
188    /// Request timeout in seconds.
189    #[serde(default = "default_request_timeout")]
190    pub request_timeout_secs: u64,
191
192    /// Enable CORS handling.
193    #[serde(default = "default_cors_enabled")]
194    pub cors_enabled: bool,
195
196    /// Allowed CORS origins.
197    #[serde(default = "default_cors_origins")]
198    pub cors_origins: Vec<String>,
199}
200
201impl Default for GatewayConfig {
202    fn default() -> Self {
203        Self {
204            port: default_http_port(),
205            grpc_port: default_grpc_port(),
206            max_connections: default_max_connections(),
207            request_timeout_secs: default_request_timeout(),
208            cors_enabled: default_cors_enabled(),
209            cors_origins: default_cors_origins(),
210        }
211    }
212}
213
214fn default_http_port() -> u16 {
215    8080
216}
217
218fn default_grpc_port() -> u16 {
219    9000
220}
221
222fn default_max_connections() -> usize {
223    512
224}
225
226fn default_request_timeout() -> u64 {
227    30
228}
229
230fn default_cors_enabled() -> bool {
231    false
232}
233
234fn default_cors_origins() -> Vec<String> {
235    Vec::new()
236}
237
238/// Function execution configuration.
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct FunctionConfig {
241    /// Maximum concurrent function executions.
242    #[serde(default = "default_max_concurrent")]
243    pub max_concurrent: usize,
244
245    /// Function timeout in seconds.
246    #[serde(default = "default_function_timeout")]
247    pub timeout_secs: u64,
248
249    /// Memory limit per function (in bytes).
250    #[serde(default = "default_memory_limit")]
251    pub memory_limit: usize,
252}
253
254impl Default for FunctionConfig {
255    fn default() -> Self {
256        Self {
257            max_concurrent: default_max_concurrent(),
258            timeout_secs: default_function_timeout(),
259            memory_limit: default_memory_limit(),
260        }
261    }
262}
263
264fn default_max_concurrent() -> usize {
265    1000
266}
267
268fn default_function_timeout() -> u64 {
269    30
270}
271
272fn default_memory_limit() -> usize {
273    512 * 1024 * 1024 // 512 MiB
274}
275
276/// Worker configuration.
277#[derive(Debug, Clone, Serialize, Deserialize)]
278pub struct WorkerConfig {
279    /// Maximum concurrent jobs.
280    #[serde(default = "default_max_concurrent_jobs")]
281    pub max_concurrent_jobs: usize,
282
283    /// Job timeout in seconds.
284    #[serde(default = "default_job_timeout")]
285    pub job_timeout_secs: u64,
286
287    /// Poll interval in milliseconds.
288    #[serde(default = "default_poll_interval")]
289    pub poll_interval_ms: u64,
290}
291
292impl Default for WorkerConfig {
293    fn default() -> Self {
294        Self {
295            max_concurrent_jobs: default_max_concurrent_jobs(),
296            job_timeout_secs: default_job_timeout(),
297            poll_interval_ms: default_poll_interval(),
298        }
299    }
300}
301
302fn default_max_concurrent_jobs() -> usize {
303    50
304}
305
306fn default_job_timeout() -> u64 {
307    3600 // 1 hour
308}
309
310fn default_poll_interval() -> u64 {
311    100
312}
313
314/// Security configuration.
315#[derive(Debug, Clone, Serialize, Deserialize, Default)]
316pub struct SecurityConfig {
317    /// Secret key for signing.
318    pub secret_key: Option<String>,
319}
320
321/// JWT signing algorithm.
322#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
323#[serde(rename_all = "UPPERCASE")]
324pub enum JwtAlgorithm {
325    /// HMAC using SHA-256 (symmetric, requires jwt_secret).
326    #[default]
327    HS256,
328    /// HMAC using SHA-384 (symmetric, requires jwt_secret).
329    HS384,
330    /// HMAC using SHA-512 (symmetric, requires jwt_secret).
331    HS512,
332    /// RSA using SHA-256 (asymmetric, requires jwks_url).
333    RS256,
334    /// RSA using SHA-384 (asymmetric, requires jwks_url).
335    RS384,
336    /// RSA using SHA-512 (asymmetric, requires jwks_url).
337    RS512,
338}
339
340/// Authentication configuration.
341#[derive(Debug, Clone, Serialize, Deserialize)]
342pub struct AuthConfig {
343    /// JWT secret for HMAC algorithms (HS256, HS384, HS512).
344    /// Required when using HMAC algorithms.
345    pub jwt_secret: Option<String>,
346
347    /// JWT signing algorithm.
348    /// HMAC algorithms (HS256, HS384, HS512) require jwt_secret.
349    /// RSA algorithms (RS256, RS384, RS512) require jwks_url.
350    #[serde(default)]
351    pub jwt_algorithm: JwtAlgorithm,
352
353    /// Expected token issuer (iss claim).
354    /// If set, tokens with a different issuer are rejected.
355    pub jwt_issuer: Option<String>,
356
357    /// Expected audience (aud claim).
358    /// If set, tokens with a different audience are rejected.
359    pub jwt_audience: Option<String>,
360
361    /// Token expiry duration (e.g., "15m", "1h", "7d").
362    pub token_expiry: Option<String>,
363
364    /// JWKS URL for RSA algorithms (RS256, RS384, RS512).
365    /// Keys are fetched and cached automatically.
366    pub jwks_url: Option<String>,
367
368    /// JWKS cache TTL in seconds.
369    #[serde(default = "default_jwks_cache_ttl")]
370    pub jwks_cache_ttl_secs: u64,
371
372    /// Session TTL in seconds (for WebSocket sessions).
373    #[serde(default = "default_session_ttl")]
374    pub session_ttl_secs: u64,
375}
376
377impl Default for AuthConfig {
378    fn default() -> Self {
379        Self {
380            jwt_secret: None,
381            jwt_algorithm: JwtAlgorithm::default(),
382            jwt_issuer: None,
383            jwt_audience: None,
384            token_expiry: None,
385            jwks_url: None,
386            jwks_cache_ttl_secs: default_jwks_cache_ttl(),
387            session_ttl_secs: default_session_ttl(),
388        }
389    }
390}
391
392impl AuthConfig {
393    /// Check if auth is configured (any credential or claim validation is set).
394    fn is_configured(&self) -> bool {
395        self.jwt_secret.is_some()
396            || self.jwks_url.is_some()
397            || self.jwt_issuer.is_some()
398            || self.jwt_audience.is_some()
399    }
400
401    /// Validate that the configuration is complete for the chosen algorithm.
402    /// Skips validation if no auth settings are configured (auth disabled).
403    pub fn validate(&self) -> Result<()> {
404        if !self.is_configured() {
405            return Ok(());
406        }
407
408        match self.jwt_algorithm {
409            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512 => {
410                if self.jwt_secret.is_none() {
411                    return Err(ForgeError::Config(
412                        "auth.jwt_secret is required for HMAC algorithms (HS256, HS384, HS512). \
413                         Set auth.jwt_secret to a secure random string, \
414                         or switch to RS256 and provide auth.jwks_url for external identity providers."
415                            .into(),
416                    ));
417                }
418            }
419            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512 => {
420                if self.jwks_url.is_none() {
421                    return Err(ForgeError::Config(
422                        "auth.jwks_url is required for RSA algorithms (RS256, RS384, RS512). \
423                         Set auth.jwks_url to your identity provider's JWKS endpoint, \
424                         or switch to HS256 and provide auth.jwt_secret for symmetric signing."
425                            .into(),
426                    ));
427                }
428            }
429        }
430        Ok(())
431    }
432
433    /// Check if this config uses HMAC (symmetric) algorithms.
434    pub fn is_hmac(&self) -> bool {
435        matches!(
436            self.jwt_algorithm,
437            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
438        )
439    }
440
441    /// Check if this config uses RSA (asymmetric) algorithms.
442    pub fn is_rsa(&self) -> bool {
443        matches!(
444            self.jwt_algorithm,
445            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
446        )
447    }
448}
449
450fn default_jwks_cache_ttl() -> u64 {
451    3600 // 1 hour
452}
453
454fn default_session_ttl() -> u64 {
455    7 * 24 * 60 * 60 // 7 days
456}
457
458/// Observability configuration for OTLP telemetry.
459#[derive(Debug, Clone, Serialize, Deserialize)]
460pub struct ObservabilityConfig {
461    /// Enable observability (traces, metrics, logs).
462    #[serde(default)]
463    pub enabled: bool,
464
465    /// OTLP endpoint for telemetry export.
466    #[serde(default = "default_otlp_endpoint")]
467    pub otlp_endpoint: String,
468
469    /// Service name for telemetry identification.
470    pub service_name: Option<String>,
471
472    /// Enable distributed tracing.
473    #[serde(default = "default_true")]
474    pub enable_traces: bool,
475
476    /// Enable metrics collection.
477    #[serde(default = "default_true")]
478    pub enable_metrics: bool,
479
480    /// Enable log export via OTLP.
481    #[serde(default = "default_true")]
482    pub enable_logs: bool,
483
484    /// Trace sampling ratio (0.0 to 1.0).
485    #[serde(default = "default_sampling_ratio")]
486    pub sampling_ratio: f64,
487}
488
489impl Default for ObservabilityConfig {
490    fn default() -> Self {
491        Self {
492            enabled: false,
493            otlp_endpoint: default_otlp_endpoint(),
494            service_name: None,
495            enable_traces: true,
496            enable_metrics: true,
497            enable_logs: true,
498            sampling_ratio: default_sampling_ratio(),
499        }
500    }
501}
502
503fn default_otlp_endpoint() -> String {
504    "http://localhost:4317".to_string()
505}
506
507fn default_true() -> bool {
508    true
509}
510
511fn default_sampling_ratio() -> f64 {
512    1.0
513}
514
515/// Substitute environment variables in the format ${VAR_NAME}.
516fn substitute_env_vars(content: &str) -> String {
517    let mut result = content.to_string();
518    let re = regex_lite::Regex::new(r"\$\{([A-Z_][A-Z0-9_]*)\}").expect("valid regex pattern");
519
520    for cap in re.captures_iter(content) {
521        let var_name = &cap[1];
522        if let Ok(value) = std::env::var(var_name) {
523            result = result.replace(&cap[0], &value);
524        }
525    }
526
527    result
528}
529
530#[cfg(test)]
531#[allow(clippy::unwrap_used, clippy::indexing_slicing, unsafe_code)]
532mod tests {
533    use super::*;
534
535    #[test]
536    fn test_default_config() {
537        let config = ForgeConfig::default_with_database_url("postgres://localhost/test");
538        assert_eq!(config.gateway.port, 8080);
539        assert_eq!(config.node.roles.len(), 4);
540    }
541
542    #[test]
543    fn test_parse_minimal_config() {
544        let toml = r#"
545            [database]
546            mode = "remote"
547            url = "postgres://localhost/myapp"
548        "#;
549
550        let config = ForgeConfig::parse_toml(toml).unwrap();
551        assert_eq!(config.database.url(), Some("postgres://localhost/myapp"));
552        assert_eq!(config.gateway.port, 8080);
553    }
554
555    #[test]
556    fn test_parse_full_config() {
557        let toml = r#"
558            [project]
559            name = "my-app"
560            version = "1.0.0"
561
562            [database]
563            mode = "remote"
564            url = "postgres://localhost/myapp"
565            pool_size = 100
566
567            [node]
568            roles = ["gateway", "worker"]
569            worker_capabilities = ["media", "general"]
570
571            [gateway]
572            port = 3000
573            grpc_port = 9001
574        "#;
575
576        let config = ForgeConfig::parse_toml(toml).unwrap();
577        assert_eq!(config.project.name, "my-app");
578        assert_eq!(config.database.pool_size, 100);
579        assert_eq!(config.node.roles.len(), 2);
580        assert_eq!(config.gateway.port, 3000);
581    }
582
583    #[test]
584    fn test_env_var_substitution() {
585        unsafe {
586            std::env::set_var("TEST_DB_URL", "postgres://test:test@localhost/test");
587        }
588
589        let toml = r#"
590            [database]
591            mode = "remote"
592            url = "${TEST_DB_URL}"
593        "#;
594
595        let config = ForgeConfig::parse_toml(toml).unwrap();
596        assert_eq!(
597            config.database.url(),
598            Some("postgres://test:test@localhost/test")
599        );
600
601        unsafe {
602            std::env::remove_var("TEST_DB_URL");
603        }
604    }
605
606    #[test]
607    fn test_auth_validation_no_config() {
608        let auth = AuthConfig::default();
609        assert!(auth.validate().is_ok());
610    }
611
612    #[test]
613    fn test_auth_validation_hmac_with_secret() {
614        let auth = AuthConfig {
615            jwt_secret: Some("my-secret".into()),
616            jwt_algorithm: JwtAlgorithm::HS256,
617            ..Default::default()
618        };
619        assert!(auth.validate().is_ok());
620    }
621
622    #[test]
623    fn test_auth_validation_hmac_missing_secret() {
624        let auth = AuthConfig {
625            jwt_issuer: Some("my-issuer".into()),
626            jwt_algorithm: JwtAlgorithm::HS256,
627            ..Default::default()
628        };
629        let result = auth.validate();
630        assert!(result.is_err());
631        let err_msg = result.unwrap_err().to_string();
632        assert!(err_msg.contains("jwt_secret is required"));
633    }
634
635    #[test]
636    fn test_auth_validation_rsa_with_jwks() {
637        let auth = AuthConfig {
638            jwks_url: Some("https://example.com/.well-known/jwks.json".into()),
639            jwt_algorithm: JwtAlgorithm::RS256,
640            ..Default::default()
641        };
642        assert!(auth.validate().is_ok());
643    }
644
645    #[test]
646    fn test_auth_validation_rsa_missing_jwks() {
647        let auth = AuthConfig {
648            jwt_issuer: Some("my-issuer".into()),
649            jwt_algorithm: JwtAlgorithm::RS256,
650            ..Default::default()
651        };
652        let result = auth.validate();
653        assert!(result.is_err());
654        let err_msg = result.unwrap_err().to_string();
655        assert!(err_msg.contains("jwks_url is required"));
656    }
657
658    #[test]
659    fn test_forge_config_validation_fails_on_empty_url() {
660        let toml = r#"
661            [database]
662            mode = "remote"
663            url = ""
664        "#;
665
666        let result = ForgeConfig::parse_toml(toml);
667        assert!(result.is_err());
668        let err_msg = result.unwrap_err().to_string();
669        assert!(err_msg.contains("database.url is required"));
670    }
671
672    #[test]
673    fn test_forge_config_validation_fails_on_invalid_auth() {
674        let toml = r#"
675            [database]
676            mode = "remote"
677            url = "postgres://localhost/test"
678
679            [auth]
680            jwt_issuer = "my-issuer"
681            jwt_algorithm = "RS256"
682        "#;
683
684        let result = ForgeConfig::parse_toml(toml);
685        assert!(result.is_err());
686        let err_msg = result.unwrap_err().to_string();
687        assert!(err_msg.contains("jwks_url is required"));
688    }
689}