Skip to main content

forge_core/config/
mod.rs

1pub mod cluster;
2mod database;
3pub mod signals;
4
5pub use cluster::ClusterConfig;
6pub use database::{DatabaseConfig, PoolConfig};
7pub use signals::SignalsConfig;
8
9use serde::{Deserialize, Serialize};
10use std::path::Path;
11
12use crate::error::{ForgeError, Result};
13
14/// Root configuration for FORGE.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ForgeConfig {
17    /// Project metadata.
18    #[serde(default)]
19    pub project: ProjectConfig,
20
21    /// Database configuration.
22    pub database: DatabaseConfig,
23
24    /// Node configuration.
25    #[serde(default)]
26    pub node: NodeConfig,
27
28    /// Gateway configuration.
29    #[serde(default)]
30    pub gateway: GatewayConfig,
31
32    /// Function execution configuration.
33    #[serde(default)]
34    pub function: FunctionConfig,
35
36    /// Worker configuration.
37    #[serde(default)]
38    pub worker: WorkerConfig,
39
40    /// Cluster configuration.
41    #[serde(default)]
42    pub cluster: ClusterConfig,
43
44    /// Security configuration.
45    #[serde(default)]
46    pub security: SecurityConfig,
47
48    /// Authentication configuration.
49    #[serde(default)]
50    pub auth: AuthConfig,
51
52    /// Observability configuration.
53    #[serde(default)]
54    pub observability: ObservabilityConfig,
55
56    /// MCP server configuration.
57    #[serde(default)]
58    pub mcp: McpConfig,
59
60    /// Signals configuration for product analytics and diagnostics.
61    #[serde(default)]
62    pub signals: SignalsConfig,
63}
64
65impl ForgeConfig {
66    /// Load configuration from a TOML file.
67    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
68        let content = std::fs::read_to_string(path.as_ref())
69            .map_err(|e| ForgeError::Config(format!("Failed to read config file: {}", e)))?;
70
71        Self::parse_toml(&content)
72    }
73
74    /// Parse configuration from a TOML string.
75    pub fn parse_toml(content: &str) -> Result<Self> {
76        // Substitute environment variables
77        let content = substitute_env_vars(content);
78
79        let config: Self = toml::from_str(&content)
80            .map_err(|e| ForgeError::Config(format!("Failed to parse config: {}", e)))?;
81
82        config.validate()?;
83        Ok(config)
84    }
85
86    /// Validate the configuration for invalid combinations.
87    pub fn validate(&self) -> Result<()> {
88        self.database.validate()?;
89        self.auth.validate()?;
90        self.mcp.validate()?;
91        let body_limit = self.gateway.max_body_size_bytes()?;
92        let file_limit = self.gateway.max_file_size_bytes()?;
93        if file_limit > body_limit {
94            return Err(ForgeError::Config(format!(
95                "gateway.max_file_size ({}) cannot exceed gateway.max_body_size ({})",
96                self.gateway.max_file_size, self.gateway.max_body_size
97            )));
98        }
99
100        // Cross-field: OAuth requires jwt_secret for signing tokens
101        if self.mcp.oauth && self.auth.jwt_secret.is_none() {
102            return Err(ForgeError::Config(
103                "mcp.oauth = true requires auth.jwt_secret to be set. \
104                 OAuth-issued tokens are signed with this secret, even when using \
105                 an external provider (JWKS) for identity verification."
106                    .into(),
107            ));
108        }
109        if self.mcp.oauth && !self.mcp.enabled {
110            return Err(ForgeError::Config(
111                "mcp.oauth = true requires mcp.enabled = true".into(),
112            ));
113        }
114
115        Ok(())
116    }
117
118    /// Load configuration with defaults.
119    pub fn default_with_database_url(url: &str) -> Self {
120        Self {
121            project: ProjectConfig::default(),
122            database: DatabaseConfig::new(url),
123            node: NodeConfig::default(),
124            gateway: GatewayConfig::default(),
125            function: FunctionConfig::default(),
126            worker: WorkerConfig::default(),
127            cluster: ClusterConfig::default(),
128            security: SecurityConfig::default(),
129            auth: AuthConfig::default(),
130            observability: ObservabilityConfig::default(),
131            mcp: McpConfig::default(),
132            signals: SignalsConfig::default(),
133        }
134    }
135}
136
137/// Project metadata.
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct ProjectConfig {
140    /// Project name.
141    #[serde(default = "default_project_name")]
142    pub name: String,
143
144    /// Project version.
145    #[serde(default = "default_version")]
146    pub version: String,
147}
148
149impl Default for ProjectConfig {
150    fn default() -> Self {
151        Self {
152            name: default_project_name(),
153            version: default_version(),
154        }
155    }
156}
157
158fn default_project_name() -> String {
159    "forge-app".to_string()
160}
161
162fn default_version() -> String {
163    "0.1.0".to_string()
164}
165
166/// Node role configuration.
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct NodeConfig {
169    /// Roles this node should assume.
170    #[serde(default = "default_roles")]
171    pub roles: Vec<NodeRole>,
172
173    /// Worker capabilities for job routing.
174    #[serde(default = "default_capabilities")]
175    pub worker_capabilities: Vec<String>,
176}
177
178impl Default for NodeConfig {
179    fn default() -> Self {
180        Self {
181            roles: default_roles(),
182            worker_capabilities: default_capabilities(),
183        }
184    }
185}
186
187fn default_roles() -> Vec<NodeRole> {
188    vec![
189        NodeRole::Gateway,
190        NodeRole::Function,
191        NodeRole::Worker,
192        NodeRole::Scheduler,
193    ]
194}
195
196fn default_capabilities() -> Vec<String> {
197    vec!["general".to_string()]
198}
199
200/// Available node roles.
201#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
202#[serde(rename_all = "lowercase")]
203pub enum NodeRole {
204    Gateway,
205    Function,
206    Worker,
207    Scheduler,
208}
209
210/// Gateway configuration.
211#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct GatewayConfig {
213    /// HTTP port.
214    #[serde(default = "default_http_port")]
215    pub port: u16,
216
217    /// gRPC port for inter-node communication (reserved for future use).
218    ///
219    /// This port is registered in the cluster node info but a gRPC listener
220    /// is not yet started. It will be used for efficient binary inter-node
221    /// RPC in a future release.
222    #[serde(default = "default_grpc_port")]
223    pub grpc_port: u16,
224
225    /// Maximum concurrent connections.
226    #[serde(default = "default_max_connections")]
227    pub max_connections: usize,
228
229    /// Maximum active SSE sessions.
230    #[serde(default = "default_sse_max_sessions")]
231    pub sse_max_sessions: usize,
232
233    /// Request timeout in seconds.
234    #[serde(default = "default_request_timeout")]
235    pub request_timeout_secs: u64,
236
237    /// Enable CORS handling.
238    #[serde(default = "default_cors_enabled")]
239    pub cors_enabled: bool,
240
241    /// Allowed CORS origins.
242    #[serde(default = "default_cors_origins")]
243    pub cors_origins: Vec<String>,
244
245    /// Routes excluded from request logs, metrics, and traces.
246    /// Defaults to `["/_api/health", "/_api/ready"]`. Set to `[]` to monitor everything.
247    #[serde(default = "default_quiet_routes")]
248    pub quiet_routes: Vec<String>,
249
250    /// Maximum request body size (e.g. "100mb", "1gb"). Defaults to "20mb".
251    #[serde(default = "default_max_body_size")]
252    pub max_body_size: String,
253
254    /// Default per-file cap for multipart uploads (e.g. "10mb", "200mb").
255    /// Applies when a mutation does not declare its own `max_size`. Set to
256    /// the same value as `max_body_size` to disable the per-file guard.
257    /// Defaults to "10mb".
258    #[serde(default = "default_max_file_size")]
259    pub max_file_size: String,
260}
261
262impl Default for GatewayConfig {
263    fn default() -> Self {
264        Self {
265            port: default_http_port(),
266            grpc_port: default_grpc_port(),
267            max_connections: default_max_connections(),
268            sse_max_sessions: default_sse_max_sessions(),
269            request_timeout_secs: default_request_timeout(),
270            cors_enabled: default_cors_enabled(),
271            cors_origins: default_cors_origins(),
272            quiet_routes: default_quiet_routes(),
273            max_body_size: default_max_body_size(),
274            max_file_size: default_max_file_size(),
275        }
276    }
277}
278
279impl GatewayConfig {
280    /// Parse `max_body_size` into bytes.
281    pub fn max_body_size_bytes(&self) -> crate::Result<usize> {
282        crate::util::parse_size(&self.max_body_size).ok_or_else(|| {
283            crate::ForgeError::Config(format!(
284                "invalid gateway.max_body_size '{}'. Expected a size like '20mb', '1gb', or '1048576'",
285                self.max_body_size
286            ))
287        })
288    }
289
290    /// Parse `max_file_size` into bytes.
291    pub fn max_file_size_bytes(&self) -> crate::Result<usize> {
292        crate::util::parse_size(&self.max_file_size).ok_or_else(|| {
293            crate::ForgeError::Config(format!(
294                "invalid gateway.max_file_size '{}'. Expected a size like '10mb', '200mb', or '1048576'",
295                self.max_file_size
296            ))
297        })
298    }
299}
300
301fn default_http_port() -> u16 {
302    9081
303}
304
305fn default_grpc_port() -> u16 {
306    9000
307}
308
309fn default_max_connections() -> usize {
310    4096
311}
312
313fn default_sse_max_sessions() -> usize {
314    10_000
315}
316
317fn default_request_timeout() -> u64 {
318    30
319}
320
321fn default_cors_enabled() -> bool {
322    false
323}
324
325fn default_cors_origins() -> Vec<String> {
326    Vec::new()
327}
328
329fn default_quiet_routes() -> Vec<String> {
330    vec![
331        "/_api/health".to_string(),
332        "/_api/ready".to_string(),
333        "/_api/signal/event".to_string(),
334        "/_api/signal/view".to_string(),
335        "/_api/signal/user".to_string(),
336        "/_api/signal/report".to_string(),
337    ]
338}
339
340fn default_max_body_size() -> String {
341    "20mb".to_string()
342}
343
344fn default_max_file_size() -> String {
345    "10mb".to_string()
346}
347
348/// Function execution configuration.
349#[derive(Debug, Clone, Serialize, Deserialize)]
350pub struct FunctionConfig {
351    /// Maximum concurrent function executions.
352    #[serde(default = "default_max_concurrent")]
353    pub max_concurrent: usize,
354
355    /// Function timeout in seconds.
356    #[serde(default = "default_function_timeout")]
357    pub timeout_secs: u64,
358
359    /// Advisory memory limit per function execution (in bytes).
360    ///
361    /// This value is exposed as configuration metadata for orchestrators
362    /// (e.g., Kubernetes resource requests) and observability dashboards.
363    /// It is not enforced at the process level since Rust does not provide
364    /// per-function memory sandboxing. Use container-level limits for hard
365    /// enforcement.
366    #[serde(default = "default_memory_limit")]
367    pub memory_limit: usize,
368}
369
370impl Default for FunctionConfig {
371    fn default() -> Self {
372        Self {
373            max_concurrent: default_max_concurrent(),
374            timeout_secs: default_function_timeout(),
375            memory_limit: default_memory_limit(),
376        }
377    }
378}
379
380fn default_max_concurrent() -> usize {
381    1000
382}
383
384fn default_function_timeout() -> u64 {
385    30
386}
387
388fn default_memory_limit() -> usize {
389    512 * 1024 * 1024 // 512 MiB
390}
391
392/// Worker configuration.
393#[derive(Debug, Clone, Serialize, Deserialize)]
394pub struct WorkerConfig {
395    /// Maximum concurrent jobs.
396    #[serde(default = "default_max_concurrent_jobs")]
397    pub max_concurrent_jobs: usize,
398
399    /// Job timeout in seconds.
400    #[serde(default = "default_job_timeout")]
401    pub job_timeout_secs: u64,
402
403    /// Poll interval in milliseconds.
404    #[serde(default = "default_poll_interval")]
405    pub poll_interval_ms: u64,
406}
407
408impl Default for WorkerConfig {
409    fn default() -> Self {
410        Self {
411            max_concurrent_jobs: default_max_concurrent_jobs(),
412            job_timeout_secs: default_job_timeout(),
413            poll_interval_ms: default_poll_interval(),
414        }
415    }
416}
417
418fn default_max_concurrent_jobs() -> usize {
419    50
420}
421
422fn default_job_timeout() -> u64 {
423    3600 // 1 hour
424}
425
426fn default_poll_interval() -> u64 {
427    100
428}
429
430/// Security configuration.
431#[derive(Debug, Clone, Serialize, Deserialize, Default)]
432pub struct SecurityConfig {
433    /// Secret key for signing.
434    pub secret_key: Option<String>,
435}
436
437/// JWT signing algorithm.
438#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
439#[serde(rename_all = "UPPERCASE")]
440pub enum JwtAlgorithm {
441    /// HMAC using SHA-256 (symmetric, requires jwt_secret).
442    #[default]
443    HS256,
444    /// HMAC using SHA-384 (symmetric, requires jwt_secret).
445    HS384,
446    /// HMAC using SHA-512 (symmetric, requires jwt_secret).
447    HS512,
448    /// RSA using SHA-256 (asymmetric, requires jwks_url).
449    RS256,
450    /// RSA using SHA-384 (asymmetric, requires jwks_url).
451    RS384,
452    /// RSA using SHA-512 (asymmetric, requires jwks_url).
453    RS512,
454}
455
456/// Authentication configuration.
457#[derive(Debug, Clone, Serialize, Deserialize)]
458pub struct AuthConfig {
459    /// JWT secret for HMAC algorithms (HS256, HS384, HS512).
460    /// Required when using HMAC algorithms.
461    pub jwt_secret: Option<String>,
462
463    /// JWT signing algorithm.
464    /// HMAC algorithms (HS256, HS384, HS512) require jwt_secret.
465    /// RSA algorithms (RS256, RS384, RS512) require jwks_url.
466    #[serde(default)]
467    pub jwt_algorithm: JwtAlgorithm,
468
469    /// Expected token issuer (iss claim).
470    /// If set, tokens with a different issuer are rejected.
471    pub jwt_issuer: Option<String>,
472
473    /// Expected audience (aud claim).
474    /// If set, tokens with a different audience are rejected.
475    pub jwt_audience: Option<String>,
476
477    /// Access token lifetime (e.g., "15m", "1h").
478    /// Used by `ctx.issue_token_pair()`. Defaults to "1h".
479    pub access_token_ttl: Option<String>,
480
481    /// Refresh token lifetime (e.g., "7d", "30d").
482    /// Used by `ctx.issue_token_pair()`. Defaults to "30d".
483    pub refresh_token_ttl: Option<String>,
484
485    /// JWKS URL for RSA algorithms (RS256, RS384, RS512).
486    /// Keys are fetched and cached automatically.
487    pub jwks_url: Option<String>,
488
489    /// JWKS cache TTL in seconds.
490    #[serde(default = "default_jwks_cache_ttl")]
491    pub jwks_cache_ttl_secs: u64,
492
493    /// Session TTL in seconds (for WebSocket sessions).
494    #[serde(default = "default_session_ttl")]
495    pub session_ttl_secs: u64,
496}
497
498impl Default for AuthConfig {
499    fn default() -> Self {
500        Self {
501            jwt_secret: None,
502            jwt_algorithm: JwtAlgorithm::default(),
503            jwt_issuer: None,
504            jwt_audience: None,
505            access_token_ttl: None,
506            refresh_token_ttl: None,
507            jwks_url: None,
508            jwks_cache_ttl_secs: default_jwks_cache_ttl(),
509            session_ttl_secs: default_session_ttl(),
510        }
511    }
512}
513
514impl AuthConfig {
515    /// Resolved access token TTL in seconds.
516    /// Parses `access_token_ttl`, default 3600s (1h).
517    /// Minimum 1 second to prevent zero-lifetime tokens.
518    pub fn access_token_ttl_secs(&self) -> i64 {
519        self.access_token_ttl
520            .as_deref()
521            .and_then(crate::util::parse_duration)
522            .map(|d| (d.as_secs() as i64).max(1))
523            .unwrap_or(3600)
524    }
525
526    /// Resolved refresh token TTL in days.
527    /// Parses `refresh_token_ttl`, default 30 days.
528    pub fn refresh_token_ttl_days(&self) -> i64 {
529        self.refresh_token_ttl
530            .as_deref()
531            .and_then(crate::util::parse_duration)
532            .map(|d| (d.as_secs() / 86400) as i64)
533            .map(|d| if d == 0 { 1 } else { d })
534            .unwrap_or(30)
535    }
536
537    /// Check if auth is configured (any credential or claim validation is set).
538    fn is_configured(&self) -> bool {
539        self.jwt_secret.is_some()
540            || self.jwks_url.is_some()
541            || self.jwt_issuer.is_some()
542            || self.jwt_audience.is_some()
543    }
544
545    /// Validate that the configuration is complete for the chosen algorithm.
546    /// Skips validation if no auth settings are configured (auth disabled).
547    pub fn validate(&self) -> Result<()> {
548        if !self.is_configured() {
549            return Ok(());
550        }
551
552        match self.jwt_algorithm {
553            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512 => {
554                if self.jwt_secret.is_none() {
555                    return Err(ForgeError::Config(
556                        "auth.jwt_secret is required for HMAC algorithms (HS256, HS384, HS512). \
557                         Set auth.jwt_secret to a secure random string, \
558                         or switch to RS256 and provide auth.jwks_url for external identity providers."
559                            .into(),
560                    ));
561                }
562            }
563            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512 => {
564                if self.jwks_url.is_none() {
565                    return Err(ForgeError::Config(
566                        "auth.jwks_url is required for RSA algorithms (RS256, RS384, RS512). \
567                         Set auth.jwks_url to your identity provider's JWKS endpoint, \
568                         or switch to HS256 and provide auth.jwt_secret for symmetric signing."
569                            .into(),
570                    ));
571                }
572            }
573        }
574        Ok(())
575    }
576
577    /// Check if this config uses HMAC (symmetric) algorithms.
578    pub fn is_hmac(&self) -> bool {
579        matches!(
580            self.jwt_algorithm,
581            JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
582        )
583    }
584
585    /// Check if this config uses RSA (asymmetric) algorithms.
586    pub fn is_rsa(&self) -> bool {
587        matches!(
588            self.jwt_algorithm,
589            JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
590        )
591    }
592}
593
594fn default_jwks_cache_ttl() -> u64 {
595    3600 // 1 hour
596}
597
598fn default_session_ttl() -> u64 {
599    7 * 24 * 60 * 60 // 7 days
600}
601
602/// Observability configuration for OTLP telemetry.
603#[derive(Debug, Clone, Serialize, Deserialize)]
604pub struct ObservabilityConfig {
605    /// Enable observability (traces, metrics, logs).
606    #[serde(default)]
607    pub enabled: bool,
608
609    /// OTLP endpoint for telemetry export.
610    #[serde(default = "default_otlp_endpoint")]
611    pub otlp_endpoint: String,
612
613    /// Service name for telemetry identification.
614    pub service_name: Option<String>,
615
616    /// Enable distributed tracing.
617    #[serde(default = "default_true")]
618    pub enable_traces: bool,
619
620    /// Enable metrics collection.
621    #[serde(default = "default_true")]
622    pub enable_metrics: bool,
623
624    /// Enable log export via OTLP.
625    #[serde(default = "default_true")]
626    pub enable_logs: bool,
627
628    /// Trace sampling ratio (0.0 to 1.0).
629    #[serde(default = "default_sampling_ratio")]
630    pub sampling_ratio: f64,
631
632    /// Metrics export interval in seconds. OTLP collectors typically prefer 15s-60s.
633    #[serde(default = "default_metrics_interval_secs")]
634    pub metrics_interval_secs: u64,
635
636    /// Log level for the tracing subscriber (e.g., "debug", "info", "warn").
637    #[serde(default = "default_log_level")]
638    pub log_level: String,
639}
640
641impl Default for ObservabilityConfig {
642    fn default() -> Self {
643        Self {
644            enabled: false,
645            otlp_endpoint: default_otlp_endpoint(),
646            service_name: None,
647            enable_traces: true,
648            enable_metrics: true,
649            enable_logs: true,
650            sampling_ratio: default_sampling_ratio(),
651            metrics_interval_secs: default_metrics_interval_secs(),
652            log_level: default_log_level(),
653        }
654    }
655}
656
657impl ObservabilityConfig {
658    pub fn otlp_active(&self) -> bool {
659        self.enabled && (self.enable_traces || self.enable_metrics || self.enable_logs)
660    }
661}
662
663fn default_otlp_endpoint() -> String {
664    "http://localhost:4318".to_string()
665}
666
667pub(crate) fn default_true() -> bool {
668    true
669}
670
671fn default_sampling_ratio() -> f64 {
672    1.0
673}
674
675fn default_metrics_interval_secs() -> u64 {
676    15
677}
678
679fn default_log_level() -> String {
680    "info".to_string()
681}
682
683/// MCP server configuration.
684#[derive(Debug, Clone, Serialize, Deserialize)]
685pub struct McpConfig {
686    /// Enable MCP endpoint exposure.
687    #[serde(default)]
688    pub enabled: bool,
689
690    /// Enable OAuth 2.1 Authorization Code + PKCE for MCP clients.
691    /// When true, Forge acts as an OAuth 2.1 Authorization Server so MCP
692    /// clients like Claude Code can auto-authenticate via browser login.
693    /// Requires `auth.jwt_secret` to be set.
694    #[serde(default)]
695    pub oauth: bool,
696
697    /// MCP endpoint path under the gateway API namespace.
698    #[serde(default = "default_mcp_path")]
699    pub path: String,
700
701    /// Session TTL in seconds.
702    #[serde(default = "default_mcp_session_ttl_secs")]
703    pub session_ttl_secs: u64,
704
705    /// Allowed origins for Origin header validation.
706    #[serde(default)]
707    pub allowed_origins: Vec<String>,
708
709    /// Enforce MCP-Protocol-Version header on post-initialize requests.
710    #[serde(default = "default_true")]
711    pub require_protocol_version_header: bool,
712}
713
714impl Default for McpConfig {
715    fn default() -> Self {
716        Self {
717            enabled: false,
718            oauth: false,
719            path: default_mcp_path(),
720            session_ttl_secs: default_mcp_session_ttl_secs(),
721            allowed_origins: Vec::new(),
722            require_protocol_version_header: default_true(),
723        }
724    }
725}
726
727impl McpConfig {
728    /// Paths reserved by the gateway that MCP must not collide with.
729    const RESERVED_PATHS: &[&str] = &[
730        "/health",
731        "/ready",
732        "/rpc",
733        "/events",
734        "/subscribe",
735        "/unsubscribe",
736        "/subscribe-job",
737        "/subscribe-workflow",
738        "/metrics",
739    ];
740
741    pub fn validate(&self) -> Result<()> {
742        if self.path.is_empty() || !self.path.starts_with('/') {
743            return Err(ForgeError::Config(
744                "mcp.path must start with '/' (example: /mcp)".to_string(),
745            ));
746        }
747        if self.path.contains(' ') {
748            return Err(ForgeError::Config(
749                "mcp.path cannot contain spaces".to_string(),
750            ));
751        }
752        if Self::RESERVED_PATHS.contains(&self.path.as_str()) {
753            return Err(ForgeError::Config(format!(
754                "mcp.path '{}' conflicts with a reserved gateway route",
755                self.path
756            )));
757        }
758        if self.session_ttl_secs == 0 {
759            return Err(ForgeError::Config(
760                "mcp.session_ttl_secs must be greater than 0".to_string(),
761            ));
762        }
763        Ok(())
764    }
765}
766
767fn default_mcp_path() -> String {
768    "/mcp".to_string()
769}
770
771fn default_mcp_session_ttl_secs() -> u64 {
772    60 * 60
773}
774
775/// Substitute environment variables in the format `${VAR_NAME}`.
776///
777/// Supports default values with `${VAR-default}` or `${VAR:-default}`.
778/// When the env var is unset, the default is used. Without a default,
779/// the literal `${VAR}` is preserved (so TOML parsing can still fail
780/// loudly if a required variable is missing).
781#[allow(clippy::indexing_slicing)]
782pub fn substitute_env_vars(content: &str) -> String {
783    let mut result = String::with_capacity(content.len());
784    let bytes = content.as_bytes();
785    let len = bytes.len();
786    let mut i = 0;
787
788    while i < len {
789        if i + 1 < len
790            && bytes[i] == b'$'
791            && bytes[i + 1] == b'{'
792            && let Some(end) = content[i + 2..].find('}')
793        {
794            let inner = &content[i + 2..i + 2 + end];
795
796            // Split on first `-` or `:-` for default value support
797            let (var_name, default_value) = parse_var_with_default(inner);
798
799            if is_valid_env_var_name(var_name) {
800                if let Ok(value) = std::env::var(var_name) {
801                    result.push_str(&value);
802                } else if let Some(default) = default_value {
803                    result.push_str(default);
804                } else {
805                    result.push_str(&content[i..i + 2 + end + 1]);
806                }
807                i += 2 + end + 1;
808                continue;
809            }
810        }
811        result.push(bytes[i] as char);
812        i += 1;
813    }
814
815    result
816}
817
818/// Parse `VAR-default` or `VAR:-default` into (name, optional default).
819/// Both forms behave identically (fallback when unset). `:-` is checked
820/// first so its `-` doesn't get matched by the plain `-` branch.
821fn parse_var_with_default(inner: &str) -> (&str, Option<&str>) {
822    if let Some(pos) = inner.find(":-") {
823        return (&inner[..pos], Some(&inner[pos + 2..]));
824    }
825    if let Some(pos) = inner.find('-') {
826        return (&inner[..pos], Some(&inner[pos + 1..]));
827    }
828    (inner, None)
829}
830
831fn is_valid_env_var_name(name: &str) -> bool {
832    let first = match name.as_bytes().first() {
833        Some(b) => b,
834        None => return false,
835    };
836    (first.is_ascii_uppercase() || *first == b'_')
837        && name
838            .bytes()
839            .all(|b| b.is_ascii_uppercase() || b.is_ascii_digit() || b == b'_')
840}
841
842#[cfg(test)]
843#[allow(clippy::unwrap_used, clippy::indexing_slicing, unsafe_code)]
844mod tests {
845    use super::*;
846
847    #[test]
848    fn test_default_config() {
849        let config = ForgeConfig::default_with_database_url("postgres://localhost/test");
850        assert_eq!(config.gateway.port, 9081);
851        assert_eq!(config.node.roles.len(), 4);
852        assert_eq!(config.mcp.path, "/mcp");
853        assert!(!config.mcp.enabled);
854    }
855
856    #[test]
857    fn test_parse_minimal_config() {
858        let toml = r#"
859            [database]
860            url = "postgres://localhost/myapp"
861        "#;
862
863        let config = ForgeConfig::parse_toml(toml).unwrap();
864        assert_eq!(config.database.url(), "postgres://localhost/myapp");
865        assert_eq!(config.gateway.port, 9081);
866    }
867
868    #[test]
869    fn test_parse_full_config() {
870        let toml = r#"
871            [project]
872            name = "my-app"
873            version = "1.0.0"
874
875            [database]
876            url = "postgres://localhost/myapp"
877            pool_size = 100
878
879            [node]
880            roles = ["gateway", "worker"]
881            worker_capabilities = ["media", "general"]
882
883            [gateway]
884            port = 3000
885            grpc_port = 9001
886        "#;
887
888        let config = ForgeConfig::parse_toml(toml).unwrap();
889        assert_eq!(config.project.name, "my-app");
890        assert_eq!(config.database.pool_size, 100);
891        assert_eq!(config.node.roles.len(), 2);
892        assert_eq!(config.gateway.port, 3000);
893    }
894
895    #[test]
896    fn test_env_var_substitution() {
897        unsafe {
898            std::env::set_var("TEST_DB_URL", "postgres://test:test@localhost/test");
899        }
900
901        let toml = r#"
902            [database]
903            url = "${TEST_DB_URL}"
904        "#;
905
906        let config = ForgeConfig::parse_toml(toml).unwrap();
907        assert_eq!(config.database.url(), "postgres://test:test@localhost/test");
908
909        unsafe {
910            std::env::remove_var("TEST_DB_URL");
911        }
912    }
913
914    #[test]
915    fn test_auth_validation_no_config() {
916        let auth = AuthConfig::default();
917        assert!(auth.validate().is_ok());
918    }
919
920    #[test]
921    fn test_auth_validation_hmac_with_secret() {
922        let auth = AuthConfig {
923            jwt_secret: Some("my-secret".into()),
924            jwt_algorithm: JwtAlgorithm::HS256,
925            ..Default::default()
926        };
927        assert!(auth.validate().is_ok());
928    }
929
930    #[test]
931    fn test_auth_validation_hmac_missing_secret() {
932        let auth = AuthConfig {
933            jwt_issuer: Some("my-issuer".into()),
934            jwt_algorithm: JwtAlgorithm::HS256,
935            ..Default::default()
936        };
937        let result = auth.validate();
938        assert!(result.is_err());
939        let err_msg = result.unwrap_err().to_string();
940        assert!(err_msg.contains("jwt_secret is required"));
941    }
942
943    #[test]
944    fn test_auth_validation_rsa_with_jwks() {
945        let auth = AuthConfig {
946            jwks_url: Some("https://example.com/.well-known/jwks.json".into()),
947            jwt_algorithm: JwtAlgorithm::RS256,
948            ..Default::default()
949        };
950        assert!(auth.validate().is_ok());
951    }
952
953    #[test]
954    fn test_auth_validation_rsa_missing_jwks() {
955        let auth = AuthConfig {
956            jwt_issuer: Some("my-issuer".into()),
957            jwt_algorithm: JwtAlgorithm::RS256,
958            ..Default::default()
959        };
960        let result = auth.validate();
961        assert!(result.is_err());
962        let err_msg = result.unwrap_err().to_string();
963        assert!(err_msg.contains("jwks_url is required"));
964    }
965
966    #[test]
967    fn test_forge_config_validation_fails_on_empty_url() {
968        let toml = r#"
969            [database]
970
971            url = ""
972        "#;
973
974        let result = ForgeConfig::parse_toml(toml);
975        assert!(result.is_err());
976        let err_msg = result.unwrap_err().to_string();
977        assert!(err_msg.contains("database.url is required"));
978    }
979
980    #[test]
981    fn test_forge_config_validation_fails_on_invalid_auth() {
982        let toml = r#"
983            [database]
984
985            url = "postgres://localhost/test"
986
987            [auth]
988            jwt_issuer = "my-issuer"
989            jwt_algorithm = "RS256"
990        "#;
991
992        let result = ForgeConfig::parse_toml(toml);
993        assert!(result.is_err());
994        let err_msg = result.unwrap_err().to_string();
995        assert!(err_msg.contains("jwks_url is required"));
996    }
997
998    #[test]
999    fn test_env_var_default_used_when_unset() {
1000        // Ensure the var is definitely not set
1001        unsafe {
1002            std::env::remove_var("TEST_FORGE_OTEL_UNSET");
1003        }
1004
1005        let input = r#"enabled = ${TEST_FORGE_OTEL_UNSET-false}"#;
1006        let result = substitute_env_vars(input);
1007        assert_eq!(result, "enabled = false");
1008    }
1009
1010    #[test]
1011    fn test_env_var_default_overridden_when_set() {
1012        unsafe {
1013            std::env::set_var("TEST_FORGE_OTEL_SET", "true");
1014        }
1015
1016        let input = r#"enabled = ${TEST_FORGE_OTEL_SET-false}"#;
1017        let result = substitute_env_vars(input);
1018        assert_eq!(result, "enabled = true");
1019
1020        unsafe {
1021            std::env::remove_var("TEST_FORGE_OTEL_SET");
1022        }
1023    }
1024
1025    #[test]
1026    fn test_env_var_colon_dash_default() {
1027        unsafe {
1028            std::env::remove_var("TEST_FORGE_ENDPOINT_UNSET");
1029        }
1030
1031        let input = r#"endpoint = "${TEST_FORGE_ENDPOINT_UNSET:-http://localhost:4318}""#;
1032        let result = substitute_env_vars(input);
1033        assert_eq!(result, r#"endpoint = "http://localhost:4318""#);
1034    }
1035
1036    #[test]
1037    fn test_env_var_no_default_preserves_literal() {
1038        unsafe {
1039            std::env::remove_var("TEST_FORGE_MISSING");
1040        }
1041
1042        let input = r#"url = "${TEST_FORGE_MISSING}""#;
1043        let result = substitute_env_vars(input);
1044        assert_eq!(result, r#"url = "${TEST_FORGE_MISSING}""#);
1045    }
1046
1047    #[test]
1048    fn test_env_var_default_empty_string() {
1049        unsafe {
1050            std::env::remove_var("TEST_FORGE_EMPTY_DEFAULT");
1051        }
1052
1053        let input = r#"val = "${TEST_FORGE_EMPTY_DEFAULT-}""#;
1054        let result = substitute_env_vars(input);
1055        assert_eq!(result, r#"val = """#);
1056    }
1057
1058    #[test]
1059    fn test_observability_config_default_disabled() {
1060        let toml = r#"
1061            [database]
1062            url = "postgres://localhost/test"
1063        "#;
1064
1065        let config = ForgeConfig::parse_toml(toml).unwrap();
1066        assert!(!config.observability.enabled);
1067        assert!(!config.observability.otlp_active());
1068    }
1069
1070    #[test]
1071    fn test_observability_config_with_env_default() {
1072        // Simulates what the template produces when no env vars are set
1073        unsafe {
1074            std::env::remove_var("TEST_OTEL_ENABLED");
1075        }
1076
1077        let toml = r#"
1078            [database]
1079            url = "postgres://localhost/test"
1080
1081            [observability]
1082            enabled = ${TEST_OTEL_ENABLED-false}
1083        "#;
1084
1085        let config = ForgeConfig::parse_toml(toml).unwrap();
1086        assert!(!config.observability.enabled);
1087    }
1088
1089    #[test]
1090    fn test_mcp_config_validation_rejects_invalid_path() {
1091        let toml = r#"
1092            [database]
1093
1094            url = "postgres://localhost/test"
1095
1096            [mcp]
1097            enabled = true
1098            path = "mcp"
1099        "#;
1100
1101        let result = ForgeConfig::parse_toml(toml);
1102        assert!(result.is_err());
1103        let err_msg = result.unwrap_err().to_string();
1104        assert!(err_msg.contains("mcp.path must start with '/'"));
1105    }
1106
1107    #[test]
1108    fn test_access_token_ttl_defaults() {
1109        let auth = AuthConfig::default();
1110        assert_eq!(auth.access_token_ttl_secs(), 3600);
1111        assert_eq!(auth.refresh_token_ttl_days(), 30);
1112    }
1113
1114    #[test]
1115    fn test_access_token_ttl_custom() {
1116        let auth = AuthConfig {
1117            access_token_ttl: Some("15m".into()),
1118            refresh_token_ttl: Some("7d".into()),
1119            ..Default::default()
1120        };
1121        assert_eq!(auth.access_token_ttl_secs(), 900);
1122        assert_eq!(auth.refresh_token_ttl_days(), 7);
1123    }
1124
1125    #[test]
1126    fn test_access_token_ttl_minimum_enforced() {
1127        let auth = AuthConfig {
1128            access_token_ttl: Some("0s".into()),
1129            ..Default::default()
1130        };
1131        // Should floor at 1, not 0
1132        assert_eq!(auth.access_token_ttl_secs(), 1);
1133    }
1134
1135    #[test]
1136    fn test_refresh_token_ttl_minimum_enforced() {
1137        let auth = AuthConfig {
1138            refresh_token_ttl: Some("1h".into()),
1139            ..Default::default()
1140        };
1141        // 1 hour < 1 day, so should floor at 1 day
1142        assert_eq!(auth.refresh_token_ttl_days(), 1);
1143    }
1144
1145    #[test]
1146    fn test_max_body_size_defaults() {
1147        let gw = GatewayConfig::default();
1148        assert_eq!(gw.max_body_size_bytes().unwrap(), 20 * 1024 * 1024);
1149    }
1150
1151    #[test]
1152    fn test_max_body_size_custom() {
1153        let gw = GatewayConfig {
1154            max_body_size: "100mb".into(),
1155            ..Default::default()
1156        };
1157        assert_eq!(gw.max_body_size_bytes().unwrap(), 100 * 1024 * 1024);
1158    }
1159
1160    #[test]
1161    fn test_max_body_size_invalid_errors() {
1162        let gw = GatewayConfig {
1163            max_body_size: "not-a-size".into(),
1164            ..Default::default()
1165        };
1166        assert!(gw.max_body_size_bytes().is_err());
1167    }
1168
1169    #[test]
1170    fn test_max_file_size_defaults() {
1171        let gw = GatewayConfig::default();
1172        assert_eq!(gw.max_file_size_bytes().unwrap(), 10 * 1024 * 1024);
1173    }
1174
1175    #[test]
1176    fn test_max_file_size_custom() {
1177        let gw = GatewayConfig {
1178            max_file_size: "200mb".into(),
1179            max_body_size: "500mb".into(),
1180            ..Default::default()
1181        };
1182        assert_eq!(gw.max_file_size_bytes().unwrap(), 200 * 1024 * 1024);
1183    }
1184
1185    #[test]
1186    fn test_max_file_size_invalid_errors() {
1187        let gw = GatewayConfig {
1188            max_file_size: "nope".into(),
1189            ..Default::default()
1190        };
1191        assert!(gw.max_file_size_bytes().is_err());
1192    }
1193
1194    #[test]
1195    fn test_validate_rejects_file_larger_than_body() {
1196        let toml = r#"
1197            [database]
1198            url = "postgres://localhost/test"
1199
1200            [gateway]
1201            max_body_size = "10mb"
1202            max_file_size = "20mb"
1203        "#;
1204        let err = ForgeConfig::parse_toml(toml).unwrap_err().to_string();
1205        assert!(
1206            err.contains("max_file_size"),
1207            "Expected max_file_size error, got: {err}"
1208        );
1209    }
1210
1211    #[test]
1212    fn test_mcp_config_rejects_reserved_paths() {
1213        for reserved in McpConfig::RESERVED_PATHS {
1214            let toml = format!(
1215                r#"
1216                [database]
1217                url = "postgres://localhost/test"
1218
1219                [mcp]
1220                enabled = true
1221                path = "{reserved}"
1222                "#
1223            );
1224
1225            let result = ForgeConfig::parse_toml(&toml);
1226            assert!(result.is_err(), "Expected {reserved} to be rejected");
1227            let err_msg = result.unwrap_err().to_string();
1228            assert!(
1229                err_msg.contains("conflicts with a reserved gateway route"),
1230                "Wrong error for {reserved}: {err_msg}"
1231            );
1232        }
1233    }
1234}