Skip to main content

forge_core/config/
mod.rs

1pub mod cluster;
2mod database;
3pub mod signals;
4
5pub use cluster::ClusterConfig;
6pub use database::{DatabaseConfig, PoolConfig};
7pub use signals::SignalsConfig;
8
9use serde::{Deserialize, Serialize};
10use std::path::Path;
11
12use crate::error::{ForgeError, Result};
13
14/// Root configuration for FORGE.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ForgeConfig {
17    /// Project metadata.
18    #[serde(default)]
19    pub project: ProjectConfig,
20
21    /// Database configuration.
22    pub database: DatabaseConfig,
23
24    /// Node configuration.
25    #[serde(default)]
26    pub node: NodeConfig,
27
28    /// Gateway configuration.
29    #[serde(default)]
30    pub gateway: GatewayConfig,
31
32    /// Function execution configuration.
33    #[serde(default)]
34    pub function: FunctionConfig,
35
36    /// Worker configuration.
37    #[serde(default)]
38    pub worker: WorkerConfig,
39
40    /// Cluster configuration.
41    #[serde(default)]
42    pub cluster: ClusterConfig,
43
44    /// Security configuration.
45    #[serde(default)]
46    pub security: SecurityConfig,
47
48    /// Authentication configuration.
49    #[serde(default)]
50    pub auth: AuthConfig,
51
52    /// Observability configuration.
53    #[serde(default)]
54    pub observability: ObservabilityConfig,
55
56    /// MCP server configuration.
57    #[serde(default)]
58    pub mcp: McpConfig,
59
60    /// Signals configuration for product analytics and diagnostics.
61    #[serde(default)]
62    pub signals: SignalsConfig,
63}
64
65impl ForgeConfig {
66    /// Load configuration from a TOML file.
67    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
68        let content = std::fs::read_to_string(path.as_ref())
69            .map_err(|e| ForgeError::Config(format!("Failed to read config file: {}", e)))?;
70
71        Self::parse_toml(&content)
72    }
73
74    /// Parse configuration from a TOML string.
75    pub fn parse_toml(content: &str) -> Result<Self> {
76        // Substitute environment variables
77        let content = substitute_env_vars(content);
78
79        let config: Self = toml::from_str(&content)
80            .map_err(|e| ForgeError::Config(format!("Failed to parse config: {}", e)))?;
81
82        config.validate()?;
83        Ok(config)
84    }
85
86    /// Validate the configuration for invalid combinations.
87    pub fn validate(&self) -> Result<()> {
88        self.database.validate()?;
89        self.auth.validate()?;
90        self.mcp.validate()?;
91
92        // Cross-field: OAuth requires jwt_secret for signing tokens
93        if self.mcp.oauth && self.auth.jwt_secret.is_none() {
94            return Err(ForgeError::Config(
95                "mcp.oauth = true requires auth.jwt_secret to be set. \
96                 OAuth-issued tokens are signed with this secret, even when using \
97                 an external provider (JWKS) for identity verification."
98                    .into(),
99            ));
100        }
101        if self.mcp.oauth && !self.mcp.enabled {
102            return Err(ForgeError::Config(
103                "mcp.oauth = true requires mcp.enabled = true".into(),
104            ));
105        }
106
107        Ok(())
108    }
109
110    /// Load configuration with defaults.
111    pub fn default_with_database_url(url: &str) -> Self {
112        Self {
113            project: ProjectConfig::default(),
114            database: DatabaseConfig::new(url),
115            node: NodeConfig::default(),
116            gateway: GatewayConfig::default(),
117            function: FunctionConfig::default(),
118            worker: WorkerConfig::default(),
119            cluster: ClusterConfig::default(),
120            security: SecurityConfig::default(),
121            auth: AuthConfig::default(),
122            observability: ObservabilityConfig::default(),
123            mcp: McpConfig::default(),
124            signals: SignalsConfig::default(),
125        }
126    }
127}
128
129/// Project metadata.
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct ProjectConfig {
132    /// Project name.
133    #[serde(default = "default_project_name")]
134    pub name: String,
135
136    /// Project version.
137    #[serde(default = "default_version")]
138    pub version: String,
139}
140
141impl Default for ProjectConfig {
142    fn default() -> Self {
143        Self {
144            name: default_project_name(),
145            version: default_version(),
146        }
147    }
148}
149
150fn default_project_name() -> String {
151    "forge-app".to_string()
152}
153
154fn default_version() -> String {
155    "0.1.0".to_string()
156}
157
158/// Node role configuration.
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct NodeConfig {
161    /// Roles this node should assume.
162    #[serde(default = "default_roles")]
163    pub roles: Vec<NodeRole>,
164
165    /// Worker capabilities for job routing.
166    #[serde(default = "default_capabilities")]
167    pub worker_capabilities: Vec<String>,
168}
169
170impl Default for NodeConfig {
171    fn default() -> Self {
172        Self {
173            roles: default_roles(),
174            worker_capabilities: default_capabilities(),
175        }
176    }
177}
178
179fn default_roles() -> Vec<NodeRole> {
180    vec![
181        NodeRole::Gateway,
182        NodeRole::Function,
183        NodeRole::Worker,
184        NodeRole::Scheduler,
185    ]
186}
187
188fn default_capabilities() -> Vec<String> {
189    vec!["general".to_string()]
190}
191
192/// Available node roles.
193#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
194#[serde(rename_all = "lowercase")]
195pub enum NodeRole {
196    Gateway,
197    Function,
198    Worker,
199    Scheduler,
200}
201
202/// Gateway configuration.
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct GatewayConfig {
205    /// HTTP port.
206    #[serde(default = "default_http_port")]
207    pub port: u16,
208
209    /// gRPC port for inter-node communication (reserved for future use).
210    ///
211    /// This port is registered in the cluster node info but a gRPC listener
212    /// is not yet started. It will be used for efficient binary inter-node
213    /// RPC in a future release.
214    #[serde(default = "default_grpc_port")]
215    pub grpc_port: u16,
216
217    /// Maximum concurrent connections.
218    #[serde(default = "default_max_connections")]
219    pub max_connections: usize,
220
221    /// Maximum active SSE sessions.
222    #[serde(default = "default_sse_max_sessions")]
223    pub sse_max_sessions: usize,
224
225    /// Request timeout in seconds.
226    #[serde(default = "default_request_timeout")]
227    pub request_timeout_secs: u64,
228
229    /// Enable CORS handling.
230    #[serde(default = "default_cors_enabled")]
231    pub cors_enabled: bool,
232
233    /// Allowed CORS origins.
234    #[serde(default = "default_cors_origins")]
235    pub cors_origins: Vec<String>,
236
237    /// Routes excluded from request logs, metrics, and traces.
238    /// Defaults to `["/_api/health", "/_api/ready"]`. Set to `[]` to monitor everything.
239    #[serde(default = "default_quiet_routes")]
240    pub quiet_routes: Vec<String>,
241}
242
243impl Default for GatewayConfig {
244    fn default() -> Self {
245        Self {
246            port: default_http_port(),
247            grpc_port: default_grpc_port(),
248            max_connections: default_max_connections(),
249            sse_max_sessions: default_sse_max_sessions(),
250            request_timeout_secs: default_request_timeout(),
251            cors_enabled: default_cors_enabled(),
252            cors_origins: default_cors_origins(),
253            quiet_routes: default_quiet_routes(),
254        }
255    }
256}
257
258fn default_http_port() -> u16 {
259    9081
260}
261
262fn default_grpc_port() -> u16 {
263    9000
264}
265
266fn default_max_connections() -> usize {
267    4096
268}
269
270fn default_sse_max_sessions() -> usize {
271    10_000
272}
273
274fn default_request_timeout() -> u64 {
275    30
276}
277
278fn default_cors_enabled() -> bool {
279    false
280}
281
282fn default_cors_origins() -> Vec<String> {
283    Vec::new()
284}
285
286fn default_quiet_routes() -> Vec<String> {
287    vec![
288        "/_api/health".to_string(),
289        "/_api/ready".to_string(),
290        "/_api/signal/event".to_string(),
291        "/_api/signal/view".to_string(),
292        "/_api/signal/user".to_string(),
293        "/_api/signal/report".to_string(),
294    ]
295}
296
297/// Function execution configuration.
298#[derive(Debug, Clone, Serialize, Deserialize)]
299pub struct FunctionConfig {
300    /// Maximum concurrent function executions.
301    #[serde(default = "default_max_concurrent")]
302    pub max_concurrent: usize,
303
304    /// Function timeout in seconds.
305    #[serde(default = "default_function_timeout")]
306    pub timeout_secs: u64,
307
308    /// Advisory memory limit per function execution (in bytes).
309    ///
310    /// This value is exposed as configuration metadata for orchestrators
311    /// (e.g., Kubernetes resource requests) and observability dashboards.
312    /// It is not enforced at the process level since Rust does not provide
313    /// per-function memory sandboxing. Use container-level limits for hard
314    /// enforcement.
315    #[serde(default = "default_memory_limit")]
316    pub memory_limit: usize,
317}
318
319impl Default for FunctionConfig {
320    fn default() -> Self {
321        Self {
322            max_concurrent: default_max_concurrent(),
323            timeout_secs: default_function_timeout(),
324            memory_limit: default_memory_limit(),
325        }
326    }
327}
328
329fn default_max_concurrent() -> usize {
330    1000
331}
332
333fn default_function_timeout() -> u64 {
334    30
335}
336
337fn default_memory_limit() -> usize {
338    512 * 1024 * 1024 // 512 MiB
339}
340
341/// Worker configuration.
342#[derive(Debug, Clone, Serialize, Deserialize)]
343pub struct WorkerConfig {
344    /// Maximum concurrent jobs.
345    #[serde(default = "default_max_concurrent_jobs")]
346    pub max_concurrent_jobs: usize,
347
348    /// Job timeout in seconds.
349    #[serde(default = "default_job_timeout")]
350    pub job_timeout_secs: u64,
351
352    /// Poll interval in milliseconds.
353    #[serde(default = "default_poll_interval")]
354    pub poll_interval_ms: u64,
355}
356
357impl Default for WorkerConfig {
358    fn default() -> Self {
359        Self {
360            max_concurrent_jobs: default_max_concurrent_jobs(),
361            job_timeout_secs: default_job_timeout(),
362            poll_interval_ms: default_poll_interval(),
363        }
364    }
365}
366
367fn default_max_concurrent_jobs() -> usize {
368    50
369}
370
371fn default_job_timeout() -> u64 {
372    3600 // 1 hour
373}
374
375fn default_poll_interval() -> u64 {
376    100
377}
378
379/// Security configuration.
380#[derive(Debug, Clone, Serialize, Deserialize, Default)]
381pub struct SecurityConfig {
382    /// Secret key for signing.
383    pub secret_key: Option<String>,
384}
385
386/// JWT signing algorithm.
387#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
388#[serde(rename_all = "UPPERCASE")]
389pub enum JwtAlgorithm {
390    /// HMAC using SHA-256 (symmetric, requires jwt_secret).
391    #[default]
392    HS256,
393    /// HMAC using SHA-384 (symmetric, requires jwt_secret).
394    HS384,
395    /// HMAC using SHA-512 (symmetric, requires jwt_secret).
396    HS512,
397    /// RSA using SHA-256 (asymmetric, requires jwks_url).
398    RS256,
399    /// RSA using SHA-384 (asymmetric, requires jwks_url).
400    RS384,
401    /// RSA using SHA-512 (asymmetric, requires jwks_url).
402    RS512,
403}
404
405/// Authentication configuration.
406#[derive(Debug, Clone, Serialize, Deserialize)]
407pub struct AuthConfig {
408    /// JWT secret for HMAC algorithms (HS256, HS384, HS512).
409    /// Required when using HMAC algorithms.
410    pub jwt_secret: Option<String>,
411
412    /// JWT signing algorithm.
413    /// HMAC algorithms (HS256, HS384, HS512) require jwt_secret.
414    /// RSA algorithms (RS256, RS384, RS512) require jwks_url.
415    #[serde(default)]
416    pub jwt_algorithm: JwtAlgorithm,
417
418    /// Expected token issuer (iss claim).
419    /// If set, tokens with a different issuer are rejected.
420    pub jwt_issuer: Option<String>,
421
422    /// Expected audience (aud claim).
423    /// If set, tokens with a different audience are rejected.
424    pub jwt_audience: Option<String>,
425
426    /// Access token lifetime (e.g., "15m", "1h").
427    /// Used by `ctx.issue_token_pair()`. Defaults to "1h".
428    pub access_token_ttl: Option<String>,
429
430    /// Refresh token lifetime (e.g., "7d", "30d").
431    /// Used by `ctx.issue_token_pair()`. Defaults to "30d".
432    pub refresh_token_ttl: Option<String>,
433
434    /// JWKS URL for RSA algorithms (RS256, RS384, RS512).
435    /// Keys are fetched and cached automatically.
436    pub jwks_url: Option<String>,
437
438    /// JWKS cache TTL in seconds.
439    #[serde(default = "default_jwks_cache_ttl")]
440    pub jwks_cache_ttl_secs: u64,
441
442    /// Session TTL in seconds (for WebSocket sessions).
443    #[serde(default = "default_session_ttl")]
444    pub session_ttl_secs: u64,
445}
446
447impl Default for AuthConfig {
448    fn default() -> Self {
449        Self {
450            jwt_secret: None,
451            jwt_algorithm: JwtAlgorithm::default(),
452            jwt_issuer: None,
453            jwt_audience: None,
454            access_token_ttl: None,
455            refresh_token_ttl: None,
456            jwks_url: None,
457            jwks_cache_ttl_secs: default_jwks_cache_ttl(),
458            session_ttl_secs: default_session_ttl(),
459        }
460    }
461}
462
463impl AuthConfig {
464    /// Resolved access token TTL in seconds.
465    /// Parses `access_token_ttl`, default 3600s (1h).
466    /// Minimum 1 second to prevent zero-lifetime tokens.
467    pub fn access_token_ttl_secs(&self) -> i64 {
468        self.access_token_ttl
469            .as_deref()
470            .and_then(crate::util::parse_duration)
471            .map(|d| (d.as_secs() as i64).max(1))
472            .unwrap_or(3600)
473    }
474
475    /// Resolved refresh token TTL in days.
476    /// Parses `refresh_token_ttl`, default 30 days.
477    pub fn refresh_token_ttl_days(&self) -> i64 {
478        self.refresh_token_ttl
479            .as_deref()
480            .and_then(crate::util::parse_duration)
481            .map(|d| (d.as_secs() / 86400) as i64)
482            .map(|d| if d == 0 { 1 } else { d })
483            .unwrap_or(30)
484    }
485
486    /// Check if auth is configured (any credential or claim validation is set).
487    fn is_configured(&self) -> bool {
488        self.jwt_secret.is_some()
489            || self.jwks_url.is_some()
490            || self.jwt_issuer.is_some()
491            || self.jwt_audience.is_some()
492    }
493
494    /// Validate that the configuration is complete for the chosen algorithm.
495    /// Skips validation if no auth settings are configured (auth disabled).
496    pub fn validate(&self) -> Result<()> {
497        if !self.is_configured() {
498            return Ok(());
499        }
500
501        match self.jwt_algorithm {
502            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512 => {
503                if self.jwt_secret.is_none() {
504                    return Err(ForgeError::Config(
505                        "auth.jwt_secret is required for HMAC algorithms (HS256, HS384, HS512). \
506                         Set auth.jwt_secret to a secure random string, \
507                         or switch to RS256 and provide auth.jwks_url for external identity providers."
508                            .into(),
509                    ));
510                }
511            }
512            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512 => {
513                if self.jwks_url.is_none() {
514                    return Err(ForgeError::Config(
515                        "auth.jwks_url is required for RSA algorithms (RS256, RS384, RS512). \
516                         Set auth.jwks_url to your identity provider's JWKS endpoint, \
517                         or switch to HS256 and provide auth.jwt_secret for symmetric signing."
518                            .into(),
519                    ));
520                }
521            }
522        }
523        Ok(())
524    }
525
526    /// Check if this config uses HMAC (symmetric) algorithms.
527    pub fn is_hmac(&self) -> bool {
528        matches!(
529            self.jwt_algorithm,
530            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
531        )
532    }
533
534    /// Check if this config uses RSA (asymmetric) algorithms.
535    pub fn is_rsa(&self) -> bool {
536        matches!(
537            self.jwt_algorithm,
538            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
539        )
540    }
541}
542
543fn default_jwks_cache_ttl() -> u64 {
544    3600 // 1 hour
545}
546
547fn default_session_ttl() -> u64 {
548    7 * 24 * 60 * 60 // 7 days
549}
550
551/// Observability configuration for OTLP telemetry.
552#[derive(Debug, Clone, Serialize, Deserialize)]
553pub struct ObservabilityConfig {
554    /// Enable observability (traces, metrics, logs).
555    #[serde(default)]
556    pub enabled: bool,
557
558    /// OTLP endpoint for telemetry export.
559    #[serde(default = "default_otlp_endpoint")]
560    pub otlp_endpoint: String,
561
562    /// Service name for telemetry identification.
563    pub service_name: Option<String>,
564
565    /// Enable distributed tracing.
566    #[serde(default = "default_true")]
567    pub enable_traces: bool,
568
569    /// Enable metrics collection.
570    #[serde(default = "default_true")]
571    pub enable_metrics: bool,
572
573    /// Enable log export via OTLP.
574    #[serde(default = "default_true")]
575    pub enable_logs: bool,
576
577    /// Trace sampling ratio (0.0 to 1.0).
578    #[serde(default = "default_sampling_ratio")]
579    pub sampling_ratio: f64,
580
581    /// Log level for the tracing subscriber (e.g., "debug", "info", "warn").
582    #[serde(default = "default_log_level")]
583    pub log_level: String,
584}
585
586impl Default for ObservabilityConfig {
587    fn default() -> Self {
588        Self {
589            enabled: false,
590            otlp_endpoint: default_otlp_endpoint(),
591            service_name: None,
592            enable_traces: true,
593            enable_metrics: true,
594            enable_logs: true,
595            sampling_ratio: default_sampling_ratio(),
596            log_level: default_log_level(),
597        }
598    }
599}
600
601impl ObservabilityConfig {
602    pub fn otlp_active(&self) -> bool {
603        self.enabled && (self.enable_traces || self.enable_metrics || self.enable_logs)
604    }
605
606    /// Apply FORGE_OTEL_* environment variable overrides.
607    ///
608    /// Supported variables:
609    /// - `FORGE_OTEL_TRACES` - "true"/"false" to enable/disable traces
610    /// - `FORGE_OTEL_METRICS` - "true"/"false" to enable/disable metrics
611    /// - `FORGE_OTEL_LOGS` - "true"/"false" to enable/disable logs
612    pub fn apply_env_overrides(&mut self) {
613        if let Ok(val) = std::env::var("FORGE_OTEL_TRACES") {
614            self.enable_traces = val.eq_ignore_ascii_case("true") || val == "1";
615        }
616        if let Ok(val) = std::env::var("FORGE_OTEL_METRICS") {
617            self.enable_metrics = val.eq_ignore_ascii_case("true") || val == "1";
618        }
619        if let Ok(val) = std::env::var("FORGE_OTEL_LOGS") {
620            self.enable_logs = val.eq_ignore_ascii_case("true") || val == "1";
621        }
622    }
623}
624
625fn default_otlp_endpoint() -> String {
626    "http://localhost:4318".to_string()
627}
628
629pub(crate) fn default_true() -> bool {
630    true
631}
632
633fn default_sampling_ratio() -> f64 {
634    1.0
635}
636
637fn default_log_level() -> String {
638    "info".to_string()
639}
640
641/// MCP server configuration.
642#[derive(Debug, Clone, Serialize, Deserialize)]
643pub struct McpConfig {
644    /// Enable MCP endpoint exposure.
645    #[serde(default)]
646    pub enabled: bool,
647
648    /// Enable OAuth 2.1 Authorization Code + PKCE for MCP clients.
649    /// When true, Forge acts as an OAuth 2.1 Authorization Server so MCP
650    /// clients like Claude Code can auto-authenticate via browser login.
651    /// Requires `auth.jwt_secret` to be set.
652    #[serde(default)]
653    pub oauth: bool,
654
655    /// MCP endpoint path under the gateway API namespace.
656    #[serde(default = "default_mcp_path")]
657    pub path: String,
658
659    /// Session TTL in seconds.
660    #[serde(default = "default_mcp_session_ttl_secs")]
661    pub session_ttl_secs: u64,
662
663    /// Allowed origins for Origin header validation.
664    #[serde(default)]
665    pub allowed_origins: Vec<String>,
666
667    /// Enforce MCP-Protocol-Version header on post-initialize requests.
668    #[serde(default = "default_true")]
669    pub require_protocol_version_header: bool,
670}
671
672impl Default for McpConfig {
673    fn default() -> Self {
674        Self {
675            enabled: false,
676            oauth: false,
677            path: default_mcp_path(),
678            session_ttl_secs: default_mcp_session_ttl_secs(),
679            allowed_origins: Vec::new(),
680            require_protocol_version_header: default_true(),
681        }
682    }
683}
684
685impl McpConfig {
686    /// Paths reserved by the gateway that MCP must not collide with.
687    const RESERVED_PATHS: &[&str] = &[
688        "/health",
689        "/ready",
690        "/rpc",
691        "/events",
692        "/subscribe",
693        "/unsubscribe",
694        "/subscribe-job",
695        "/subscribe-workflow",
696        "/metrics",
697    ];
698
699    pub fn validate(&self) -> Result<()> {
700        if self.path.is_empty() || !self.path.starts_with('/') {
701            return Err(ForgeError::Config(
702                "mcp.path must start with '/' (example: /mcp)".to_string(),
703            ));
704        }
705        if self.path.contains(' ') {
706            return Err(ForgeError::Config(
707                "mcp.path cannot contain spaces".to_string(),
708            ));
709        }
710        if Self::RESERVED_PATHS.contains(&self.path.as_str()) {
711            return Err(ForgeError::Config(format!(
712                "mcp.path '{}' conflicts with a reserved gateway route",
713                self.path
714            )));
715        }
716        if self.session_ttl_secs == 0 {
717            return Err(ForgeError::Config(
718                "mcp.session_ttl_secs must be greater than 0".to_string(),
719            ));
720        }
721        Ok(())
722    }
723}
724
725fn default_mcp_path() -> String {
726    "/mcp".to_string()
727}
728
729fn default_mcp_session_ttl_secs() -> u64 {
730    60 * 60
731}
732
733/// Substitute environment variables in the format `${VAR_NAME}`.
734///
735/// Supports default values with `${VAR-default}` or `${VAR:-default}`.
736/// When the env var is unset, the default is used. Without a default,
737/// the literal `${VAR}` is preserved (so TOML parsing can still fail
738/// loudly if a required variable is missing).
739#[allow(clippy::indexing_slicing)]
740pub fn substitute_env_vars(content: &str) -> String {
741    let mut result = String::with_capacity(content.len());
742    let bytes = content.as_bytes();
743    let len = bytes.len();
744    let mut i = 0;
745
746    while i < len {
747        if i + 1 < len
748            && bytes[i] == b'$'
749            && bytes[i + 1] == b'{'
750            && let Some(end) = content[i + 2..].find('}')
751        {
752            let inner = &content[i + 2..i + 2 + end];
753
754            // Split on first `-` or `:-` for default value support
755            let (var_name, default_value) = parse_var_with_default(inner);
756
757            if is_valid_env_var_name(var_name) {
758                if let Ok(value) = std::env::var(var_name) {
759                    result.push_str(&value);
760                } else if let Some(default) = default_value {
761                    result.push_str(default);
762                } else {
763                    result.push_str(&content[i..i + 2 + end + 1]);
764                }
765                i += 2 + end + 1;
766                continue;
767            }
768        }
769        result.push(bytes[i] as char);
770        i += 1;
771    }
772
773    result
774}
775
776/// Parse `VAR-default` or `VAR:-default` into (name, optional default).
777/// Both forms behave identically (fallback when unset). `:-` is checked
778/// first so its `-` doesn't get matched by the plain `-` branch.
779fn parse_var_with_default(inner: &str) -> (&str, Option<&str>) {
780    if let Some(pos) = inner.find(":-") {
781        return (&inner[..pos], Some(&inner[pos + 2..]));
782    }
783    if let Some(pos) = inner.find('-') {
784        return (&inner[..pos], Some(&inner[pos + 1..]));
785    }
786    (inner, None)
787}
788
789fn is_valid_env_var_name(name: &str) -> bool {
790    let first = match name.as_bytes().first() {
791        Some(b) => b,
792        None => return false,
793    };
794    (first.is_ascii_uppercase() || *first == b'_')
795        && name
796            .bytes()
797            .all(|b| b.is_ascii_uppercase() || b.is_ascii_digit() || b == b'_')
798}
799
800#[cfg(test)]
801#[allow(clippy::unwrap_used, clippy::indexing_slicing, unsafe_code)]
802mod tests {
803    use super::*;
804
805    #[test]
806    fn test_default_config() {
807        let config = ForgeConfig::default_with_database_url("postgres://localhost/test");
808        assert_eq!(config.gateway.port, 9081);
809        assert_eq!(config.node.roles.len(), 4);
810        assert_eq!(config.mcp.path, "/mcp");
811        assert!(!config.mcp.enabled);
812    }
813
814    #[test]
815    fn test_parse_minimal_config() {
816        let toml = r#"
817            [database]
818            url = "postgres://localhost/myapp"
819        "#;
820
821        let config = ForgeConfig::parse_toml(toml).unwrap();
822        assert_eq!(config.database.url(), "postgres://localhost/myapp");
823        assert_eq!(config.gateway.port, 9081);
824    }
825
826    #[test]
827    fn test_parse_full_config() {
828        let toml = r#"
829            [project]
830            name = "my-app"
831            version = "1.0.0"
832
833            [database]
834            url = "postgres://localhost/myapp"
835            pool_size = 100
836
837            [node]
838            roles = ["gateway", "worker"]
839            worker_capabilities = ["media", "general"]
840
841            [gateway]
842            port = 3000
843            grpc_port = 9001
844        "#;
845
846        let config = ForgeConfig::parse_toml(toml).unwrap();
847        assert_eq!(config.project.name, "my-app");
848        assert_eq!(config.database.pool_size, 100);
849        assert_eq!(config.node.roles.len(), 2);
850        assert_eq!(config.gateway.port, 3000);
851    }
852
853    #[test]
854    fn test_env_var_substitution() {
855        unsafe {
856            std::env::set_var("TEST_DB_URL", "postgres://test:test@localhost/test");
857        }
858
859        let toml = r#"
860            [database]
861            url = "${TEST_DB_URL}"
862        "#;
863
864        let config = ForgeConfig::parse_toml(toml).unwrap();
865        assert_eq!(config.database.url(), "postgres://test:test@localhost/test");
866
867        unsafe {
868            std::env::remove_var("TEST_DB_URL");
869        }
870    }
871
872    #[test]
873    fn test_auth_validation_no_config() {
874        let auth = AuthConfig::default();
875        assert!(auth.validate().is_ok());
876    }
877
878    #[test]
879    fn test_auth_validation_hmac_with_secret() {
880        let auth = AuthConfig {
881            jwt_secret: Some("my-secret".into()),
882            jwt_algorithm: JwtAlgorithm::HS256,
883            ..Default::default()
884        };
885        assert!(auth.validate().is_ok());
886    }
887
888    #[test]
889    fn test_auth_validation_hmac_missing_secret() {
890        let auth = AuthConfig {
891            jwt_issuer: Some("my-issuer".into()),
892            jwt_algorithm: JwtAlgorithm::HS256,
893            ..Default::default()
894        };
895        let result = auth.validate();
896        assert!(result.is_err());
897        let err_msg = result.unwrap_err().to_string();
898        assert!(err_msg.contains("jwt_secret is required"));
899    }
900
901    #[test]
902    fn test_auth_validation_rsa_with_jwks() {
903        let auth = AuthConfig {
904            jwks_url: Some("https://example.com/.well-known/jwks.json".into()),
905            jwt_algorithm: JwtAlgorithm::RS256,
906            ..Default::default()
907        };
908        assert!(auth.validate().is_ok());
909    }
910
911    #[test]
912    fn test_auth_validation_rsa_missing_jwks() {
913        let auth = AuthConfig {
914            jwt_issuer: Some("my-issuer".into()),
915            jwt_algorithm: JwtAlgorithm::RS256,
916            ..Default::default()
917        };
918        let result = auth.validate();
919        assert!(result.is_err());
920        let err_msg = result.unwrap_err().to_string();
921        assert!(err_msg.contains("jwks_url is required"));
922    }
923
924    #[test]
925    fn test_forge_config_validation_fails_on_empty_url() {
926        let toml = r#"
927            [database]
928
929            url = ""
930        "#;
931
932        let result = ForgeConfig::parse_toml(toml);
933        assert!(result.is_err());
934        let err_msg = result.unwrap_err().to_string();
935        assert!(err_msg.contains("database.url is required"));
936    }
937
938    #[test]
939    fn test_forge_config_validation_fails_on_invalid_auth() {
940        let toml = r#"
941            [database]
942
943            url = "postgres://localhost/test"
944
945            [auth]
946            jwt_issuer = "my-issuer"
947            jwt_algorithm = "RS256"
948        "#;
949
950        let result = ForgeConfig::parse_toml(toml);
951        assert!(result.is_err());
952        let err_msg = result.unwrap_err().to_string();
953        assert!(err_msg.contains("jwks_url is required"));
954    }
955
956    #[test]
957    fn test_env_var_default_used_when_unset() {
958        // Ensure the var is definitely not set
959        unsafe {
960            std::env::remove_var("TEST_FORGE_OTEL_UNSET");
961        }
962
963        let input = r#"enabled = ${TEST_FORGE_OTEL_UNSET-false}"#;
964        let result = substitute_env_vars(input);
965        assert_eq!(result, "enabled = false");
966    }
967
968    #[test]
969    fn test_env_var_default_overridden_when_set() {
970        unsafe {
971            std::env::set_var("TEST_FORGE_OTEL_SET", "true");
972        }
973
974        let input = r#"enabled = ${TEST_FORGE_OTEL_SET-false}"#;
975        let result = substitute_env_vars(input);
976        assert_eq!(result, "enabled = true");
977
978        unsafe {
979            std::env::remove_var("TEST_FORGE_OTEL_SET");
980        }
981    }
982
983    #[test]
984    fn test_env_var_colon_dash_default() {
985        unsafe {
986            std::env::remove_var("TEST_FORGE_ENDPOINT_UNSET");
987        }
988
989        let input = r#"endpoint = "${TEST_FORGE_ENDPOINT_UNSET:-http://localhost:4318}""#;
990        let result = substitute_env_vars(input);
991        assert_eq!(result, r#"endpoint = "http://localhost:4318""#);
992    }
993
994    #[test]
995    fn test_env_var_no_default_preserves_literal() {
996        unsafe {
997            std::env::remove_var("TEST_FORGE_MISSING");
998        }
999
1000        let input = r#"url = "${TEST_FORGE_MISSING}""#;
1001        let result = substitute_env_vars(input);
1002        assert_eq!(result, r#"url = "${TEST_FORGE_MISSING}""#);
1003    }
1004
1005    #[test]
1006    fn test_env_var_default_empty_string() {
1007        unsafe {
1008            std::env::remove_var("TEST_FORGE_EMPTY_DEFAULT");
1009        }
1010
1011        let input = r#"val = "${TEST_FORGE_EMPTY_DEFAULT-}""#;
1012        let result = substitute_env_vars(input);
1013        assert_eq!(result, r#"val = """#);
1014    }
1015
1016    #[test]
1017    fn test_observability_config_default_disabled() {
1018        let toml = r#"
1019            [database]
1020            url = "postgres://localhost/test"
1021        "#;
1022
1023        let config = ForgeConfig::parse_toml(toml).unwrap();
1024        assert!(!config.observability.enabled);
1025        assert!(!config.observability.otlp_active());
1026    }
1027
1028    #[test]
1029    fn test_observability_config_with_env_default() {
1030        // Simulates what the template produces when no env vars are set
1031        unsafe {
1032            std::env::remove_var("TEST_OTEL_ENABLED");
1033        }
1034
1035        let toml = r#"
1036            [database]
1037            url = "postgres://localhost/test"
1038
1039            [observability]
1040            enabled = ${TEST_OTEL_ENABLED-false}
1041        "#;
1042
1043        let config = ForgeConfig::parse_toml(toml).unwrap();
1044        assert!(!config.observability.enabled);
1045    }
1046
1047    #[test]
1048    fn test_mcp_config_validation_rejects_invalid_path() {
1049        let toml = r#"
1050            [database]
1051
1052            url = "postgres://localhost/test"
1053
1054            [mcp]
1055            enabled = true
1056            path = "mcp"
1057        "#;
1058
1059        let result = ForgeConfig::parse_toml(toml);
1060        assert!(result.is_err());
1061        let err_msg = result.unwrap_err().to_string();
1062        assert!(err_msg.contains("mcp.path must start with '/'"));
1063    }
1064
1065    #[test]
1066    fn test_access_token_ttl_defaults() {
1067        let auth = AuthConfig::default();
1068        assert_eq!(auth.access_token_ttl_secs(), 3600);
1069        assert_eq!(auth.refresh_token_ttl_days(), 30);
1070    }
1071
1072    #[test]
1073    fn test_access_token_ttl_custom() {
1074        let auth = AuthConfig {
1075            access_token_ttl: Some("15m".into()),
1076            refresh_token_ttl: Some("7d".into()),
1077            ..Default::default()
1078        };
1079        assert_eq!(auth.access_token_ttl_secs(), 900);
1080        assert_eq!(auth.refresh_token_ttl_days(), 7);
1081    }
1082
1083    #[test]
1084    fn test_access_token_ttl_minimum_enforced() {
1085        let auth = AuthConfig {
1086            access_token_ttl: Some("0s".into()),
1087            ..Default::default()
1088        };
1089        // Should floor at 1, not 0
1090        assert_eq!(auth.access_token_ttl_secs(), 1);
1091    }
1092
1093    #[test]
1094    fn test_refresh_token_ttl_minimum_enforced() {
1095        let auth = AuthConfig {
1096            refresh_token_ttl: Some("1h".into()),
1097            ..Default::default()
1098        };
1099        // 1 hour < 1 day, so should floor at 1 day
1100        assert_eq!(auth.refresh_token_ttl_days(), 1);
1101    }
1102
1103    #[test]
1104    fn test_mcp_config_rejects_reserved_paths() {
1105        for reserved in McpConfig::RESERVED_PATHS {
1106            let toml = format!(
1107                r#"
1108                [database]
1109                url = "postgres://localhost/test"
1110
1111                [mcp]
1112                enabled = true
1113                path = "{reserved}"
1114                "#
1115            );
1116
1117            let result = ForgeConfig::parse_toml(&toml);
1118            assert!(result.is_err(), "Expected {reserved} to be rejected");
1119            let err_msg = result.unwrap_err().to_string();
1120            assert!(
1121                err_msg.contains("conflicts with a reserved gateway route"),
1122                "Wrong error for {reserved}: {err_msg}"
1123            );
1124        }
1125    }
1126}