Skip to main content

forge_core/config/
mod.rs

1mod cluster;
2mod database;
3mod observability;
4
5pub use cluster::ClusterConfig;
6pub use database::DatabaseConfig;
7pub use observability::ObservabilityConfig;
8
9use serde::{Deserialize, Serialize};
10use std::path::Path;
11
12use crate::error::{ForgeError, Result};
13
14/// Root configuration for FORGE.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ForgeConfig {
17    /// Project metadata.
18    #[serde(default)]
19    pub project: ProjectConfig,
20
21    /// Database configuration.
22    pub database: DatabaseConfig,
23
24    /// Node configuration.
25    #[serde(default)]
26    pub node: NodeConfig,
27
28    /// Gateway configuration.
29    #[serde(default)]
30    pub gateway: GatewayConfig,
31
32    /// Function execution configuration.
33    #[serde(default)]
34    pub function: FunctionConfig,
35
36    /// Worker configuration.
37    #[serde(default)]
38    pub worker: WorkerConfig,
39
40    /// Cluster configuration.
41    #[serde(default)]
42    pub cluster: ClusterConfig,
43
44    /// Observability configuration.
45    #[serde(default)]
46    pub observability: ObservabilityConfig,
47
48    /// Security configuration.
49    #[serde(default)]
50    pub security: SecurityConfig,
51
52    /// Authentication configuration.
53    #[serde(default)]
54    pub auth: AuthConfig,
55}
56
57impl ForgeConfig {
58    /// Load configuration from a TOML file.
59    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
60        let content = std::fs::read_to_string(path.as_ref())
61            .map_err(|e| ForgeError::Config(format!("Failed to read config file: {}", e)))?;
62
63        Self::parse_toml(&content)
64    }
65
66    /// Parse configuration from a TOML string.
67    pub fn parse_toml(content: &str) -> Result<Self> {
68        // Substitute environment variables
69        let content = substitute_env_vars(content);
70
71        toml::from_str(&content)
72            .map_err(|e| ForgeError::Config(format!("Failed to parse config: {}", e)))
73    }
74
75    /// Load configuration with defaults.
76    pub fn default_with_database_url(url: &str) -> Self {
77        Self {
78            project: ProjectConfig::default(),
79            database: DatabaseConfig {
80                url: url.to_string(),
81                ..Default::default()
82            },
83            node: NodeConfig::default(),
84            gateway: GatewayConfig::default(),
85            function: FunctionConfig::default(),
86            worker: WorkerConfig::default(),
87            cluster: ClusterConfig::default(),
88            observability: ObservabilityConfig::default(),
89            security: SecurityConfig::default(),
90            auth: AuthConfig::default(),
91        }
92    }
93}
94
95/// Project metadata.
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct ProjectConfig {
98    /// Project name.
99    #[serde(default = "default_project_name")]
100    pub name: String,
101
102    /// Project version.
103    #[serde(default = "default_version")]
104    pub version: String,
105}
106
107impl Default for ProjectConfig {
108    fn default() -> Self {
109        Self {
110            name: default_project_name(),
111            version: default_version(),
112        }
113    }
114}
115
116fn default_project_name() -> String {
117    "forge-app".to_string()
118}
119
120fn default_version() -> String {
121    "0.1.0".to_string()
122}
123
124/// Node role configuration.
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct NodeConfig {
127    /// Roles this node should assume.
128    #[serde(default = "default_roles")]
129    pub roles: Vec<NodeRole>,
130
131    /// Worker capabilities for job routing.
132    #[serde(default = "default_capabilities")]
133    pub worker_capabilities: Vec<String>,
134}
135
136impl Default for NodeConfig {
137    fn default() -> Self {
138        Self {
139            roles: default_roles(),
140            worker_capabilities: default_capabilities(),
141        }
142    }
143}
144
145fn default_roles() -> Vec<NodeRole> {
146    vec![
147        NodeRole::Gateway,
148        NodeRole::Function,
149        NodeRole::Worker,
150        NodeRole::Scheduler,
151    ]
152}
153
154fn default_capabilities() -> Vec<String> {
155    vec!["general".to_string()]
156}
157
158/// Available node roles.
159#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
160#[serde(rename_all = "lowercase")]
161pub enum NodeRole {
162    Gateway,
163    Function,
164    Worker,
165    Scheduler,
166}
167
168/// Gateway configuration.
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct GatewayConfig {
171    /// HTTP port.
172    #[serde(default = "default_http_port")]
173    pub port: u16,
174
175    /// gRPC port for inter-node communication.
176    #[serde(default = "default_grpc_port")]
177    pub grpc_port: u16,
178
179    /// Maximum concurrent connections.
180    #[serde(default = "default_max_connections")]
181    pub max_connections: usize,
182
183    /// Request timeout in seconds.
184    #[serde(default = "default_request_timeout")]
185    pub request_timeout_secs: u64,
186}
187
188impl Default for GatewayConfig {
189    fn default() -> Self {
190        Self {
191            port: default_http_port(),
192            grpc_port: default_grpc_port(),
193            max_connections: default_max_connections(),
194            request_timeout_secs: default_request_timeout(),
195        }
196    }
197}
198
199fn default_http_port() -> u16 {
200    8080
201}
202
203fn default_grpc_port() -> u16 {
204    9000
205}
206
207fn default_max_connections() -> usize {
208    10000
209}
210
211fn default_request_timeout() -> u64 {
212    30
213}
214
215/// Function execution configuration.
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct FunctionConfig {
218    /// Maximum concurrent function executions.
219    #[serde(default = "default_max_concurrent")]
220    pub max_concurrent: usize,
221
222    /// Function timeout in seconds.
223    #[serde(default = "default_function_timeout")]
224    pub timeout_secs: u64,
225
226    /// Memory limit per function (in bytes).
227    #[serde(default = "default_memory_limit")]
228    pub memory_limit: usize,
229}
230
231impl Default for FunctionConfig {
232    fn default() -> Self {
233        Self {
234            max_concurrent: default_max_concurrent(),
235            timeout_secs: default_function_timeout(),
236            memory_limit: default_memory_limit(),
237        }
238    }
239}
240
241fn default_max_concurrent() -> usize {
242    1000
243}
244
245fn default_function_timeout() -> u64 {
246    30
247}
248
249fn default_memory_limit() -> usize {
250    512 * 1024 * 1024 // 512 MiB
251}
252
253/// Worker configuration.
254#[derive(Debug, Clone, Serialize, Deserialize)]
255pub struct WorkerConfig {
256    /// Maximum concurrent jobs.
257    #[serde(default = "default_max_concurrent_jobs")]
258    pub max_concurrent_jobs: usize,
259
260    /// Job timeout in seconds.
261    #[serde(default = "default_job_timeout")]
262    pub job_timeout_secs: u64,
263
264    /// Poll interval in milliseconds.
265    #[serde(default = "default_poll_interval")]
266    pub poll_interval_ms: u64,
267}
268
269impl Default for WorkerConfig {
270    fn default() -> Self {
271        Self {
272            max_concurrent_jobs: default_max_concurrent_jobs(),
273            job_timeout_secs: default_job_timeout(),
274            poll_interval_ms: default_poll_interval(),
275        }
276    }
277}
278
279fn default_max_concurrent_jobs() -> usize {
280    50
281}
282
283fn default_job_timeout() -> u64 {
284    3600 // 1 hour
285}
286
287fn default_poll_interval() -> u64 {
288    100
289}
290
291/// Security configuration.
292#[derive(Debug, Clone, Serialize, Deserialize, Default)]
293pub struct SecurityConfig {
294    /// Secret key for signing.
295    pub secret_key: Option<String>,
296}
297
298/// JWT signing algorithm.
299#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
300#[serde(rename_all = "UPPERCASE")]
301pub enum JwtAlgorithm {
302    /// HMAC using SHA-256 (symmetric, requires jwt_secret).
303    #[default]
304    HS256,
305    /// HMAC using SHA-384 (symmetric, requires jwt_secret).
306    HS384,
307    /// HMAC using SHA-512 (symmetric, requires jwt_secret).
308    HS512,
309    /// RSA using SHA-256 (asymmetric, requires jwks_url).
310    RS256,
311    /// RSA using SHA-384 (asymmetric, requires jwks_url).
312    RS384,
313    /// RSA using SHA-512 (asymmetric, requires jwks_url).
314    RS512,
315}
316
317/// Authentication configuration.
318#[derive(Debug, Clone, Serialize, Deserialize)]
319pub struct AuthConfig {
320    /// JWT secret for HMAC algorithms (HS256, HS384, HS512).
321    /// Required when using HMAC algorithms.
322    pub jwt_secret: Option<String>,
323
324    /// JWT signing algorithm.
325    /// HMAC algorithms (HS256, HS384, HS512) require jwt_secret.
326    /// RSA algorithms (RS256, RS384, RS512) require jwks_url.
327    #[serde(default)]
328    pub jwt_algorithm: JwtAlgorithm,
329
330    /// Expected token issuer (iss claim).
331    /// If set, tokens with a different issuer are rejected.
332    pub jwt_issuer: Option<String>,
333
334    /// Expected audience (aud claim).
335    /// If set, tokens with a different audience are rejected.
336    pub jwt_audience: Option<String>,
337
338    /// Token expiry duration (e.g., "15m", "1h", "7d").
339    pub token_expiry: Option<String>,
340
341    /// JWKS URL for RSA algorithms (RS256, RS384, RS512).
342    /// Keys are fetched and cached automatically.
343    pub jwks_url: Option<String>,
344
345    /// JWKS cache TTL in seconds.
346    #[serde(default = "default_jwks_cache_ttl")]
347    pub jwks_cache_ttl_secs: u64,
348
349    /// Session TTL in seconds (for WebSocket sessions).
350    #[serde(default = "default_session_ttl")]
351    pub session_ttl_secs: u64,
352}
353
354impl Default for AuthConfig {
355    fn default() -> Self {
356        Self {
357            jwt_secret: None,
358            jwt_algorithm: JwtAlgorithm::default(),
359            jwt_issuer: None,
360            jwt_audience: None,
361            token_expiry: None,
362            jwks_url: None,
363            jwks_cache_ttl_secs: default_jwks_cache_ttl(),
364            session_ttl_secs: default_session_ttl(),
365        }
366    }
367}
368
369impl AuthConfig {
370    /// Validate that the configuration is complete for the chosen algorithm.
371    pub fn validate(&self) -> Result<()> {
372        match self.jwt_algorithm {
373            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512 => {
374                if self.jwt_secret.is_none() {
375                    return Err(ForgeError::Config(
376                        "jwt_secret is required for HMAC algorithms (HS256, HS384, HS512)".into(),
377                    ));
378                }
379            }
380            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512 => {
381                if self.jwks_url.is_none() {
382                    return Err(ForgeError::Config(
383                        "jwks_url is required for RSA algorithms (RS256, RS384, RS512)".into(),
384                    ));
385                }
386            }
387        }
388        Ok(())
389    }
390
391    /// Check if this config uses HMAC (symmetric) algorithms.
392    pub fn is_hmac(&self) -> bool {
393        matches!(
394            self.jwt_algorithm,
395            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
396        )
397    }
398
399    /// Check if this config uses RSA (asymmetric) algorithms.
400    pub fn is_rsa(&self) -> bool {
401        matches!(
402            self.jwt_algorithm,
403            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
404        )
405    }
406}
407
408fn default_jwks_cache_ttl() -> u64 {
409    3600 // 1 hour
410}
411
412fn default_session_ttl() -> u64 {
413    7 * 24 * 60 * 60 // 7 days
414}
415
416/// Substitute environment variables in the format ${VAR_NAME}.
417fn substitute_env_vars(content: &str) -> String {
418    let mut result = content.to_string();
419    let re = regex_lite::Regex::new(r"\$\{([A-Z_][A-Z0-9_]*)\}").unwrap();
420
421    for cap in re.captures_iter(content) {
422        let var_name = &cap[1];
423        if let Ok(value) = std::env::var(var_name) {
424            result = result.replace(&cap[0], &value);
425        }
426    }
427
428    result
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434
435    #[test]
436    fn test_default_config() {
437        let config = ForgeConfig::default_with_database_url("postgres://localhost/test");
438        assert_eq!(config.gateway.port, 8080);
439        assert_eq!(config.node.roles.len(), 4);
440    }
441
442    #[test]
443    fn test_parse_minimal_config() {
444        let toml = r#"
445            [database]
446            url = "postgres://localhost/myapp"
447        "#;
448
449        let config = ForgeConfig::parse_toml(toml).unwrap();
450        assert_eq!(config.database.url, "postgres://localhost/myapp");
451        assert_eq!(config.gateway.port, 8080);
452    }
453
454    #[test]
455    fn test_parse_full_config() {
456        let toml = r#"
457            [project]
458            name = "my-app"
459            version = "1.0.0"
460
461            [database]
462            url = "postgres://localhost/myapp"
463            pool_size = 100
464
465            [node]
466            roles = ["gateway", "worker"]
467            worker_capabilities = ["media", "general"]
468
469            [gateway]
470            port = 3000
471            grpc_port = 9001
472        "#;
473
474        let config = ForgeConfig::parse_toml(toml).unwrap();
475        assert_eq!(config.project.name, "my-app");
476        assert_eq!(config.database.pool_size, 100);
477        assert_eq!(config.node.roles.len(), 2);
478        assert_eq!(config.gateway.port, 3000);
479    }
480
481    #[test]
482    fn test_env_var_substitution() {
483        unsafe {
484            std::env::set_var("TEST_DB_URL", "postgres://test:test@localhost/test");
485        }
486
487        let toml = r#"
488            [database]
489            url = "${TEST_DB_URL}"
490        "#;
491
492        let config = ForgeConfig::parse_toml(toml).unwrap();
493        assert_eq!(config.database.url, "postgres://test:test@localhost/test");
494
495        unsafe {
496            std::env::remove_var("TEST_DB_URL");
497        }
498    }
499}