Skip to main content

forge_core/config/
mod.rs

1pub mod cluster;
2mod database;
3
4pub use cluster::ClusterConfig;
5pub use database::{DatabaseConfig, PoolConfig};
6
7use serde::{Deserialize, Serialize};
8use std::path::Path;
9
10use crate::error::{ForgeError, Result};
11
12/// Root configuration for FORGE.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ForgeConfig {
15    /// Project metadata.
16    #[serde(default)]
17    pub project: ProjectConfig,
18
19    /// Database configuration.
20    pub database: DatabaseConfig,
21
22    /// Node configuration.
23    #[serde(default)]
24    pub node: NodeConfig,
25
26    /// Gateway configuration.
27    #[serde(default)]
28    pub gateway: GatewayConfig,
29
30    /// Function execution configuration.
31    #[serde(default)]
32    pub function: FunctionConfig,
33
34    /// Worker configuration.
35    #[serde(default)]
36    pub worker: WorkerConfig,
37
38    /// Cluster configuration.
39    #[serde(default)]
40    pub cluster: ClusterConfig,
41
42    /// Security configuration.
43    #[serde(default)]
44    pub security: SecurityConfig,
45
46    /// Authentication configuration.
47    #[serde(default)]
48    pub auth: AuthConfig,
49
50    /// Observability configuration.
51    #[serde(default)]
52    pub observability: ObservabilityConfig,
53
54    /// MCP server configuration.
55    #[serde(default)]
56    pub mcp: McpConfig,
57}
58
59impl ForgeConfig {
60    /// Load configuration from a TOML file.
61    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
62        let content = std::fs::read_to_string(path.as_ref())
63            .map_err(|e| ForgeError::Config(format!("Failed to read config file: {}", e)))?;
64
65        Self::parse_toml(&content)
66    }
67
68    /// Parse configuration from a TOML string.
69    pub fn parse_toml(content: &str) -> Result<Self> {
70        // Substitute environment variables
71        let content = substitute_env_vars(content);
72
73        let config: Self = toml::from_str(&content)
74            .map_err(|e| ForgeError::Config(format!("Failed to parse config: {}", e)))?;
75
76        config.validate()?;
77        Ok(config)
78    }
79
80    /// Validate the configuration for invalid combinations.
81    pub fn validate(&self) -> Result<()> {
82        self.database.validate()?;
83        self.auth.validate()?;
84        self.mcp.validate()?;
85
86        // Cross-field: OAuth requires jwt_secret for signing tokens
87        if self.mcp.oauth && self.auth.jwt_secret.is_none() {
88            return Err(ForgeError::Config(
89                "mcp.oauth = true requires auth.jwt_secret to be set. \
90                 OAuth-issued tokens are signed with this secret, even when using \
91                 an external provider (JWKS) for identity verification."
92                    .into(),
93            ));
94        }
95        if self.mcp.oauth && !self.mcp.enabled {
96            return Err(ForgeError::Config(
97                "mcp.oauth = true requires mcp.enabled = true".into(),
98            ));
99        }
100
101        Ok(())
102    }
103
104    /// Load configuration with defaults.
105    pub fn default_with_database_url(url: &str) -> Self {
106        Self {
107            project: ProjectConfig::default(),
108            database: DatabaseConfig::new(url),
109            node: NodeConfig::default(),
110            gateway: GatewayConfig::default(),
111            function: FunctionConfig::default(),
112            worker: WorkerConfig::default(),
113            cluster: ClusterConfig::default(),
114            security: SecurityConfig::default(),
115            auth: AuthConfig::default(),
116            observability: ObservabilityConfig::default(),
117            mcp: McpConfig::default(),
118        }
119    }
120}
121
122/// Project metadata.
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct ProjectConfig {
125    /// Project name.
126    #[serde(default = "default_project_name")]
127    pub name: String,
128
129    /// Project version.
130    #[serde(default = "default_version")]
131    pub version: String,
132}
133
134impl Default for ProjectConfig {
135    fn default() -> Self {
136        Self {
137            name: default_project_name(),
138            version: default_version(),
139        }
140    }
141}
142
143fn default_project_name() -> String {
144    "forge-app".to_string()
145}
146
147fn default_version() -> String {
148    "0.1.0".to_string()
149}
150
151/// Node role configuration.
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct NodeConfig {
154    /// Roles this node should assume.
155    #[serde(default = "default_roles")]
156    pub roles: Vec<NodeRole>,
157
158    /// Worker capabilities for job routing.
159    #[serde(default = "default_capabilities")]
160    pub worker_capabilities: Vec<String>,
161}
162
163impl Default for NodeConfig {
164    fn default() -> Self {
165        Self {
166            roles: default_roles(),
167            worker_capabilities: default_capabilities(),
168        }
169    }
170}
171
172fn default_roles() -> Vec<NodeRole> {
173    vec![
174        NodeRole::Gateway,
175        NodeRole::Function,
176        NodeRole::Worker,
177        NodeRole::Scheduler,
178    ]
179}
180
181fn default_capabilities() -> Vec<String> {
182    vec!["general".to_string()]
183}
184
185/// Available node roles.
186#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
187#[serde(rename_all = "lowercase")]
188pub enum NodeRole {
189    Gateway,
190    Function,
191    Worker,
192    Scheduler,
193}
194
195/// Gateway configuration.
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct GatewayConfig {
198    /// HTTP port.
199    #[serde(default = "default_http_port")]
200    pub port: u16,
201
202    /// gRPC port for inter-node communication.
203    #[serde(default = "default_grpc_port")]
204    pub grpc_port: u16,
205
206    /// Maximum concurrent connections.
207    #[serde(default = "default_max_connections")]
208    pub max_connections: usize,
209
210    /// Maximum active SSE sessions.
211    #[serde(default = "default_sse_max_sessions")]
212    pub sse_max_sessions: usize,
213
214    /// Request timeout in seconds.
215    #[serde(default = "default_request_timeout")]
216    pub request_timeout_secs: u64,
217
218    /// Enable CORS handling.
219    #[serde(default = "default_cors_enabled")]
220    pub cors_enabled: bool,
221
222    /// Allowed CORS origins.
223    #[serde(default = "default_cors_origins")]
224    pub cors_origins: Vec<String>,
225
226    /// Routes excluded from request logs, metrics, and traces.
227    /// Defaults to `["/_api/health", "/_api/ready"]`. Set to `[]` to monitor everything.
228    #[serde(default = "default_quiet_routes")]
229    pub quiet_routes: Vec<String>,
230}
231
232impl Default for GatewayConfig {
233    fn default() -> Self {
234        Self {
235            port: default_http_port(),
236            grpc_port: default_grpc_port(),
237            max_connections: default_max_connections(),
238            sse_max_sessions: default_sse_max_sessions(),
239            request_timeout_secs: default_request_timeout(),
240            cors_enabled: default_cors_enabled(),
241            cors_origins: default_cors_origins(),
242            quiet_routes: default_quiet_routes(),
243        }
244    }
245}
246
247fn default_http_port() -> u16 {
248    9081
249}
250
251fn default_grpc_port() -> u16 {
252    9000
253}
254
255fn default_max_connections() -> usize {
256    4096
257}
258
259fn default_sse_max_sessions() -> usize {
260    10_000
261}
262
263fn default_request_timeout() -> u64 {
264    30
265}
266
267fn default_cors_enabled() -> bool {
268    false
269}
270
271fn default_cors_origins() -> Vec<String> {
272    Vec::new()
273}
274
275fn default_quiet_routes() -> Vec<String> {
276    vec!["/_api/health".to_string(), "/_api/ready".to_string()]
277}
278
279/// Function execution configuration.
280#[derive(Debug, Clone, Serialize, Deserialize)]
281pub struct FunctionConfig {
282    /// Maximum concurrent function executions.
283    #[serde(default = "default_max_concurrent")]
284    pub max_concurrent: usize,
285
286    /// Function timeout in seconds.
287    #[serde(default = "default_function_timeout")]
288    pub timeout_secs: u64,
289
290    /// Memory limit per function (in bytes).
291    #[serde(default = "default_memory_limit")]
292    pub memory_limit: usize,
293}
294
295impl Default for FunctionConfig {
296    fn default() -> Self {
297        Self {
298            max_concurrent: default_max_concurrent(),
299            timeout_secs: default_function_timeout(),
300            memory_limit: default_memory_limit(),
301        }
302    }
303}
304
305fn default_max_concurrent() -> usize {
306    1000
307}
308
309fn default_function_timeout() -> u64 {
310    30
311}
312
313fn default_memory_limit() -> usize {
314    512 * 1024 * 1024 // 512 MiB
315}
316
317/// Worker configuration.
318#[derive(Debug, Clone, Serialize, Deserialize)]
319pub struct WorkerConfig {
320    /// Maximum concurrent jobs.
321    #[serde(default = "default_max_concurrent_jobs")]
322    pub max_concurrent_jobs: usize,
323
324    /// Job timeout in seconds.
325    #[serde(default = "default_job_timeout")]
326    pub job_timeout_secs: u64,
327
328    /// Poll interval in milliseconds.
329    #[serde(default = "default_poll_interval")]
330    pub poll_interval_ms: u64,
331}
332
333impl Default for WorkerConfig {
334    fn default() -> Self {
335        Self {
336            max_concurrent_jobs: default_max_concurrent_jobs(),
337            job_timeout_secs: default_job_timeout(),
338            poll_interval_ms: default_poll_interval(),
339        }
340    }
341}
342
343fn default_max_concurrent_jobs() -> usize {
344    50
345}
346
347fn default_job_timeout() -> u64 {
348    3600 // 1 hour
349}
350
351fn default_poll_interval() -> u64 {
352    100
353}
354
355/// Security configuration.
356#[derive(Debug, Clone, Serialize, Deserialize, Default)]
357pub struct SecurityConfig {
358    /// Secret key for signing.
359    pub secret_key: Option<String>,
360}
361
362/// JWT signing algorithm.
363#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
364#[serde(rename_all = "UPPERCASE")]
365pub enum JwtAlgorithm {
366    /// HMAC using SHA-256 (symmetric, requires jwt_secret).
367    #[default]
368    HS256,
369    /// HMAC using SHA-384 (symmetric, requires jwt_secret).
370    HS384,
371    /// HMAC using SHA-512 (symmetric, requires jwt_secret).
372    HS512,
373    /// RSA using SHA-256 (asymmetric, requires jwks_url).
374    RS256,
375    /// RSA using SHA-384 (asymmetric, requires jwks_url).
376    RS384,
377    /// RSA using SHA-512 (asymmetric, requires jwks_url).
378    RS512,
379}
380
381/// Authentication configuration.
382#[derive(Debug, Clone, Serialize, Deserialize)]
383pub struct AuthConfig {
384    /// JWT secret for HMAC algorithms (HS256, HS384, HS512).
385    /// Required when using HMAC algorithms.
386    pub jwt_secret: Option<String>,
387
388    /// JWT signing algorithm.
389    /// HMAC algorithms (HS256, HS384, HS512) require jwt_secret.
390    /// RSA algorithms (RS256, RS384, RS512) require jwks_url.
391    #[serde(default)]
392    pub jwt_algorithm: JwtAlgorithm,
393
394    /// Expected token issuer (iss claim).
395    /// If set, tokens with a different issuer are rejected.
396    pub jwt_issuer: Option<String>,
397
398    /// Expected audience (aud claim).
399    /// If set, tokens with a different audience are rejected.
400    pub jwt_audience: Option<String>,
401
402    /// Token expiry duration (e.g., "15m", "1h", "7d").
403    /// Deprecated: use `access_token_ttl` instead.
404    pub token_expiry: Option<String>,
405
406    /// Access token lifetime (e.g., "15m", "1h").
407    /// Used by `ctx.issue_token_pair()`. Defaults to "1h".
408    pub access_token_ttl: Option<String>,
409
410    /// Refresh token lifetime (e.g., "7d", "30d").
411    /// Used by `ctx.issue_token_pair()`. Defaults to "30d".
412    pub refresh_token_ttl: Option<String>,
413
414    /// JWKS URL for RSA algorithms (RS256, RS384, RS512).
415    /// Keys are fetched and cached automatically.
416    pub jwks_url: Option<String>,
417
418    /// JWKS cache TTL in seconds.
419    #[serde(default = "default_jwks_cache_ttl")]
420    pub jwks_cache_ttl_secs: u64,
421
422    /// Session TTL in seconds (for WebSocket sessions).
423    #[serde(default = "default_session_ttl")]
424    pub session_ttl_secs: u64,
425}
426
427impl Default for AuthConfig {
428    fn default() -> Self {
429        Self {
430            jwt_secret: None,
431            jwt_algorithm: JwtAlgorithm::default(),
432            jwt_issuer: None,
433            jwt_audience: None,
434            token_expiry: None,
435            access_token_ttl: None,
436            refresh_token_ttl: None,
437            jwks_url: None,
438            jwks_cache_ttl_secs: default_jwks_cache_ttl(),
439            session_ttl_secs: default_session_ttl(),
440        }
441    }
442}
443
444impl AuthConfig {
445    /// Resolved access token TTL in seconds.
446    /// Checks `access_token_ttl`, falls back to `token_expiry`, then default 3600s (1h).
447    /// Minimum 1 second to prevent zero-lifetime tokens.
448    pub fn access_token_ttl_secs(&self) -> i64 {
449        self.access_token_ttl
450            .as_deref()
451            .or(self.token_expiry.as_deref())
452            .and_then(crate::util::parse_duration)
453            .map(|d| (d.as_secs() as i64).max(1))
454            .unwrap_or(3600)
455    }
456
457    /// Resolved refresh token TTL in days.
458    /// Parses `refresh_token_ttl`, default 30 days.
459    pub fn refresh_token_ttl_days(&self) -> i64 {
460        self.refresh_token_ttl
461            .as_deref()
462            .and_then(crate::util::parse_duration)
463            .map(|d| (d.as_secs() / 86400) as i64)
464            .map(|d| if d == 0 { 1 } else { d })
465            .unwrap_or(30)
466    }
467
468    /// Check if auth is configured (any credential or claim validation is set).
469    fn is_configured(&self) -> bool {
470        self.jwt_secret.is_some()
471            || self.jwks_url.is_some()
472            || self.jwt_issuer.is_some()
473            || self.jwt_audience.is_some()
474    }
475
476    /// Validate that the configuration is complete for the chosen algorithm.
477    /// Skips validation if no auth settings are configured (auth disabled).
478    pub fn validate(&self) -> Result<()> {
479        if !self.is_configured() {
480            return Ok(());
481        }
482
483        match self.jwt_algorithm {
484            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512 => {
485                if self.jwt_secret.is_none() {
486                    return Err(ForgeError::Config(
487                        "auth.jwt_secret is required for HMAC algorithms (HS256, HS384, HS512). \
488                         Set auth.jwt_secret to a secure random string, \
489                         or switch to RS256 and provide auth.jwks_url for external identity providers."
490                            .into(),
491                    ));
492                }
493            }
494            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512 => {
495                if self.jwks_url.is_none() {
496                    return Err(ForgeError::Config(
497                        "auth.jwks_url is required for RSA algorithms (RS256, RS384, RS512). \
498                         Set auth.jwks_url to your identity provider's JWKS endpoint, \
499                         or switch to HS256 and provide auth.jwt_secret for symmetric signing."
500                            .into(),
501                    ));
502                }
503            }
504        }
505        Ok(())
506    }
507
508    /// Check if this config uses HMAC (symmetric) algorithms.
509    pub fn is_hmac(&self) -> bool {
510        matches!(
511            self.jwt_algorithm,
512            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
513        )
514    }
515
516    /// Check if this config uses RSA (asymmetric) algorithms.
517    pub fn is_rsa(&self) -> bool {
518        matches!(
519            self.jwt_algorithm,
520            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
521        )
522    }
523}
524
525fn default_jwks_cache_ttl() -> u64 {
526    3600 // 1 hour
527}
528
529fn default_session_ttl() -> u64 {
530    7 * 24 * 60 * 60 // 7 days
531}
532
533/// Observability configuration for OTLP telemetry.
534#[derive(Debug, Clone, Serialize, Deserialize)]
535pub struct ObservabilityConfig {
536    /// Enable observability (traces, metrics, logs).
537    #[serde(default)]
538    pub enabled: bool,
539
540    /// OTLP endpoint for telemetry export.
541    #[serde(default = "default_otlp_endpoint")]
542    pub otlp_endpoint: String,
543
544    /// Service name for telemetry identification.
545    pub service_name: Option<String>,
546
547    /// Enable distributed tracing.
548    #[serde(default = "default_true")]
549    pub enable_traces: bool,
550
551    /// Enable metrics collection.
552    #[serde(default = "default_true")]
553    pub enable_metrics: bool,
554
555    /// Enable log export via OTLP.
556    #[serde(default = "default_true")]
557    pub enable_logs: bool,
558
559    /// Trace sampling ratio (0.0 to 1.0).
560    #[serde(default = "default_sampling_ratio")]
561    pub sampling_ratio: f64,
562
563    /// Log level for the tracing subscriber (e.g., "debug", "info", "warn").
564    #[serde(default = "default_log_level")]
565    pub log_level: String,
566}
567
568impl Default for ObservabilityConfig {
569    fn default() -> Self {
570        Self {
571            enabled: false,
572            otlp_endpoint: default_otlp_endpoint(),
573            service_name: None,
574            enable_traces: true,
575            enable_metrics: true,
576            enable_logs: true,
577            sampling_ratio: default_sampling_ratio(),
578            log_level: default_log_level(),
579        }
580    }
581}
582
583impl ObservabilityConfig {
584    pub fn otlp_active(&self) -> bool {
585        self.enabled && (self.enable_traces || self.enable_metrics || self.enable_logs)
586    }
587}
588
589fn default_otlp_endpoint() -> String {
590    "http://localhost:4318".to_string()
591}
592
593fn default_true() -> bool {
594    true
595}
596
597fn default_sampling_ratio() -> f64 {
598    1.0
599}
600
601fn default_log_level() -> String {
602    "info".to_string()
603}
604
605/// MCP server configuration.
606#[derive(Debug, Clone, Serialize, Deserialize)]
607pub struct McpConfig {
608    /// Enable MCP endpoint exposure.
609    #[serde(default)]
610    pub enabled: bool,
611
612    /// Enable OAuth 2.1 Authorization Code + PKCE for MCP clients.
613    /// When true, Forge acts as an OAuth 2.1 Authorization Server so MCP
614    /// clients like Claude Code can auto-authenticate via browser login.
615    /// Requires `auth.jwt_secret` to be set.
616    #[serde(default)]
617    pub oauth: bool,
618
619    /// MCP endpoint path under the gateway API namespace.
620    #[serde(default = "default_mcp_path")]
621    pub path: String,
622
623    /// Session TTL in seconds.
624    #[serde(default = "default_mcp_session_ttl_secs")]
625    pub session_ttl_secs: u64,
626
627    /// Allowed origins for Origin header validation.
628    #[serde(default)]
629    pub allowed_origins: Vec<String>,
630
631    /// Enforce MCP-Protocol-Version header on post-initialize requests.
632    #[serde(default = "default_true")]
633    pub require_protocol_version_header: bool,
634}
635
636impl Default for McpConfig {
637    fn default() -> Self {
638        Self {
639            enabled: false,
640            oauth: false,
641            path: default_mcp_path(),
642            session_ttl_secs: default_mcp_session_ttl_secs(),
643            allowed_origins: Vec::new(),
644            require_protocol_version_header: default_true(),
645        }
646    }
647}
648
649impl McpConfig {
650    /// Paths reserved by the gateway that MCP must not collide with.
651    const RESERVED_PATHS: &[&str] = &[
652        "/health",
653        "/ready",
654        "/rpc",
655        "/events",
656        "/subscribe",
657        "/unsubscribe",
658        "/subscribe-job",
659        "/subscribe-workflow",
660        "/metrics",
661    ];
662
663    pub fn validate(&self) -> Result<()> {
664        if self.path.is_empty() || !self.path.starts_with('/') {
665            return Err(ForgeError::Config(
666                "mcp.path must start with '/' (example: /mcp)".to_string(),
667            ));
668        }
669        if self.path.contains(' ') {
670            return Err(ForgeError::Config(
671                "mcp.path cannot contain spaces".to_string(),
672            ));
673        }
674        if Self::RESERVED_PATHS.contains(&self.path.as_str()) {
675            return Err(ForgeError::Config(format!(
676                "mcp.path '{}' conflicts with a reserved gateway route",
677                self.path
678            )));
679        }
680        if self.session_ttl_secs == 0 {
681            return Err(ForgeError::Config(
682                "mcp.session_ttl_secs must be greater than 0".to_string(),
683            ));
684        }
685        Ok(())
686    }
687}
688
689fn default_mcp_path() -> String {
690    "/mcp".to_string()
691}
692
693fn default_mcp_session_ttl_secs() -> u64 {
694    60 * 60
695}
696
697/// Substitute environment variables in the format `${VAR_NAME}`.
698///
699/// Supports default values with `${VAR-default}` or `${VAR:-default}`.
700/// When the env var is unset, the default is used. Without a default,
701/// the literal `${VAR}` is preserved (so TOML parsing can still fail
702/// loudly if a required variable is missing).
703#[allow(clippy::indexing_slicing)]
704pub fn substitute_env_vars(content: &str) -> String {
705    let mut result = String::with_capacity(content.len());
706    let bytes = content.as_bytes();
707    let len = bytes.len();
708    let mut i = 0;
709
710    while i < len {
711        if i + 1 < len
712            && bytes[i] == b'$'
713            && bytes[i + 1] == b'{'
714            && let Some(end) = content[i + 2..].find('}')
715        {
716            let inner = &content[i + 2..i + 2 + end];
717
718            // Split on first `-` or `:-` for default value support
719            let (var_name, default_value) = parse_var_with_default(inner);
720
721            if is_valid_env_var_name(var_name) {
722                if let Ok(value) = std::env::var(var_name) {
723                    result.push_str(&value);
724                } else if let Some(default) = default_value {
725                    result.push_str(default);
726                } else {
727                    result.push_str(&content[i..i + 2 + end + 1]);
728                }
729                i += 2 + end + 1;
730                continue;
731            }
732        }
733        result.push(bytes[i] as char);
734        i += 1;
735    }
736
737    result
738}
739
740/// Parse `VAR-default` or `VAR:-default` into (name, optional default).
741/// Both forms behave identically (fallback when unset). `:-` is checked
742/// first so its `-` doesn't get matched by the plain `-` branch.
743fn parse_var_with_default(inner: &str) -> (&str, Option<&str>) {
744    if let Some(pos) = inner.find(":-") {
745        return (&inner[..pos], Some(&inner[pos + 2..]));
746    }
747    if let Some(pos) = inner.find('-') {
748        return (&inner[..pos], Some(&inner[pos + 1..]));
749    }
750    (inner, None)
751}
752
753fn is_valid_env_var_name(name: &str) -> bool {
754    let first = match name.as_bytes().first() {
755        Some(b) => b,
756        None => return false,
757    };
758    (first.is_ascii_uppercase() || *first == b'_')
759        && name
760            .bytes()
761            .all(|b| b.is_ascii_uppercase() || b.is_ascii_digit() || b == b'_')
762}
763
764#[cfg(test)]
765#[allow(clippy::unwrap_used, clippy::indexing_slicing, unsafe_code)]
766mod tests {
767    use super::*;
768
769    #[test]
770    fn test_default_config() {
771        let config = ForgeConfig::default_with_database_url("postgres://localhost/test");
772        assert_eq!(config.gateway.port, 9081);
773        assert_eq!(config.node.roles.len(), 4);
774        assert_eq!(config.mcp.path, "/mcp");
775        assert!(!config.mcp.enabled);
776    }
777
778    #[test]
779    fn test_parse_minimal_config() {
780        let toml = r#"
781            [database]
782            url = "postgres://localhost/myapp"
783        "#;
784
785        let config = ForgeConfig::parse_toml(toml).unwrap();
786        assert_eq!(config.database.url(), "postgres://localhost/myapp");
787        assert_eq!(config.gateway.port, 9081);
788    }
789
790    #[test]
791    fn test_parse_full_config() {
792        let toml = r#"
793            [project]
794            name = "my-app"
795            version = "1.0.0"
796
797            [database]
798            url = "postgres://localhost/myapp"
799            pool_size = 100
800
801            [node]
802            roles = ["gateway", "worker"]
803            worker_capabilities = ["media", "general"]
804
805            [gateway]
806            port = 3000
807            grpc_port = 9001
808        "#;
809
810        let config = ForgeConfig::parse_toml(toml).unwrap();
811        assert_eq!(config.project.name, "my-app");
812        assert_eq!(config.database.pool_size, 100);
813        assert_eq!(config.node.roles.len(), 2);
814        assert_eq!(config.gateway.port, 3000);
815    }
816
817    #[test]
818    fn test_env_var_substitution() {
819        unsafe {
820            std::env::set_var("TEST_DB_URL", "postgres://test:test@localhost/test");
821        }
822
823        let toml = r#"
824            [database]
825            url = "${TEST_DB_URL}"
826        "#;
827
828        let config = ForgeConfig::parse_toml(toml).unwrap();
829        assert_eq!(config.database.url(), "postgres://test:test@localhost/test");
830
831        unsafe {
832            std::env::remove_var("TEST_DB_URL");
833        }
834    }
835
836    #[test]
837    fn test_auth_validation_no_config() {
838        let auth = AuthConfig::default();
839        assert!(auth.validate().is_ok());
840    }
841
842    #[test]
843    fn test_auth_validation_hmac_with_secret() {
844        let auth = AuthConfig {
845            jwt_secret: Some("my-secret".into()),
846            jwt_algorithm: JwtAlgorithm::HS256,
847            ..Default::default()
848        };
849        assert!(auth.validate().is_ok());
850    }
851
852    #[test]
853    fn test_auth_validation_hmac_missing_secret() {
854        let auth = AuthConfig {
855            jwt_issuer: Some("my-issuer".into()),
856            jwt_algorithm: JwtAlgorithm::HS256,
857            ..Default::default()
858        };
859        let result = auth.validate();
860        assert!(result.is_err());
861        let err_msg = result.unwrap_err().to_string();
862        assert!(err_msg.contains("jwt_secret is required"));
863    }
864
865    #[test]
866    fn test_auth_validation_rsa_with_jwks() {
867        let auth = AuthConfig {
868            jwks_url: Some("https://example.com/.well-known/jwks.json".into()),
869            jwt_algorithm: JwtAlgorithm::RS256,
870            ..Default::default()
871        };
872        assert!(auth.validate().is_ok());
873    }
874
875    #[test]
876    fn test_auth_validation_rsa_missing_jwks() {
877        let auth = AuthConfig {
878            jwt_issuer: Some("my-issuer".into()),
879            jwt_algorithm: JwtAlgorithm::RS256,
880            ..Default::default()
881        };
882        let result = auth.validate();
883        assert!(result.is_err());
884        let err_msg = result.unwrap_err().to_string();
885        assert!(err_msg.contains("jwks_url is required"));
886    }
887
888    #[test]
889    fn test_forge_config_validation_fails_on_empty_url() {
890        let toml = r#"
891            [database]
892
893            url = ""
894        "#;
895
896        let result = ForgeConfig::parse_toml(toml);
897        assert!(result.is_err());
898        let err_msg = result.unwrap_err().to_string();
899        assert!(err_msg.contains("database.url is required"));
900    }
901
902    #[test]
903    fn test_forge_config_validation_fails_on_invalid_auth() {
904        let toml = r#"
905            [database]
906
907            url = "postgres://localhost/test"
908
909            [auth]
910            jwt_issuer = "my-issuer"
911            jwt_algorithm = "RS256"
912        "#;
913
914        let result = ForgeConfig::parse_toml(toml);
915        assert!(result.is_err());
916        let err_msg = result.unwrap_err().to_string();
917        assert!(err_msg.contains("jwks_url is required"));
918    }
919
920    #[test]
921    fn test_env_var_default_used_when_unset() {
922        // Ensure the var is definitely not set
923        unsafe {
924            std::env::remove_var("TEST_FORGE_OTEL_UNSET");
925        }
926
927        let input = r#"enabled = ${TEST_FORGE_OTEL_UNSET-false}"#;
928        let result = substitute_env_vars(input);
929        assert_eq!(result, "enabled = false");
930    }
931
932    #[test]
933    fn test_env_var_default_overridden_when_set() {
934        unsafe {
935            std::env::set_var("TEST_FORGE_OTEL_SET", "true");
936        }
937
938        let input = r#"enabled = ${TEST_FORGE_OTEL_SET-false}"#;
939        let result = substitute_env_vars(input);
940        assert_eq!(result, "enabled = true");
941
942        unsafe {
943            std::env::remove_var("TEST_FORGE_OTEL_SET");
944        }
945    }
946
947    #[test]
948    fn test_env_var_colon_dash_default() {
949        unsafe {
950            std::env::remove_var("TEST_FORGE_ENDPOINT_UNSET");
951        }
952
953        let input = r#"endpoint = "${TEST_FORGE_ENDPOINT_UNSET:-http://localhost:4318}""#;
954        let result = substitute_env_vars(input);
955        assert_eq!(result, r#"endpoint = "http://localhost:4318""#);
956    }
957
958    #[test]
959    fn test_env_var_no_default_preserves_literal() {
960        unsafe {
961            std::env::remove_var("TEST_FORGE_MISSING");
962        }
963
964        let input = r#"url = "${TEST_FORGE_MISSING}""#;
965        let result = substitute_env_vars(input);
966        assert_eq!(result, r#"url = "${TEST_FORGE_MISSING}""#);
967    }
968
969    #[test]
970    fn test_env_var_default_empty_string() {
971        unsafe {
972            std::env::remove_var("TEST_FORGE_EMPTY_DEFAULT");
973        }
974
975        let input = r#"val = "${TEST_FORGE_EMPTY_DEFAULT-}""#;
976        let result = substitute_env_vars(input);
977        assert_eq!(result, r#"val = """#);
978    }
979
980    #[test]
981    fn test_observability_config_default_disabled() {
982        let toml = r#"
983            [database]
984            url = "postgres://localhost/test"
985        "#;
986
987        let config = ForgeConfig::parse_toml(toml).unwrap();
988        assert!(!config.observability.enabled);
989        assert!(!config.observability.otlp_active());
990    }
991
992    #[test]
993    fn test_observability_config_with_env_default() {
994        // Simulates what the template produces when no env vars are set
995        unsafe {
996            std::env::remove_var("TEST_OTEL_ENABLED");
997        }
998
999        let toml = r#"
1000            [database]
1001            url = "postgres://localhost/test"
1002
1003            [observability]
1004            enabled = ${TEST_OTEL_ENABLED-false}
1005        "#;
1006
1007        let config = ForgeConfig::parse_toml(toml).unwrap();
1008        assert!(!config.observability.enabled);
1009    }
1010
1011    #[test]
1012    fn test_mcp_config_validation_rejects_invalid_path() {
1013        let toml = r#"
1014            [database]
1015
1016            url = "postgres://localhost/test"
1017
1018            [mcp]
1019            enabled = true
1020            path = "mcp"
1021        "#;
1022
1023        let result = ForgeConfig::parse_toml(toml);
1024        assert!(result.is_err());
1025        let err_msg = result.unwrap_err().to_string();
1026        assert!(err_msg.contains("mcp.path must start with '/'"));
1027    }
1028
1029    #[test]
1030    fn test_access_token_ttl_defaults() {
1031        let auth = AuthConfig::default();
1032        assert_eq!(auth.access_token_ttl_secs(), 3600);
1033        assert_eq!(auth.refresh_token_ttl_days(), 30);
1034    }
1035
1036    #[test]
1037    fn test_access_token_ttl_custom() {
1038        let auth = AuthConfig {
1039            access_token_ttl: Some("15m".into()),
1040            refresh_token_ttl: Some("7d".into()),
1041            ..Default::default()
1042        };
1043        assert_eq!(auth.access_token_ttl_secs(), 900);
1044        assert_eq!(auth.refresh_token_ttl_days(), 7);
1045    }
1046
1047    #[test]
1048    fn test_access_token_ttl_minimum_enforced() {
1049        let auth = AuthConfig {
1050            access_token_ttl: Some("0s".into()),
1051            ..Default::default()
1052        };
1053        // Should floor at 1, not 0
1054        assert_eq!(auth.access_token_ttl_secs(), 1);
1055    }
1056
1057    #[test]
1058    fn test_refresh_token_ttl_minimum_enforced() {
1059        let auth = AuthConfig {
1060            refresh_token_ttl: Some("1h".into()),
1061            ..Default::default()
1062        };
1063        // 1 hour < 1 day, so should floor at 1 day
1064        assert_eq!(auth.refresh_token_ttl_days(), 1);
1065    }
1066
1067    #[test]
1068    fn test_mcp_config_rejects_reserved_paths() {
1069        for reserved in McpConfig::RESERVED_PATHS {
1070            let toml = format!(
1071                r#"
1072                [database]
1073                url = "postgres://localhost/test"
1074
1075                [mcp]
1076                enabled = true
1077                path = "{reserved}"
1078                "#
1079            );
1080
1081            let result = ForgeConfig::parse_toml(&toml);
1082            assert!(result.is_err(), "Expected {reserved} to be rejected");
1083            let err_msg = result.unwrap_err().to_string();
1084            assert!(
1085                err_msg.contains("conflicts with a reserved gateway route"),
1086                "Wrong error for {reserved}: {err_msg}"
1087            );
1088        }
1089    }
1090}