1mod cluster;
2mod database;
3
4pub use cluster::ClusterConfig;
5pub use database::{DatabaseConfig, DatabaseSource};
6
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9
10use crate::error::{ForgeError, Result};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ForgeConfig {
15 #[serde(default)]
17 pub project: ProjectConfig,
18
19 pub database: DatabaseConfig,
21
22 #[serde(default)]
24 pub node: NodeConfig,
25
26 #[serde(default)]
28 pub gateway: GatewayConfig,
29
30 #[serde(default)]
32 pub function: FunctionConfig,
33
34 #[serde(default)]
36 pub worker: WorkerConfig,
37
38 #[serde(default)]
40 pub cluster: ClusterConfig,
41
42 #[serde(default)]
44 pub security: SecurityConfig,
45
46 #[serde(default)]
48 pub auth: AuthConfig,
49
50 #[serde(default)]
52 pub observability: ObservabilityConfig,
53}
54
55impl ForgeConfig {
56 pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
58 let content = std::fs::read_to_string(path.as_ref())
59 .map_err(|e| ForgeError::Config(format!("Failed to read config file: {}", e)))?;
60
61 Self::parse_toml(&content)
62 }
63
64 pub fn parse_toml(content: &str) -> Result<Self> {
66 let content = substitute_env_vars(content);
68
69 let config: Self = toml::from_str(&content)
70 .map_err(|e| ForgeError::Config(format!("Failed to parse config: {}", e)))?;
71
72 config.validate()?;
73 Ok(config)
74 }
75
76 pub fn validate(&self) -> Result<()> {
78 self.database.validate()?;
79 self.auth.validate()?;
80 Ok(())
81 }
82
83 pub fn default_with_database_url(url: &str) -> Self {
85 Self {
86 project: ProjectConfig::default(),
87 database: DatabaseConfig::remote(url),
88 node: NodeConfig::default(),
89 gateway: GatewayConfig::default(),
90 function: FunctionConfig::default(),
91 worker: WorkerConfig::default(),
92 cluster: ClusterConfig::default(),
93 security: SecurityConfig::default(),
94 auth: AuthConfig::default(),
95 observability: ObservabilityConfig::default(),
96 }
97 }
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct ProjectConfig {
103 #[serde(default = "default_project_name")]
105 pub name: String,
106
107 #[serde(default = "default_version")]
109 pub version: String,
110}
111
112impl Default for ProjectConfig {
113 fn default() -> Self {
114 Self {
115 name: default_project_name(),
116 version: default_version(),
117 }
118 }
119}
120
121fn default_project_name() -> String {
122 "forge-app".to_string()
123}
124
125fn default_version() -> String {
126 "0.1.0".to_string()
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct NodeConfig {
132 #[serde(default = "default_roles")]
134 pub roles: Vec<NodeRole>,
135
136 #[serde(default = "default_capabilities")]
138 pub worker_capabilities: Vec<String>,
139}
140
141impl Default for NodeConfig {
142 fn default() -> Self {
143 Self {
144 roles: default_roles(),
145 worker_capabilities: default_capabilities(),
146 }
147 }
148}
149
150fn default_roles() -> Vec<NodeRole> {
151 vec![
152 NodeRole::Gateway,
153 NodeRole::Function,
154 NodeRole::Worker,
155 NodeRole::Scheduler,
156 ]
157}
158
159fn default_capabilities() -> Vec<String> {
160 vec!["general".to_string()]
161}
162
163#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
165#[serde(rename_all = "lowercase")]
166pub enum NodeRole {
167 Gateway,
168 Function,
169 Worker,
170 Scheduler,
171}
172
173#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct GatewayConfig {
176 #[serde(default = "default_http_port")]
178 pub port: u16,
179
180 #[serde(default = "default_grpc_port")]
182 pub grpc_port: u16,
183
184 #[serde(default = "default_max_connections")]
186 pub max_connections: usize,
187
188 #[serde(default = "default_request_timeout")]
190 pub request_timeout_secs: u64,
191
192 #[serde(default = "default_cors_enabled")]
194 pub cors_enabled: bool,
195
196 #[serde(default = "default_cors_origins")]
198 pub cors_origins: Vec<String>,
199}
200
201impl Default for GatewayConfig {
202 fn default() -> Self {
203 Self {
204 port: default_http_port(),
205 grpc_port: default_grpc_port(),
206 max_connections: default_max_connections(),
207 request_timeout_secs: default_request_timeout(),
208 cors_enabled: default_cors_enabled(),
209 cors_origins: default_cors_origins(),
210 }
211 }
212}
213
214fn default_http_port() -> u16 {
215 8080
216}
217
218fn default_grpc_port() -> u16 {
219 9000
220}
221
222fn default_max_connections() -> usize {
223 512
224}
225
226fn default_request_timeout() -> u64 {
227 30
228}
229
230fn default_cors_enabled() -> bool {
231 false
232}
233
234fn default_cors_origins() -> Vec<String> {
235 Vec::new()
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct FunctionConfig {
241 #[serde(default = "default_max_concurrent")]
243 pub max_concurrent: usize,
244
245 #[serde(default = "default_function_timeout")]
247 pub timeout_secs: u64,
248
249 #[serde(default = "default_memory_limit")]
251 pub memory_limit: usize,
252}
253
254impl Default for FunctionConfig {
255 fn default() -> Self {
256 Self {
257 max_concurrent: default_max_concurrent(),
258 timeout_secs: default_function_timeout(),
259 memory_limit: default_memory_limit(),
260 }
261 }
262}
263
264fn default_max_concurrent() -> usize {
265 1000
266}
267
268fn default_function_timeout() -> u64 {
269 30
270}
271
272fn default_memory_limit() -> usize {
273 512 * 1024 * 1024 }
275
276#[derive(Debug, Clone, Serialize, Deserialize)]
278pub struct WorkerConfig {
279 #[serde(default = "default_max_concurrent_jobs")]
281 pub max_concurrent_jobs: usize,
282
283 #[serde(default = "default_job_timeout")]
285 pub job_timeout_secs: u64,
286
287 #[serde(default = "default_poll_interval")]
289 pub poll_interval_ms: u64,
290}
291
292impl Default for WorkerConfig {
293 fn default() -> Self {
294 Self {
295 max_concurrent_jobs: default_max_concurrent_jobs(),
296 job_timeout_secs: default_job_timeout(),
297 poll_interval_ms: default_poll_interval(),
298 }
299 }
300}
301
302fn default_max_concurrent_jobs() -> usize {
303 50
304}
305
306fn default_job_timeout() -> u64 {
307 3600 }
309
310fn default_poll_interval() -> u64 {
311 100
312}
313
314#[derive(Debug, Clone, Serialize, Deserialize, Default)]
316pub struct SecurityConfig {
317 pub secret_key: Option<String>,
319}
320
321#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
323#[serde(rename_all = "UPPERCASE")]
324pub enum JwtAlgorithm {
325 #[default]
327 HS256,
328 HS384,
330 HS512,
332 RS256,
334 RS384,
336 RS512,
338}
339
340#[derive(Debug, Clone, Serialize, Deserialize)]
342pub struct AuthConfig {
343 pub jwt_secret: Option<String>,
346
347 #[serde(default)]
351 pub jwt_algorithm: JwtAlgorithm,
352
353 pub jwt_issuer: Option<String>,
356
357 pub jwt_audience: Option<String>,
360
361 pub token_expiry: Option<String>,
363
364 pub jwks_url: Option<String>,
367
368 #[serde(default = "default_jwks_cache_ttl")]
370 pub jwks_cache_ttl_secs: u64,
371
372 #[serde(default = "default_session_ttl")]
374 pub session_ttl_secs: u64,
375}
376
377impl Default for AuthConfig {
378 fn default() -> Self {
379 Self {
380 jwt_secret: None,
381 jwt_algorithm: JwtAlgorithm::default(),
382 jwt_issuer: None,
383 jwt_audience: None,
384 token_expiry: None,
385 jwks_url: None,
386 jwks_cache_ttl_secs: default_jwks_cache_ttl(),
387 session_ttl_secs: default_session_ttl(),
388 }
389 }
390}
391
392impl AuthConfig {
393 fn is_configured(&self) -> bool {
395 self.jwt_secret.is_some()
396 || self.jwks_url.is_some()
397 || self.jwt_issuer.is_some()
398 || self.jwt_audience.is_some()
399 }
400
401 pub fn validate(&self) -> Result<()> {
404 if !self.is_configured() {
405 return Ok(());
406 }
407
408 match self.jwt_algorithm {
409 JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512 => {
410 if self.jwt_secret.is_none() {
411 return Err(ForgeError::Config(
412 "auth.jwt_secret is required for HMAC algorithms (HS256, HS384, HS512). \
413 Set auth.jwt_secret to a secure random string, \
414 or switch to RS256 and provide auth.jwks_url for external identity providers."
415 .into(),
416 ));
417 }
418 }
419 JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512 => {
420 if self.jwks_url.is_none() {
421 return Err(ForgeError::Config(
422 "auth.jwks_url is required for RSA algorithms (RS256, RS384, RS512). \
423 Set auth.jwks_url to your identity provider's JWKS endpoint, \
424 or switch to HS256 and provide auth.jwt_secret for symmetric signing."
425 .into(),
426 ));
427 }
428 }
429 }
430 Ok(())
431 }
432
433 pub fn is_hmac(&self) -> bool {
435 matches!(
436 self.jwt_algorithm,
437 JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
438 )
439 }
440
441 pub fn is_rsa(&self) -> bool {
443 matches!(
444 self.jwt_algorithm,
445 JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
446 )
447 }
448}
449
450fn default_jwks_cache_ttl() -> u64 {
451 3600 }
453
454fn default_session_ttl() -> u64 {
455 7 * 24 * 60 * 60 }
457
458#[derive(Debug, Clone, Serialize, Deserialize)]
460pub struct ObservabilityConfig {
461 #[serde(default)]
463 pub enabled: bool,
464
465 #[serde(default = "default_otlp_endpoint")]
467 pub otlp_endpoint: String,
468
469 pub service_name: Option<String>,
471
472 #[serde(default = "default_true")]
474 pub enable_traces: bool,
475
476 #[serde(default = "default_true")]
478 pub enable_metrics: bool,
479
480 #[serde(default = "default_true")]
482 pub enable_logs: bool,
483
484 #[serde(default = "default_sampling_ratio")]
486 pub sampling_ratio: f64,
487}
488
489impl Default for ObservabilityConfig {
490 fn default() -> Self {
491 Self {
492 enabled: false,
493 otlp_endpoint: default_otlp_endpoint(),
494 service_name: None,
495 enable_traces: true,
496 enable_metrics: true,
497 enable_logs: true,
498 sampling_ratio: default_sampling_ratio(),
499 }
500 }
501}
502
503fn default_otlp_endpoint() -> String {
504 "http://localhost:4317".to_string()
505}
506
507fn default_true() -> bool {
508 true
509}
510
511fn default_sampling_ratio() -> f64 {
512 1.0
513}
514
515fn substitute_env_vars(content: &str) -> String {
517 let mut result = content.to_string();
518 let re = regex_lite::Regex::new(r"\$\{([A-Z_][A-Z0-9_]*)\}").expect("valid regex pattern");
519
520 for cap in re.captures_iter(content) {
521 let var_name = &cap[1];
522 if let Ok(value) = std::env::var(var_name) {
523 result = result.replace(&cap[0], &value);
524 }
525 }
526
527 result
528}
529
530#[cfg(test)]
531#[allow(clippy::unwrap_used, clippy::indexing_slicing, unsafe_code)]
532mod tests {
533 use super::*;
534
535 #[test]
536 fn test_default_config() {
537 let config = ForgeConfig::default_with_database_url("postgres://localhost/test");
538 assert_eq!(config.gateway.port, 8080);
539 assert_eq!(config.node.roles.len(), 4);
540 }
541
542 #[test]
543 fn test_parse_minimal_config() {
544 let toml = r#"
545 [database]
546 mode = "remote"
547 url = "postgres://localhost/myapp"
548 "#;
549
550 let config = ForgeConfig::parse_toml(toml).unwrap();
551 assert_eq!(config.database.url(), Some("postgres://localhost/myapp"));
552 assert_eq!(config.gateway.port, 8080);
553 }
554
555 #[test]
556 fn test_parse_full_config() {
557 let toml = r#"
558 [project]
559 name = "my-app"
560 version = "1.0.0"
561
562 [database]
563 mode = "remote"
564 url = "postgres://localhost/myapp"
565 pool_size = 100
566
567 [node]
568 roles = ["gateway", "worker"]
569 worker_capabilities = ["media", "general"]
570
571 [gateway]
572 port = 3000
573 grpc_port = 9001
574 "#;
575
576 let config = ForgeConfig::parse_toml(toml).unwrap();
577 assert_eq!(config.project.name, "my-app");
578 assert_eq!(config.database.pool_size, 100);
579 assert_eq!(config.node.roles.len(), 2);
580 assert_eq!(config.gateway.port, 3000);
581 }
582
583 #[test]
584 fn test_env_var_substitution() {
585 unsafe {
586 std::env::set_var("TEST_DB_URL", "postgres://test:test@localhost/test");
587 }
588
589 let toml = r#"
590 [database]
591 mode = "remote"
592 url = "${TEST_DB_URL}"
593 "#;
594
595 let config = ForgeConfig::parse_toml(toml).unwrap();
596 assert_eq!(
597 config.database.url(),
598 Some("postgres://test:test@localhost/test")
599 );
600
601 unsafe {
602 std::env::remove_var("TEST_DB_URL");
603 }
604 }
605
606 #[test]
607 fn test_auth_validation_no_config() {
608 let auth = AuthConfig::default();
609 assert!(auth.validate().is_ok());
610 }
611
612 #[test]
613 fn test_auth_validation_hmac_with_secret() {
614 let auth = AuthConfig {
615 jwt_secret: Some("my-secret".into()),
616 jwt_algorithm: JwtAlgorithm::HS256,
617 ..Default::default()
618 };
619 assert!(auth.validate().is_ok());
620 }
621
622 #[test]
623 fn test_auth_validation_hmac_missing_secret() {
624 let auth = AuthConfig {
625 jwt_issuer: Some("my-issuer".into()),
626 jwt_algorithm: JwtAlgorithm::HS256,
627 ..Default::default()
628 };
629 let result = auth.validate();
630 assert!(result.is_err());
631 let err_msg = result.unwrap_err().to_string();
632 assert!(err_msg.contains("jwt_secret is required"));
633 }
634
635 #[test]
636 fn test_auth_validation_rsa_with_jwks() {
637 let auth = AuthConfig {
638 jwks_url: Some("https://example.com/.well-known/jwks.json".into()),
639 jwt_algorithm: JwtAlgorithm::RS256,
640 ..Default::default()
641 };
642 assert!(auth.validate().is_ok());
643 }
644
645 #[test]
646 fn test_auth_validation_rsa_missing_jwks() {
647 let auth = AuthConfig {
648 jwt_issuer: Some("my-issuer".into()),
649 jwt_algorithm: JwtAlgorithm::RS256,
650 ..Default::default()
651 };
652 let result = auth.validate();
653 assert!(result.is_err());
654 let err_msg = result.unwrap_err().to_string();
655 assert!(err_msg.contains("jwks_url is required"));
656 }
657
658 #[test]
659 fn test_forge_config_validation_fails_on_empty_url() {
660 let toml = r#"
661 [database]
662 mode = "remote"
663 url = ""
664 "#;
665
666 let result = ForgeConfig::parse_toml(toml);
667 assert!(result.is_err());
668 let err_msg = result.unwrap_err().to_string();
669 assert!(err_msg.contains("database.url is required"));
670 }
671
672 #[test]
673 fn test_forge_config_validation_fails_on_invalid_auth() {
674 let toml = r#"
675 [database]
676 mode = "remote"
677 url = "postgres://localhost/test"
678
679 [auth]
680 jwt_issuer = "my-issuer"
681 jwt_algorithm = "RS256"
682 "#;
683
684 let result = ForgeConfig::parse_toml(toml);
685 assert!(result.is_err());
686 let err_msg = result.unwrap_err().to_string();
687 assert!(err_msg.contains("jwks_url is required"));
688 }
689}