1mod cluster;
2mod database;
3
4pub use cluster::ClusterConfig;
5pub use database::DatabaseConfig;
6
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9
10use crate::error::{ForgeError, Result};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ForgeConfig {
15 #[serde(default)]
17 pub project: ProjectConfig,
18
19 pub database: DatabaseConfig,
21
22 #[serde(default)]
24 pub node: NodeConfig,
25
26 #[serde(default)]
28 pub gateway: GatewayConfig,
29
30 #[serde(default)]
32 pub function: FunctionConfig,
33
34 #[serde(default)]
36 pub worker: WorkerConfig,
37
38 #[serde(default)]
40 pub cluster: ClusterConfig,
41
42 #[serde(default)]
44 pub security: SecurityConfig,
45
46 #[serde(default)]
48 pub auth: AuthConfig,
49}
50
51impl ForgeConfig {
52 pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
54 let content = std::fs::read_to_string(path.as_ref())
55 .map_err(|e| ForgeError::Config(format!("Failed to read config file: {}", e)))?;
56
57 Self::parse_toml(&content)
58 }
59
60 pub fn parse_toml(content: &str) -> Result<Self> {
62 let content = substitute_env_vars(content);
64
65 toml::from_str(&content)
66 .map_err(|e| ForgeError::Config(format!("Failed to parse config: {}", e)))
67 }
68
69 pub fn default_with_database_url(url: &str) -> Self {
71 Self {
72 project: ProjectConfig::default(),
73 database: DatabaseConfig {
74 url: url.to_string(),
75 ..Default::default()
76 },
77 node: NodeConfig::default(),
78 gateway: GatewayConfig::default(),
79 function: FunctionConfig::default(),
80 worker: WorkerConfig::default(),
81 cluster: ClusterConfig::default(),
82 security: SecurityConfig::default(),
83 auth: AuthConfig::default(),
84 }
85 }
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct ProjectConfig {
91 #[serde(default = "default_project_name")]
93 pub name: String,
94
95 #[serde(default = "default_version")]
97 pub version: String,
98}
99
100impl Default for ProjectConfig {
101 fn default() -> Self {
102 Self {
103 name: default_project_name(),
104 version: default_version(),
105 }
106 }
107}
108
109fn default_project_name() -> String {
110 "forge-app".to_string()
111}
112
113fn default_version() -> String {
114 "0.1.0".to_string()
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct NodeConfig {
120 #[serde(default = "default_roles")]
122 pub roles: Vec<NodeRole>,
123
124 #[serde(default = "default_capabilities")]
126 pub worker_capabilities: Vec<String>,
127}
128
129impl Default for NodeConfig {
130 fn default() -> Self {
131 Self {
132 roles: default_roles(),
133 worker_capabilities: default_capabilities(),
134 }
135 }
136}
137
138fn default_roles() -> Vec<NodeRole> {
139 vec![
140 NodeRole::Gateway,
141 NodeRole::Function,
142 NodeRole::Worker,
143 NodeRole::Scheduler,
144 ]
145}
146
147fn default_capabilities() -> Vec<String> {
148 vec!["general".to_string()]
149}
150
151#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
153#[serde(rename_all = "lowercase")]
154pub enum NodeRole {
155 Gateway,
156 Function,
157 Worker,
158 Scheduler,
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct GatewayConfig {
164 #[serde(default = "default_http_port")]
166 pub port: u16,
167
168 #[serde(default = "default_grpc_port")]
170 pub grpc_port: u16,
171
172 #[serde(default = "default_max_connections")]
174 pub max_connections: usize,
175
176 #[serde(default = "default_request_timeout")]
178 pub request_timeout_secs: u64,
179}
180
181impl Default for GatewayConfig {
182 fn default() -> Self {
183 Self {
184 port: default_http_port(),
185 grpc_port: default_grpc_port(),
186 max_connections: default_max_connections(),
187 request_timeout_secs: default_request_timeout(),
188 }
189 }
190}
191
192fn default_http_port() -> u16 {
193 8080
194}
195
196fn default_grpc_port() -> u16 {
197 9000
198}
199
200fn default_max_connections() -> usize {
201 10000
202}
203
204fn default_request_timeout() -> u64 {
205 30
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct FunctionConfig {
211 #[serde(default = "default_max_concurrent")]
213 pub max_concurrent: usize,
214
215 #[serde(default = "default_function_timeout")]
217 pub timeout_secs: u64,
218
219 #[serde(default = "default_memory_limit")]
221 pub memory_limit: usize,
222}
223
224impl Default for FunctionConfig {
225 fn default() -> Self {
226 Self {
227 max_concurrent: default_max_concurrent(),
228 timeout_secs: default_function_timeout(),
229 memory_limit: default_memory_limit(),
230 }
231 }
232}
233
234fn default_max_concurrent() -> usize {
235 1000
236}
237
238fn default_function_timeout() -> u64 {
239 30
240}
241
242fn default_memory_limit() -> usize {
243 512 * 1024 * 1024 }
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct WorkerConfig {
249 #[serde(default = "default_max_concurrent_jobs")]
251 pub max_concurrent_jobs: usize,
252
253 #[serde(default = "default_job_timeout")]
255 pub job_timeout_secs: u64,
256
257 #[serde(default = "default_poll_interval")]
259 pub poll_interval_ms: u64,
260}
261
262impl Default for WorkerConfig {
263 fn default() -> Self {
264 Self {
265 max_concurrent_jobs: default_max_concurrent_jobs(),
266 job_timeout_secs: default_job_timeout(),
267 poll_interval_ms: default_poll_interval(),
268 }
269 }
270}
271
272fn default_max_concurrent_jobs() -> usize {
273 50
274}
275
276fn default_job_timeout() -> u64 {
277 3600 }
279
280fn default_poll_interval() -> u64 {
281 100
282}
283
284#[derive(Debug, Clone, Serialize, Deserialize, Default)]
286pub struct SecurityConfig {
287 pub secret_key: Option<String>,
289}
290
291#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
293#[serde(rename_all = "UPPERCASE")]
294pub enum JwtAlgorithm {
295 #[default]
297 HS256,
298 HS384,
300 HS512,
302 RS256,
304 RS384,
306 RS512,
308}
309
310#[derive(Debug, Clone, Serialize, Deserialize)]
312pub struct AuthConfig {
313 pub jwt_secret: Option<String>,
316
317 #[serde(default)]
321 pub jwt_algorithm: JwtAlgorithm,
322
323 pub jwt_issuer: Option<String>,
326
327 pub jwt_audience: Option<String>,
330
331 pub token_expiry: Option<String>,
333
334 pub jwks_url: Option<String>,
337
338 #[serde(default = "default_jwks_cache_ttl")]
340 pub jwks_cache_ttl_secs: u64,
341
342 #[serde(default = "default_session_ttl")]
344 pub session_ttl_secs: u64,
345}
346
347impl Default for AuthConfig {
348 fn default() -> Self {
349 Self {
350 jwt_secret: None,
351 jwt_algorithm: JwtAlgorithm::default(),
352 jwt_issuer: None,
353 jwt_audience: None,
354 token_expiry: None,
355 jwks_url: None,
356 jwks_cache_ttl_secs: default_jwks_cache_ttl(),
357 session_ttl_secs: default_session_ttl(),
358 }
359 }
360}
361
362impl AuthConfig {
363 pub fn validate(&self) -> Result<()> {
365 match self.jwt_algorithm {
366 JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512 => {
367 if self.jwt_secret.is_none() {
368 return Err(ForgeError::Config(
369 "jwt_secret is required for HMAC algorithms (HS256, HS384, HS512)".into(),
370 ));
371 }
372 }
373 JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512 => {
374 if self.jwks_url.is_none() {
375 return Err(ForgeError::Config(
376 "jwks_url is required for RSA algorithms (RS256, RS384, RS512)".into(),
377 ));
378 }
379 }
380 }
381 Ok(())
382 }
383
384 pub fn is_hmac(&self) -> bool {
386 matches!(
387 self.jwt_algorithm,
388 JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
389 )
390 }
391
392 pub fn is_rsa(&self) -> bool {
394 matches!(
395 self.jwt_algorithm,
396 JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
397 )
398 }
399}
400
401fn default_jwks_cache_ttl() -> u64 {
402 3600 }
404
405fn default_session_ttl() -> u64 {
406 7 * 24 * 60 * 60 }
408
409fn substitute_env_vars(content: &str) -> String {
411 let mut result = content.to_string();
412 let re = regex_lite::Regex::new(r"\$\{([A-Z_][A-Z0-9_]*)\}").unwrap();
413
414 for cap in re.captures_iter(content) {
415 let var_name = &cap[1];
416 if let Ok(value) = std::env::var(var_name) {
417 result = result.replace(&cap[0], &value);
418 }
419 }
420
421 result
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 #[test]
429 fn test_default_config() {
430 let config = ForgeConfig::default_with_database_url("postgres://localhost/test");
431 assert_eq!(config.gateway.port, 8080);
432 assert_eq!(config.node.roles.len(), 4);
433 }
434
435 #[test]
436 fn test_parse_minimal_config() {
437 let toml = r#"
438 [database]
439 url = "postgres://localhost/myapp"
440 "#;
441
442 let config = ForgeConfig::parse_toml(toml).unwrap();
443 assert_eq!(config.database.url, "postgres://localhost/myapp");
444 assert_eq!(config.gateway.port, 8080);
445 }
446
447 #[test]
448 fn test_parse_full_config() {
449 let toml = r#"
450 [project]
451 name = "my-app"
452 version = "1.0.0"
453
454 [database]
455 url = "postgres://localhost/myapp"
456 pool_size = 100
457
458 [node]
459 roles = ["gateway", "worker"]
460 worker_capabilities = ["media", "general"]
461
462 [gateway]
463 port = 3000
464 grpc_port = 9001
465 "#;
466
467 let config = ForgeConfig::parse_toml(toml).unwrap();
468 assert_eq!(config.project.name, "my-app");
469 assert_eq!(config.database.pool_size, 100);
470 assert_eq!(config.node.roles.len(), 2);
471 assert_eq!(config.gateway.port, 3000);
472 }
473
474 #[test]
475 fn test_env_var_substitution() {
476 unsafe {
477 std::env::set_var("TEST_DB_URL", "postgres://test:test@localhost/test");
478 }
479
480 let toml = r#"
481 [database]
482 url = "${TEST_DB_URL}"
483 "#;
484
485 let config = ForgeConfig::parse_toml(toml).unwrap();
486 assert_eq!(config.database.url, "postgres://test:test@localhost/test");
487
488 unsafe {
489 std::env::remove_var("TEST_DB_URL");
490 }
491 }
492}