pub mod cluster;
mod database;
pub mod signals;
pub use cluster::ClusterConfig;
pub use database::{DatabaseConfig, PoolConfig};
pub use signals::SignalsConfig;
use serde::{Deserialize, Serialize};
use std::path::Path;
use crate::error::{ForgeError, Result};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ForgeConfig {
#[serde(default)]
pub project: ProjectConfig,
pub database: DatabaseConfig,
#[serde(default)]
pub node: NodeConfig,
#[serde(default)]
pub gateway: GatewayConfig,
#[serde(default)]
pub function: FunctionConfig,
#[serde(default)]
pub worker: WorkerConfig,
#[serde(default)]
pub cluster: ClusterConfig,
#[serde(default)]
pub security: SecurityConfig,
#[serde(default)]
pub auth: AuthConfig,
#[serde(default)]
pub observability: ObservabilityConfig,
#[serde(default)]
pub mcp: McpConfig,
#[serde(default)]
pub signals: SignalsConfig,
}
impl ForgeConfig {
pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
let content = std::fs::read_to_string(path.as_ref())
.map_err(|e| ForgeError::Config(format!("Failed to read config file: {}", e)))?;
Self::parse_toml(&content)
}
pub fn parse_toml(content: &str) -> Result<Self> {
let content = substitute_env_vars(content);
let config: Self = toml::from_str(&content)
.map_err(|e| ForgeError::Config(format!("Failed to parse config: {}", e)))?;
config.validate()?;
Ok(config)
}
pub fn validate(&self) -> Result<()> {
self.database.validate()?;
self.auth.validate()?;
self.mcp.validate()?;
let body_limit = self.gateway.max_body_size_bytes()?;
let file_limit = self.gateway.max_file_size_bytes()?;
if file_limit > body_limit {
return Err(ForgeError::Config(format!(
"gateway.max_file_size ({}) cannot exceed gateway.max_body_size ({})",
self.gateway.max_file_size, self.gateway.max_body_size
)));
}
if self.mcp.oauth && self.auth.jwt_secret.is_none() {
return Err(ForgeError::Config(
"mcp.oauth = true requires auth.jwt_secret to be set. \
OAuth-issued tokens are signed with this secret, even when using \
an external provider (JWKS) for identity verification."
.into(),
));
}
if self.mcp.oauth && !self.mcp.enabled {
return Err(ForgeError::Config(
"mcp.oauth = true requires mcp.enabled = true".into(),
));
}
Ok(())
}
pub fn default_with_database_url(url: &str) -> Self {
Self {
project: ProjectConfig::default(),
database: DatabaseConfig::new(url),
node: NodeConfig::default(),
gateway: GatewayConfig::default(),
function: FunctionConfig::default(),
worker: WorkerConfig::default(),
cluster: ClusterConfig::default(),
security: SecurityConfig::default(),
auth: AuthConfig::default(),
observability: ObservabilityConfig::default(),
mcp: McpConfig::default(),
signals: SignalsConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProjectConfig {
#[serde(default = "default_project_name")]
pub name: String,
#[serde(default = "default_version")]
pub version: String,
}
impl Default for ProjectConfig {
fn default() -> Self {
Self {
name: default_project_name(),
version: default_version(),
}
}
}
fn default_project_name() -> String {
"forge-app".to_string()
}
fn default_version() -> String {
"0.1.0".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeConfig {
#[serde(default = "default_roles")]
pub roles: Vec<NodeRole>,
#[serde(default = "default_capabilities")]
pub worker_capabilities: Vec<String>,
}
impl Default for NodeConfig {
fn default() -> Self {
Self {
roles: default_roles(),
worker_capabilities: default_capabilities(),
}
}
}
fn default_roles() -> Vec<NodeRole> {
vec![
NodeRole::Gateway,
NodeRole::Function,
NodeRole::Worker,
NodeRole::Scheduler,
]
}
fn default_capabilities() -> Vec<String> {
vec!["general".to_string()]
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum NodeRole {
Gateway,
Function,
Worker,
Scheduler,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GatewayConfig {
#[serde(default = "default_http_port")]
pub port: u16,
#[serde(default = "default_grpc_port")]
pub grpc_port: u16,
#[serde(default = "default_max_connections")]
pub max_connections: usize,
#[serde(default = "default_sse_max_sessions")]
pub sse_max_sessions: usize,
#[serde(default = "default_request_timeout")]
pub request_timeout_secs: u64,
#[serde(default = "default_cors_enabled")]
pub cors_enabled: bool,
#[serde(default = "default_cors_origins")]
pub cors_origins: Vec<String>,
#[serde(default = "default_quiet_routes")]
pub quiet_routes: Vec<String>,
#[serde(default = "default_max_body_size")]
pub max_body_size: String,
#[serde(default = "default_max_file_size")]
pub max_file_size: String,
}
impl Default for GatewayConfig {
fn default() -> Self {
Self {
port: default_http_port(),
grpc_port: default_grpc_port(),
max_connections: default_max_connections(),
sse_max_sessions: default_sse_max_sessions(),
request_timeout_secs: default_request_timeout(),
cors_enabled: default_cors_enabled(),
cors_origins: default_cors_origins(),
quiet_routes: default_quiet_routes(),
max_body_size: default_max_body_size(),
max_file_size: default_max_file_size(),
}
}
}
impl GatewayConfig {
pub fn max_body_size_bytes(&self) -> crate::Result<usize> {
crate::util::parse_size(&self.max_body_size).ok_or_else(|| {
crate::ForgeError::Config(format!(
"invalid gateway.max_body_size '{}'. Expected a size like '20mb', '1gb', or '1048576'",
self.max_body_size
))
})
}
pub fn max_file_size_bytes(&self) -> crate::Result<usize> {
crate::util::parse_size(&self.max_file_size).ok_or_else(|| {
crate::ForgeError::Config(format!(
"invalid gateway.max_file_size '{}'. Expected a size like '10mb', '200mb', or '1048576'",
self.max_file_size
))
})
}
}
fn default_http_port() -> u16 {
9081
}
fn default_grpc_port() -> u16 {
9000
}
fn default_max_connections() -> usize {
4096
}
fn default_sse_max_sessions() -> usize {
10_000
}
fn default_request_timeout() -> u64 {
30
}
fn default_cors_enabled() -> bool {
false
}
fn default_cors_origins() -> Vec<String> {
Vec::new()
}
fn default_quiet_routes() -> Vec<String> {
vec![
"/_api/health".to_string(),
"/_api/ready".to_string(),
"/_api/signal/event".to_string(),
"/_api/signal/view".to_string(),
"/_api/signal/user".to_string(),
"/_api/signal/report".to_string(),
]
}
fn default_max_body_size() -> String {
"20mb".to_string()
}
fn default_max_file_size() -> String {
"10mb".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionConfig {
#[serde(default = "default_max_concurrent")]
pub max_concurrent: usize,
#[serde(default = "default_function_timeout")]
pub timeout_secs: u64,
#[serde(default = "default_memory_limit")]
pub memory_limit: usize,
}
impl Default for FunctionConfig {
fn default() -> Self {
Self {
max_concurrent: default_max_concurrent(),
timeout_secs: default_function_timeout(),
memory_limit: default_memory_limit(),
}
}
}
fn default_max_concurrent() -> usize {
1000
}
fn default_function_timeout() -> u64 {
30
}
fn default_memory_limit() -> usize {
512 * 1024 * 1024 }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerConfig {
#[serde(default = "default_max_concurrent_jobs")]
pub max_concurrent_jobs: usize,
#[serde(default = "default_job_timeout")]
pub job_timeout_secs: u64,
#[serde(default = "default_poll_interval")]
pub poll_interval_ms: u64,
}
impl Default for WorkerConfig {
fn default() -> Self {
Self {
max_concurrent_jobs: default_max_concurrent_jobs(),
job_timeout_secs: default_job_timeout(),
poll_interval_ms: default_poll_interval(),
}
}
}
fn default_max_concurrent_jobs() -> usize {
50
}
fn default_job_timeout() -> u64 {
3600 }
fn default_poll_interval() -> u64 {
100
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct SecurityConfig {
pub secret_key: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "UPPERCASE")]
pub enum JwtAlgorithm {
#[default]
HS256,
HS384,
HS512,
RS256,
RS384,
RS512,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthConfig {
pub jwt_secret: Option<String>,
#[serde(default)]
pub jwt_algorithm: JwtAlgorithm,
pub jwt_issuer: Option<String>,
pub jwt_audience: Option<String>,
pub access_token_ttl: Option<String>,
pub refresh_token_ttl: Option<String>,
pub jwks_url: Option<String>,
#[serde(default = "default_jwks_cache_ttl")]
pub jwks_cache_ttl_secs: u64,
#[serde(default = "default_session_ttl")]
pub session_ttl_secs: u64,
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
jwt_secret: None,
jwt_algorithm: JwtAlgorithm::default(),
jwt_issuer: None,
jwt_audience: None,
access_token_ttl: None,
refresh_token_ttl: None,
jwks_url: None,
jwks_cache_ttl_secs: default_jwks_cache_ttl(),
session_ttl_secs: default_session_ttl(),
}
}
}
impl AuthConfig {
pub fn access_token_ttl_secs(&self) -> i64 {
self.access_token_ttl
.as_deref()
.and_then(crate::util::parse_duration)
.map(|d| (d.as_secs() as i64).max(1))
.unwrap_or(3600)
}
pub fn refresh_token_ttl_days(&self) -> i64 {
self.refresh_token_ttl
.as_deref()
.and_then(crate::util::parse_duration)
.map(|d| (d.as_secs() / 86400) as i64)
.map(|d| if d == 0 { 1 } else { d })
.unwrap_or(30)
}
fn is_configured(&self) -> bool {
self.jwt_secret.is_some()
|| self.jwks_url.is_some()
|| self.jwt_issuer.is_some()
|| self.jwt_audience.is_some()
}
pub fn validate(&self) -> Result<()> {
if !self.is_configured() {
return Ok(());
}
match self.jwt_algorithm {
JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512 => {
if self.jwt_secret.is_none() {
return Err(ForgeError::Config(
"auth.jwt_secret is required for HMAC algorithms (HS256, HS384, HS512). \
Set auth.jwt_secret to a secure random string, \
or switch to RS256 and provide auth.jwks_url for external identity providers."
.into(),
));
}
}
JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512 => {
if self.jwks_url.is_none() {
return Err(ForgeError::Config(
"auth.jwks_url is required for RSA algorithms (RS256, RS384, RS512). \
Set auth.jwks_url to your identity provider's JWKS endpoint, \
or switch to HS256 and provide auth.jwt_secret for symmetric signing."
.into(),
));
}
}
}
Ok(())
}
pub fn is_hmac(&self) -> bool {
matches!(
self.jwt_algorithm,
JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512
)
}
pub fn is_rsa(&self) -> bool {
matches!(
self.jwt_algorithm,
JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512
)
}
}
fn default_jwks_cache_ttl() -> u64 {
3600 }
fn default_session_ttl() -> u64 {
7 * 24 * 60 * 60 }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ObservabilityConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_otlp_endpoint")]
pub otlp_endpoint: String,
pub service_name: Option<String>,
#[serde(default = "default_true")]
pub enable_traces: bool,
#[serde(default = "default_true")]
pub enable_metrics: bool,
#[serde(default = "default_true")]
pub enable_logs: bool,
#[serde(default = "default_sampling_ratio")]
pub sampling_ratio: f64,
#[serde(default = "default_metrics_interval_secs")]
pub metrics_interval_secs: u64,
#[serde(default = "default_log_level")]
pub log_level: String,
}
impl Default for ObservabilityConfig {
fn default() -> Self {
Self {
enabled: false,
otlp_endpoint: default_otlp_endpoint(),
service_name: None,
enable_traces: true,
enable_metrics: true,
enable_logs: true,
sampling_ratio: default_sampling_ratio(),
metrics_interval_secs: default_metrics_interval_secs(),
log_level: default_log_level(),
}
}
}
impl ObservabilityConfig {
pub fn otlp_active(&self) -> bool {
self.enabled && (self.enable_traces || self.enable_metrics || self.enable_logs)
}
}
fn default_otlp_endpoint() -> String {
"http://localhost:4318".to_string()
}
pub(crate) fn default_true() -> bool {
true
}
fn default_sampling_ratio() -> f64 {
1.0
}
fn default_metrics_interval_secs() -> u64 {
15
}
fn default_log_level() -> String {
"info".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub oauth: bool,
#[serde(default = "default_mcp_path")]
pub path: String,
#[serde(default = "default_mcp_session_ttl_secs")]
pub session_ttl_secs: u64,
#[serde(default)]
pub allowed_origins: Vec<String>,
#[serde(default = "default_true")]
pub require_protocol_version_header: bool,
}
impl Default for McpConfig {
fn default() -> Self {
Self {
enabled: false,
oauth: false,
path: default_mcp_path(),
session_ttl_secs: default_mcp_session_ttl_secs(),
allowed_origins: Vec::new(),
require_protocol_version_header: default_true(),
}
}
}
impl McpConfig {
const RESERVED_PATHS: &[&str] = &[
"/health",
"/ready",
"/rpc",
"/events",
"/subscribe",
"/unsubscribe",
"/subscribe-job",
"/subscribe-workflow",
"/metrics",
];
pub fn validate(&self) -> Result<()> {
if self.path.is_empty() || !self.path.starts_with('/') {
return Err(ForgeError::Config(
"mcp.path must start with '/' (example: /mcp)".to_string(),
));
}
if self.path.contains(' ') {
return Err(ForgeError::Config(
"mcp.path cannot contain spaces".to_string(),
));
}
if Self::RESERVED_PATHS.contains(&self.path.as_str()) {
return Err(ForgeError::Config(format!(
"mcp.path '{}' conflicts with a reserved gateway route",
self.path
)));
}
if self.session_ttl_secs == 0 {
return Err(ForgeError::Config(
"mcp.session_ttl_secs must be greater than 0".to_string(),
));
}
Ok(())
}
}
fn default_mcp_path() -> String {
"/mcp".to_string()
}
fn default_mcp_session_ttl_secs() -> u64 {
60 * 60
}
#[allow(clippy::indexing_slicing)]
pub fn substitute_env_vars(content: &str) -> String {
let mut result = String::with_capacity(content.len());
let bytes = content.as_bytes();
let len = bytes.len();
let mut i = 0;
while i < len {
if i + 1 < len
&& bytes[i] == b'$'
&& bytes[i + 1] == b'{'
&& let Some(end) = content[i + 2..].find('}')
{
let inner = &content[i + 2..i + 2 + end];
let (var_name, default_value) = parse_var_with_default(inner);
if is_valid_env_var_name(var_name) {
if let Ok(value) = std::env::var(var_name) {
result.push_str(&value);
} else if let Some(default) = default_value {
result.push_str(default);
} else {
result.push_str(&content[i..i + 2 + end + 1]);
}
i += 2 + end + 1;
continue;
}
}
result.push(bytes[i] as char);
i += 1;
}
result
}
fn parse_var_with_default(inner: &str) -> (&str, Option<&str>) {
if let Some(pos) = inner.find(":-") {
return (&inner[..pos], Some(&inner[pos + 2..]));
}
if let Some(pos) = inner.find('-') {
return (&inner[..pos], Some(&inner[pos + 1..]));
}
(inner, None)
}
fn is_valid_env_var_name(name: &str) -> bool {
let first = match name.as_bytes().first() {
Some(b) => b,
None => return false,
};
(first.is_ascii_uppercase() || *first == b'_')
&& name
.bytes()
.all(|b| b.is_ascii_uppercase() || b.is_ascii_digit() || b == b'_')
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing, unsafe_code)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = ForgeConfig::default_with_database_url("postgres://localhost/test");
assert_eq!(config.gateway.port, 9081);
assert_eq!(config.node.roles.len(), 4);
assert_eq!(config.mcp.path, "/mcp");
assert!(!config.mcp.enabled);
}
#[test]
fn test_parse_minimal_config() {
let toml = r#"
[database]
url = "postgres://localhost/myapp"
"#;
let config = ForgeConfig::parse_toml(toml).unwrap();
assert_eq!(config.database.url(), "postgres://localhost/myapp");
assert_eq!(config.gateway.port, 9081);
}
#[test]
fn test_parse_full_config() {
let toml = r#"
[project]
name = "my-app"
version = "1.0.0"
[database]
url = "postgres://localhost/myapp"
pool_size = 100
[node]
roles = ["gateway", "worker"]
worker_capabilities = ["media", "general"]
[gateway]
port = 3000
grpc_port = 9001
"#;
let config = ForgeConfig::parse_toml(toml).unwrap();
assert_eq!(config.project.name, "my-app");
assert_eq!(config.database.pool_size, 100);
assert_eq!(config.node.roles.len(), 2);
assert_eq!(config.gateway.port, 3000);
}
#[test]
fn test_env_var_substitution() {
unsafe {
std::env::set_var("TEST_DB_URL", "postgres://test:test@localhost/test");
}
let toml = r#"
[database]
url = "${TEST_DB_URL}"
"#;
let config = ForgeConfig::parse_toml(toml).unwrap();
assert_eq!(config.database.url(), "postgres://test:test@localhost/test");
unsafe {
std::env::remove_var("TEST_DB_URL");
}
}
#[test]
fn test_auth_validation_no_config() {
let auth = AuthConfig::default();
assert!(auth.validate().is_ok());
}
#[test]
fn test_auth_validation_hmac_with_secret() {
let auth = AuthConfig {
jwt_secret: Some("my-secret".into()),
jwt_algorithm: JwtAlgorithm::HS256,
..Default::default()
};
assert!(auth.validate().is_ok());
}
#[test]
fn test_auth_validation_hmac_missing_secret() {
let auth = AuthConfig {
jwt_issuer: Some("my-issuer".into()),
jwt_algorithm: JwtAlgorithm::HS256,
..Default::default()
};
let result = auth.validate();
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("jwt_secret is required"));
}
#[test]
fn test_auth_validation_rsa_with_jwks() {
let auth = AuthConfig {
jwks_url: Some("https://example.com/.well-known/jwks.json".into()),
jwt_algorithm: JwtAlgorithm::RS256,
..Default::default()
};
assert!(auth.validate().is_ok());
}
#[test]
fn test_auth_validation_rsa_missing_jwks() {
let auth = AuthConfig {
jwt_issuer: Some("my-issuer".into()),
jwt_algorithm: JwtAlgorithm::RS256,
..Default::default()
};
let result = auth.validate();
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("jwks_url is required"));
}
#[test]
fn test_forge_config_validation_fails_on_empty_url() {
let toml = r#"
[database]
url = ""
"#;
let result = ForgeConfig::parse_toml(toml);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("database.url is required"));
}
#[test]
fn test_forge_config_validation_fails_on_invalid_auth() {
let toml = r#"
[database]
url = "postgres://localhost/test"
[auth]
jwt_issuer = "my-issuer"
jwt_algorithm = "RS256"
"#;
let result = ForgeConfig::parse_toml(toml);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("jwks_url is required"));
}
#[test]
fn test_env_var_default_used_when_unset() {
unsafe {
std::env::remove_var("TEST_FORGE_OTEL_UNSET");
}
let input = r#"enabled = ${TEST_FORGE_OTEL_UNSET-false}"#;
let result = substitute_env_vars(input);
assert_eq!(result, "enabled = false");
}
#[test]
fn test_env_var_default_overridden_when_set() {
unsafe {
std::env::set_var("TEST_FORGE_OTEL_SET", "true");
}
let input = r#"enabled = ${TEST_FORGE_OTEL_SET-false}"#;
let result = substitute_env_vars(input);
assert_eq!(result, "enabled = true");
unsafe {
std::env::remove_var("TEST_FORGE_OTEL_SET");
}
}
#[test]
fn test_env_var_colon_dash_default() {
unsafe {
std::env::remove_var("TEST_FORGE_ENDPOINT_UNSET");
}
let input = r#"endpoint = "${TEST_FORGE_ENDPOINT_UNSET:-http://localhost:4318}""#;
let result = substitute_env_vars(input);
assert_eq!(result, r#"endpoint = "http://localhost:4318""#);
}
#[test]
fn test_env_var_no_default_preserves_literal() {
unsafe {
std::env::remove_var("TEST_FORGE_MISSING");
}
let input = r#"url = "${TEST_FORGE_MISSING}""#;
let result = substitute_env_vars(input);
assert_eq!(result, r#"url = "${TEST_FORGE_MISSING}""#);
}
#[test]
fn test_env_var_default_empty_string() {
unsafe {
std::env::remove_var("TEST_FORGE_EMPTY_DEFAULT");
}
let input = r#"val = "${TEST_FORGE_EMPTY_DEFAULT-}""#;
let result = substitute_env_vars(input);
assert_eq!(result, r#"val = """#);
}
#[test]
fn test_observability_config_default_disabled() {
let toml = r#"
[database]
url = "postgres://localhost/test"
"#;
let config = ForgeConfig::parse_toml(toml).unwrap();
assert!(!config.observability.enabled);
assert!(!config.observability.otlp_active());
}
#[test]
fn test_observability_config_with_env_default() {
unsafe {
std::env::remove_var("TEST_OTEL_ENABLED");
}
let toml = r#"
[database]
url = "postgres://localhost/test"
[observability]
enabled = ${TEST_OTEL_ENABLED-false}
"#;
let config = ForgeConfig::parse_toml(toml).unwrap();
assert!(!config.observability.enabled);
}
#[test]
fn test_mcp_config_validation_rejects_invalid_path() {
let toml = r#"
[database]
url = "postgres://localhost/test"
[mcp]
enabled = true
path = "mcp"
"#;
let result = ForgeConfig::parse_toml(toml);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("mcp.path must start with '/'"));
}
#[test]
fn test_access_token_ttl_defaults() {
let auth = AuthConfig::default();
assert_eq!(auth.access_token_ttl_secs(), 3600);
assert_eq!(auth.refresh_token_ttl_days(), 30);
}
#[test]
fn test_access_token_ttl_custom() {
let auth = AuthConfig {
access_token_ttl: Some("15m".into()),
refresh_token_ttl: Some("7d".into()),
..Default::default()
};
assert_eq!(auth.access_token_ttl_secs(), 900);
assert_eq!(auth.refresh_token_ttl_days(), 7);
}
#[test]
fn test_access_token_ttl_minimum_enforced() {
let auth = AuthConfig {
access_token_ttl: Some("0s".into()),
..Default::default()
};
assert_eq!(auth.access_token_ttl_secs(), 1);
}
#[test]
fn test_refresh_token_ttl_minimum_enforced() {
let auth = AuthConfig {
refresh_token_ttl: Some("1h".into()),
..Default::default()
};
assert_eq!(auth.refresh_token_ttl_days(), 1);
}
#[test]
fn test_max_body_size_defaults() {
let gw = GatewayConfig::default();
assert_eq!(gw.max_body_size_bytes().unwrap(), 20 * 1024 * 1024);
}
#[test]
fn test_max_body_size_custom() {
let gw = GatewayConfig {
max_body_size: "100mb".into(),
..Default::default()
};
assert_eq!(gw.max_body_size_bytes().unwrap(), 100 * 1024 * 1024);
}
#[test]
fn test_max_body_size_invalid_errors() {
let gw = GatewayConfig {
max_body_size: "not-a-size".into(),
..Default::default()
};
assert!(gw.max_body_size_bytes().is_err());
}
#[test]
fn test_max_file_size_defaults() {
let gw = GatewayConfig::default();
assert_eq!(gw.max_file_size_bytes().unwrap(), 10 * 1024 * 1024);
}
#[test]
fn test_max_file_size_custom() {
let gw = GatewayConfig {
max_file_size: "200mb".into(),
max_body_size: "500mb".into(),
..Default::default()
};
assert_eq!(gw.max_file_size_bytes().unwrap(), 200 * 1024 * 1024);
}
#[test]
fn test_max_file_size_invalid_errors() {
let gw = GatewayConfig {
max_file_size: "nope".into(),
..Default::default()
};
assert!(gw.max_file_size_bytes().is_err());
}
#[test]
fn test_validate_rejects_file_larger_than_body() {
let toml = r#"
[database]
url = "postgres://localhost/test"
[gateway]
max_body_size = "10mb"
max_file_size = "20mb"
"#;
let err = ForgeConfig::parse_toml(toml).unwrap_err().to_string();
assert!(
err.contains("max_file_size"),
"Expected max_file_size error, got: {err}"
);
}
#[test]
fn test_mcp_config_rejects_reserved_paths() {
for reserved in McpConfig::RESERVED_PATHS {
let toml = format!(
r#"
[database]
url = "postgres://localhost/test"
[mcp]
enabled = true
path = "{reserved}"
"#
);
let result = ForgeConfig::parse_toml(&toml);
assert!(result.is_err(), "Expected {reserved} to be rejected");
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("conflicts with a reserved gateway route"),
"Wrong error for {reserved}: {err_msg}"
);
}
}
}