Skip to main content

forge_core/config/
mod.rs

1mod cluster;
2mod database;
3
4pub use cluster::ClusterConfig;
5pub use database::DatabaseConfig;
6
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9
10use crate::error::{ForgeError, Result};
11
12/// Root configuration for FORGE.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ForgeConfig {
15    /// Project metadata.
16    #[serde(default)]
17    pub project: ProjectConfig,
18
19    /// Database configuration.
20    pub database: DatabaseConfig,
21
22    /// Node configuration.
23    #[serde(default)]
24    pub node: NodeConfig,
25
26    /// Gateway configuration.
27    #[serde(default)]
28    pub gateway: GatewayConfig,
29
30    /// Function execution configuration.
31    #[serde(default)]
32    pub function: FunctionConfig,
33
34    /// Worker configuration.
35    #[serde(default)]
36    pub worker: WorkerConfig,
37
38    /// Cluster configuration.
39    #[serde(default)]
40    pub cluster: ClusterConfig,
41
42    /// Security configuration.
43    #[serde(default)]
44    pub security: SecurityConfig,
45
46    /// Authentication configuration.
47    #[serde(default)]
48    pub auth: AuthConfig,
49}
50
51impl ForgeConfig {
52    /// Load configuration from a TOML file.
53    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
54        let content = std::fs::read_to_string(path.as_ref())
55            .map_err(|e| ForgeError::Config(format!("Failed to read config file: {}", e)))?;
56
57        Self::parse_toml(&content)
58    }
59
60    /// Parse configuration from a TOML string.
61    pub fn parse_toml(content: &str) -> Result<Self> {
62        // Substitute environment variables
63        let content = substitute_env_vars(content);
64
65        toml::from_str(&content)
66            .map_err(|e| ForgeError::Config(format!("Failed to parse config: {}", e)))
67    }
68
69    /// Load configuration with defaults.
70    pub fn default_with_database_url(url: &str) -> Self {
71        Self {
72            project: ProjectConfig::default(),
73            database: DatabaseConfig {
74                url: url.to_string(),
75                ..Default::default()
76            },
77            node: NodeConfig::default(),
78            gateway: GatewayConfig::default(),
79            function: FunctionConfig::default(),
80            worker: WorkerConfig::default(),
81            cluster: ClusterConfig::default(),
82            security: SecurityConfig::default(),
83            auth: AuthConfig::default(),
84        }
85    }
86}
87
88/// Project metadata.
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct ProjectConfig {
91    /// Project name.
92    #[serde(default = "default_project_name")]
93    pub name: String,
94
95    /// Project version.
96    #[serde(default = "default_version")]
97    pub version: String,
98}
99
100impl Default for ProjectConfig {
101    fn default() -> Self {
102        Self {
103            name: default_project_name(),
104            version: default_version(),
105        }
106    }
107}
108
109fn default_project_name() -> String {
110    "forge-app".to_string()
111}
112
113fn default_version() -> String {
114    "0.1.0".to_string()
115}
116
117/// Node role configuration.
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct NodeConfig {
120    /// Roles this node should assume.
121    #[serde(default = "default_roles")]
122    pub roles: Vec<NodeRole>,
123
124    /// Worker capabilities for job routing.
125    #[serde(default = "default_capabilities")]
126    pub worker_capabilities: Vec<String>,
127}
128
129impl Default for NodeConfig {
130    fn default() -> Self {
131        Self {
132            roles: default_roles(),
133            worker_capabilities: default_capabilities(),
134        }
135    }
136}
137
138fn default_roles() -> Vec<NodeRole> {
139    vec![
140        NodeRole::Gateway,
141        NodeRole::Function,
142        NodeRole::Worker,
143        NodeRole::Scheduler,
144    ]
145}
146
147fn default_capabilities() -> Vec<String> {
148    vec!["general".to_string()]
149}
150
151/// Available node roles.
152#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
153#[serde(rename_all = "lowercase")]
154pub enum NodeRole {
155    Gateway,
156    Function,
157    Worker,
158    Scheduler,
159}
160
161/// Gateway configuration.
162#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct GatewayConfig {
164    /// HTTP port.
165    #[serde(default = "default_http_port")]
166    pub port: u16,
167
168    /// gRPC port for inter-node communication.
169    #[serde(default = "default_grpc_port")]
170    pub grpc_port: u16,
171
172    /// Maximum concurrent connections.
173    #[serde(default = "default_max_connections")]
174    pub max_connections: usize,
175
176    /// Request timeout in seconds.
177    #[serde(default = "default_request_timeout")]
178    pub request_timeout_secs: u64,
179}
180
181impl Default for GatewayConfig {
182    fn default() -> Self {
183        Self {
184            port: default_http_port(),
185            grpc_port: default_grpc_port(),
186            max_connections: default_max_connections(),
187            request_timeout_secs: default_request_timeout(),
188        }
189    }
190}
191
192fn default_http_port() -> u16 {
193    8080
194}
195
196fn default_grpc_port() -> u16 {
197    9000
198}
199
200fn default_max_connections() -> usize {
201    10000
202}
203
204fn default_request_timeout() -> u64 {
205    30
206}
207
208/// Function execution configuration.
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct FunctionConfig {
211    /// Maximum concurrent function executions.
212    #[serde(default = "default_max_concurrent")]
213    pub max_concurrent: usize,
214
215    /// Function timeout in seconds.
216    #[serde(default = "default_function_timeout")]
217    pub timeout_secs: u64,
218
219    /// Memory limit per function (in bytes).
220    #[serde(default = "default_memory_limit")]
221    pub memory_limit: usize,
222}
223
224impl Default for FunctionConfig {
225    fn default() -> Self {
226        Self {
227            max_concurrent: default_max_concurrent(),
228            timeout_secs: default_function_timeout(),
229            memory_limit: default_memory_limit(),
230        }
231    }
232}
233
234fn default_max_concurrent() -> usize {
235    1000
236}
237
238fn default_function_timeout() -> u64 {
239    30
240}
241
242fn default_memory_limit() -> usize {
243    512 * 1024 * 1024 // 512 MiB
244}
245
246/// Worker configuration.
247#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct WorkerConfig {
249    /// Maximum concurrent jobs.
250    #[serde(default = "default_max_concurrent_jobs")]
251    pub max_concurrent_jobs: usize,
252
253    /// Job timeout in seconds.
254    #[serde(default = "default_job_timeout")]
255    pub job_timeout_secs: u64,
256
257    /// Poll interval in milliseconds.
258    #[serde(default = "default_poll_interval")]
259    pub poll_interval_ms: u64,
260}
261
262impl Default for WorkerConfig {
263    fn default() -> Self {
264        Self {
265            max_concurrent_jobs: default_max_concurrent_jobs(),
266            job_timeout_secs: default_job_timeout(),
267            poll_interval_ms: default_poll_interval(),
268        }
269    }
270}
271
272fn default_max_concurrent_jobs() -> usize {
273    50
274}
275
276fn default_job_timeout() -> u64 {
277    3600 // 1 hour
278}
279
280fn default_poll_interval() -> u64 {
281    100
282}
283
284/// Security configuration.
285#[derive(Debug, Clone, Serialize, Deserialize, Default)]
286pub struct SecurityConfig {
287    /// Secret key for signing.
288    pub secret_key: Option<String>,
289}
290
291/// JWT signing algorithm.
292#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
293#[serde(rename_all = "UPPERCASE")]
294pub enum JwtAlgorithm {
295    /// HMAC using SHA-256 (symmetric, requires jwt_secret).
296    #[default]
297    HS256,
298    /// HMAC using SHA-384 (symmetric, requires jwt_secret).
299    HS384,
300    /// HMAC using SHA-512 (symmetric, requires jwt_secret).
301    HS512,
302    /// RSA using SHA-256 (asymmetric, requires jwks_url).
303    RS256,
304    /// RSA using SHA-384 (asymmetric, requires jwks_url).
305    RS384,
306    /// RSA using SHA-512 (asymmetric, requires jwks_url).
307    RS512,
308}
309
310/// Authentication configuration.
311#[derive(Debug, Clone, Serialize, Deserialize)]
312pub struct AuthConfig {
313    /// JWT secret for HMAC algorithms (HS256, HS384, HS512).
314    /// Required when using HMAC algorithms.
315    pub jwt_secret: Option<String>,
316
317    /// JWT signing algorithm.
318    /// HMAC algorithms (HS256, HS384, HS512) require jwt_secret.
319    /// RSA algorithms (RS256, RS384, RS512) require jwks_url.
320    #[serde(default)]
321    pub jwt_algorithm: JwtAlgorithm,
322
323    /// Expected token issuer (iss claim).
324    /// If set, tokens with a different issuer are rejected.
325    pub jwt_issuer: Option<String>,
326
327    /// Expected audience (aud claim).
328    /// If set, tokens with a different audience are rejected.
329    pub jwt_audience: Option<String>,
330
331    /// Token expiry duration (e.g., "15m", "1h", "7d").
332    pub token_expiry: Option<String>,
333
334    /// JWKS URL for RSA algorithms (RS256, RS384, RS512).
335    /// Keys are fetched and cached automatically.
336    pub jwks_url: Option<String>,
337
338    /// JWKS cache TTL in seconds.
339    #[serde(default = "default_jwks_cache_ttl")]
340    pub jwks_cache_ttl_secs: u64,
341
342    /// Session TTL in seconds (for WebSocket sessions).
343    #[serde(default = "default_session_ttl")]
344    pub session_ttl_secs: u64,
345}
346
347impl Default for AuthConfig {
348    fn default() -> Self {
349        Self {
350            jwt_secret: None,
351            jwt_algorithm: JwtAlgorithm::default(),
352            jwt_issuer: None,
353            jwt_audience: None,
354            token_expiry: None,
355            jwks_url: None,
356            jwks_cache_ttl_secs: default_jwks_cache_ttl(),
357            session_ttl_secs: default_session_ttl(),
358        }
359    }
360}
361
362impl AuthConfig {
363    /// Validate that the configuration is complete for the chosen algorithm.
364    pub fn validate(&self) -> Result<()> {
365        match self.jwt_algorithm {
366            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512 => {
367                if self.jwt_secret.is_none() {
368                    return Err(ForgeError::Config(
369                        "jwt_secret is required for HMAC algorithms (HS256, HS384, HS512)".into(),
370                    ));
371                }
372            }
373            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512 => {
374                if self.jwks_url.is_none() {
375                    return Err(ForgeError::Config(
376                        "jwks_url is required for RSA algorithms (RS256, RS384, RS512)".into(),
377                    ));
378                }
379            }
380        }
381        Ok(())
382    }
383
384    /// Check if this config uses HMAC (symmetric) algorithms.
385    pub fn is_hmac(&self) -> bool {
386        matches!(
387            self.jwt_algorithm,
388            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
389        )
390    }
391
392    /// Check if this config uses RSA (asymmetric) algorithms.
393    pub fn is_rsa(&self) -> bool {
394        matches!(
395            self.jwt_algorithm,
396            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
397        )
398    }
399}
400
401fn default_jwks_cache_ttl() -> u64 {
402    3600 // 1 hour
403}
404
405fn default_session_ttl() -> u64 {
406    7 * 24 * 60 * 60 // 7 days
407}
408
409/// Substitute environment variables in the format ${VAR_NAME}.
410fn substitute_env_vars(content: &str) -> String {
411    let mut result = content.to_string();
412    let re = regex_lite::Regex::new(r"\$\{([A-Z_][A-Z0-9_]*)\}").unwrap();
413
414    for cap in re.captures_iter(content) {
415        let var_name = &cap[1];
416        if let Ok(value) = std::env::var(var_name) {
417            result = result.replace(&cap[0], &value);
418        }
419    }
420
421    result
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[test]
429    fn test_default_config() {
430        let config = ForgeConfig::default_with_database_url("postgres://localhost/test");
431        assert_eq!(config.gateway.port, 8080);
432        assert_eq!(config.node.roles.len(), 4);
433    }
434
435    #[test]
436    fn test_parse_minimal_config() {
437        let toml = r#"
438            [database]
439            url = "postgres://localhost/myapp"
440        "#;
441
442        let config = ForgeConfig::parse_toml(toml).unwrap();
443        assert_eq!(config.database.url, "postgres://localhost/myapp");
444        assert_eq!(config.gateway.port, 8080);
445    }
446
447    #[test]
448    fn test_parse_full_config() {
449        let toml = r#"
450            [project]
451            name = "my-app"
452            version = "1.0.0"
453
454            [database]
455            url = "postgres://localhost/myapp"
456            pool_size = 100
457
458            [node]
459            roles = ["gateway", "worker"]
460            worker_capabilities = ["media", "general"]
461
462            [gateway]
463            port = 3000
464            grpc_port = 9001
465        "#;
466
467        let config = ForgeConfig::parse_toml(toml).unwrap();
468        assert_eq!(config.project.name, "my-app");
469        assert_eq!(config.database.pool_size, 100);
470        assert_eq!(config.node.roles.len(), 2);
471        assert_eq!(config.gateway.port, 3000);
472    }
473
474    #[test]
475    fn test_env_var_substitution() {
476        unsafe {
477            std::env::set_var("TEST_DB_URL", "postgres://test:test@localhost/test");
478        }
479
480        let toml = r#"
481            [database]
482            url = "${TEST_DB_URL}"
483        "#;
484
485        let config = ForgeConfig::parse_toml(toml).unwrap();
486        assert_eq!(config.database.url, "postgres://test:test@localhost/test");
487
488        unsafe {
489            std::env::remove_var("TEST_DB_URL");
490        }
491    }
492}