1use crate::router::DEFAULT_STREAM_BASE_PATH;
8use axum::http::{HeaderName, HeaderValue};
9use figment::{
10 Figment,
11 providers::{Format, Toml},
12};
13use serde::{Deserialize, Serialize};
14use std::env;
15use std::fmt;
16use std::net::{IpAddr, SocketAddr};
17use std::path::{Path, PathBuf};
18use std::time::Duration;
19use thiserror::Error;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
23#[serde(rename_all = "kebab-case")]
24pub enum StorageMode {
25 Memory,
27 #[serde(alias = "fast")]
29 FileFast,
30 #[serde(alias = "file", alias = "durable")]
32 FileDurable,
33 #[serde(alias = "redb")]
35 Acid,
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
40#[serde(rename_all = "kebab-case")]
41pub enum AcidBackend {
42 File,
44 #[serde(alias = "memory", alias = "inmemory")]
47 InMemory,
48}
49
50impl AcidBackend {
51 #[must_use]
52 pub fn as_str(self) -> &'static str {
53 match self {
54 Self::File => "file",
55 Self::InMemory => "in-memory",
56 }
57 }
58}
59
60impl fmt::Display for StorageMode {
61 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
62 formatter.write_str(self.as_str())
63 }
64}
65
66impl fmt::Display for TransportMode {
67 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
68 formatter.write_str(self.as_str())
69 }
70}
71
72impl fmt::Display for HttpVersion {
73 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
74 formatter.write_str(self.as_str())
75 }
76}
77
78impl fmt::Display for AlpnProtocol {
79 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
80 formatter.write_str(self.as_str())
81 }
82}
83
84impl StorageMode {
85 #[must_use]
86 pub fn as_str(self) -> &'static str {
87 match self {
88 Self::Memory => "memory",
89 Self::FileFast => "file-fast",
90 Self::FileDurable => "file-durable",
91 Self::Acid => "acid",
92 }
93 }
94
95 #[must_use]
96 pub fn uses_file_backend(self) -> bool {
97 matches!(self, Self::FileFast | Self::FileDurable)
98 }
99
100 #[must_use]
101 pub fn sync_on_append(self) -> bool {
102 matches!(self, Self::FileDurable)
103 }
104}
105
106#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
108#[serde(rename_all = "kebab-case")]
109pub enum TransportMode {
110 Http,
112 Tls,
114 Mtls,
116}
117
118impl TransportMode {
119 #[must_use]
120 pub fn as_str(self) -> &'static str {
121 match self {
122 Self::Http => "http",
123 Self::Tls => "tls",
124 Self::Mtls => "mtls",
125 }
126 }
127
128 #[must_use]
129 pub fn uses_tls(self) -> bool {
130 matches!(self, Self::Tls | Self::Mtls)
131 }
132}
133
134#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
136pub enum HttpVersion {
137 #[serde(
139 rename = "http1",
140 alias = "1.1",
141 alias = "http1.1",
142 alias = "http/1.1",
143 alias = "h1"
144 )]
145 Http1,
146 #[serde(rename = "http2", alias = "2", alias = "h2")]
148 Http2,
149}
150
151impl HttpVersion {
152 #[must_use]
153 pub fn as_str(self) -> &'static str {
154 match self {
155 Self::Http1 => "http1",
156 Self::Http2 => "http2",
157 }
158 }
159}
160
161#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
163pub enum TlsVersion {
164 #[serde(rename = "1.2", alias = "tls1.2", alias = "tls-1.2")]
166 V1_2,
167 #[serde(rename = "1.3", alias = "tls1.3", alias = "tls-1.3")]
169 V1_3,
170}
171
172impl TlsVersion {
173 #[must_use]
174 pub fn as_str(self) -> &'static str {
175 match self {
176 Self::V1_2 => "1.2",
177 Self::V1_3 => "1.3",
178 }
179 }
180}
181
182#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
184pub enum AlpnProtocol {
185 #[serde(rename = "http/1.1", alias = "http1", alias = "h1")]
187 Http1_1,
188 #[serde(rename = "h2", alias = "http2")]
190 H2,
191}
192
193impl AlpnProtocol {
194 #[must_use]
195 pub fn as_str(self) -> &'static str {
196 match self {
197 Self::Http1_1 => "http/1.1",
198 Self::H2 => "h2",
199 }
200 }
201}
202
203#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
205pub enum ForwardedHeadersMode {
206 #[serde(rename = "none")]
208 None,
209 #[serde(rename = "x-forwarded", alias = "xforwarded")]
211 XForwarded,
212 #[serde(rename = "forwarded")]
214 Forwarded,
215}
216
217#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
219#[serde(rename_all = "kebab-case")]
220pub enum ProxyIdentityMode {
221 None,
223 Header,
225}
226
227#[derive(Debug, Clone, PartialEq, Eq)]
229pub enum DeploymentProfile {
230 Default,
231 Dev,
232 Prod,
233 ProdTls,
234 ProdMtls,
235 Named(String),
236}
237
238impl DeploymentProfile {
239 #[must_use]
240 pub fn as_str(&self) -> &str {
241 match self {
242 Self::Default => "default",
243 Self::Dev => "dev",
244 Self::Prod => "prod",
245 Self::ProdTls => "prod-tls",
246 Self::ProdMtls => "prod-mtls",
247 Self::Named(name) => name.as_str(),
248 }
249 }
250}
251
252impl From<&str> for DeploymentProfile {
253 fn from(raw: &str) -> Self {
254 match raw.trim().to_ascii_lowercase().as_str() {
255 "" | "default" => Self::Default,
256 "dev" => Self::Dev,
257 "prod" => Self::Prod,
258 "prod-tls" => Self::ProdTls,
259 "prod-mtls" => Self::ProdMtls,
260 other => Self::Named(other.to_string()),
261 }
262 }
263}
264
265impl From<String> for DeploymentProfile {
266 fn from(raw: String) -> Self {
267 Self::from(raw.as_str())
268 }
269}
270
271#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
273pub struct Config {
274 pub server: ServerConfig,
276 pub limits: LimitsConfig,
278 pub http: HttpConfig,
280 pub storage: StorageConfig,
282 pub transport: TransportConfig,
284 pub proxy: ProxyConfig,
286 pub observability: ObservabilityConfig,
288}
289
290#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
292pub struct ServerConfig {
293 pub bind_address: String,
295}
296
297#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
299pub struct LimitsConfig {
300 pub max_memory_bytes: u64,
302 pub max_stream_bytes: u64,
304 pub max_stream_name_bytes: usize,
306 pub max_stream_name_segments: usize,
308}
309
310#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
312pub struct HttpConfig {
313 pub cors_origins: String,
315 pub stream_base_path: String,
317 pub allow_wildcard_cors: bool,
320}
321
322#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
324pub struct StorageConfig {
325 pub mode: StorageMode,
327 pub data_dir: String,
329 pub acid_shard_count: usize,
331 pub acid_backend: AcidBackend,
333}
334
335#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
337pub struct TransportConfig {
338 pub mode: TransportMode,
340 pub http: TransportHttpConfig,
342 pub tls: TransportTlsConfig,
344 pub connection: TransportConnectionConfig,
346}
347
348#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
350pub struct TransportHttpConfig {
351 pub versions: Vec<HttpVersion>,
353}
354
355#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
357pub struct TransportTlsConfig {
358 pub cert_path: Option<String>,
360 pub key_path: Option<String>,
362 pub client_ca_path: Option<String>,
364 pub min_version: TlsVersion,
366 pub max_version: TlsVersion,
368 pub alpn_protocols: Vec<AlpnProtocol>,
370}
371
372impl TransportTlsConfig {
373 #[must_use]
374 pub fn has_server_credentials(&self) -> bool {
375 self.cert_path.is_some() && self.key_path.is_some()
376 }
377}
378
379#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
381pub struct TransportConnectionConfig {
382 pub long_poll_timeout_secs: u64,
384 pub sse_reconnect_interval_secs: u64,
386}
387
388#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
390pub struct ProxyConfig {
391 pub enabled: bool,
393 pub forwarded_headers: ForwardedHeadersMode,
395 pub trusted_proxies: Vec<String>,
397 pub identity: ProxyIdentityConfig,
399}
400
401#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
403pub struct ProxyIdentityConfig {
404 pub mode: ProxyIdentityMode,
406 pub header_name: Option<String>,
408 pub require_tls: bool,
410}
411
412#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
414pub struct ObservabilityConfig {
415 pub rust_log: String,
417}
418
419#[derive(Debug, Clone)]
420pub struct ConfigLoadOptions {
421 pub config_dir: PathBuf,
423 pub profile: DeploymentProfile,
425 pub config_override: Option<PathBuf>,
427}
428
429impl Default for ConfigLoadOptions {
430 fn default() -> Self {
431 Self {
432 config_dir: PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("config"),
433 profile: DeploymentProfile::Default,
434 config_override: None,
435 }
436 }
437}
438
439#[derive(Debug, Clone, PartialEq, Eq, Error)]
441pub enum ConfigLoadError {
442 #[error("config override file not found: '{path}'")]
443 OverrideFileNotFound { path: PathBuf },
444 #[error("failed to parse TOML config: {message}")]
445 TomlParse { message: String },
446 #[error("invalid {input_source} value for {key}: '{value}' ({reason})")]
447 InvalidValue {
448 input_source: &'static str,
449 key: &'static str,
450 value: String,
451 reason: String,
452 },
453}
454
455#[derive(Debug, Clone, PartialEq, Eq, Error)]
457pub enum ConfigValidationError {
458 #[error("server.bind_address is invalid: '{value}' ({reason})")]
459 InvalidBindAddress { value: String, reason: String },
460 #[error("http.stream_base_path is invalid: '{value}' ({reason})")]
461 InvalidStreamBasePath { value: String, reason: String },
462 #[error("http.cors_origins contains an empty origin entry")]
463 EmptyCorsOrigin,
464 #[error("http.cors_origins entry is invalid: '{value}'")]
465 InvalidCorsOrigin { value: String },
466 #[error("limits.max_memory_bytes must be at least 1")]
467 MaxMemoryBytesTooSmall,
468 #[error("limits.max_stream_bytes must be at least 1")]
469 MaxStreamBytesTooSmall,
470 #[error("limits.max_stream_name_bytes must be at least 1")]
471 MaxStreamNameBytesTooSmall,
472 #[error("limits.max_stream_name_segments must be at least 1")]
473 MaxStreamNameSegmentsTooSmall,
474 #[error("storage.data_dir must be a non-empty path when storage.mode is '{mode}'")]
475 EmptyStorageDataDir { mode: StorageMode },
476 #[error(
477 "storage.acid_shard_count must be a power of two in 1..=256 when storage.mode is 'acid'"
478 )]
479 InvalidAcidShardCount,
480 #[error("transport.connection.long_poll_timeout_secs must be at least 1")]
481 LongPollTimeoutTooSmall,
482 #[error("transport.http.versions must include at least one version")]
483 EmptyHttpVersions,
484 #[error("transport.mode='http' does not support transport.http.versions containing http2")]
485 HttpModeDoesNotSupportHttp2,
486 #[error("transport.tls.min_version must be less than or equal to transport.tls.max_version")]
487 InvalidTlsVersionRange,
488 #[error("transport.mode='{mode}' requires transport.tls.{field}")]
489 MissingTlsField {
490 mode: TransportMode,
491 field: &'static str,
492 },
493 #[error("transport.mode='http' cannot be combined with transport.tls.{field}")]
494 HttpModeDisallowsTlsField { field: &'static str },
495 #[error("transport.mode='tls' cannot be combined with transport.tls.client_ca_path")]
496 ClientCaRequiresMtls,
497 #[error("transport.tls.{field} must be a non-empty path when set")]
498 EmptyPath { field: &'static str },
499 #[error(
500 "transport.http.versions includes '{version}', but transport.tls.alpn_protocols is missing '{alpn}'"
501 )]
502 MissingAlpnProtocol {
503 version: HttpVersion,
504 alpn: AlpnProtocol,
505 },
506 #[error(
507 "transport.tls.alpn_protocols includes '{alpn}', but transport.http.versions does not enable the matching HTTP version"
508 )]
509 UnexpectedAlpnProtocol { alpn: AlpnProtocol },
510 #[error(
511 "proxy.enabled=true requires proxy.forwarded_headers to be set to 'x-forwarded' or 'forwarded'"
512 )]
513 ProxyEnabledRequiresForwardedHeaders,
514 #[error("proxy.enabled=true requires at least one entry in proxy.trusted_proxies")]
515 ProxyEnabledRequiresTrustedProxies,
516 #[error("proxy.enabled=false cannot be combined with proxy.trusted_proxies")]
517 ProxyDisabledDisallowsTrustedProxies,
518 #[error("proxy.enabled=false cannot be combined with proxy.forwarded_headers='{mode:?}'")]
519 ProxyDisabledDisallowsForwardedHeaders { mode: ForwardedHeadersMode },
520 #[error("proxy.enabled=false cannot be combined with proxy.identity.mode='{mode:?}'")]
521 ProxyDisabledDisallowsIdentityMode { mode: ProxyIdentityMode },
522 #[error("proxy.enabled=false cannot be combined with proxy.identity.header_name")]
523 ProxyDisabledDisallowsIdentityHeader,
524 #[error("proxy.trusted_proxies entry is invalid: '{value}'")]
525 InvalidTrustedProxy { value: String },
526 #[error("proxy.identity.mode='header' requires proxy.identity.header_name")]
527 HeaderIdentityRequiresHeaderName,
528 #[error("proxy.identity.mode='header' requires transport.mode='mtls'")]
529 HeaderIdentityRequiresMtls,
530 #[error("proxy.identity.mode='none' cannot be combined with proxy.identity.header_name")]
531 IdentityHeaderRequiresHeaderMode,
532 #[error("proxy.identity.header_name is invalid: '{value}'")]
533 InvalidIdentityHeaderName { value: String },
534 #[error(
535 "http.cors_origins='*' is not allowed for the '{profile}' deployment profile; \
536 set http.allow_wildcard_cors=true to override, or specify explicit origins"
537 )]
538 WildcardCorsOriginsProd { profile: String },
539}
540
541#[derive(Debug, Deserialize, Default)]
542#[serde(default)]
543struct ConfigPatch {
544 server: ServerConfigPatch,
545 limits: LimitsConfigPatch,
546 http: HttpConfigPatch,
547 storage: StorageConfigPatch,
548 transport: TransportConfigPatch,
549 proxy: ProxyConfigPatch,
550 observability: ObservabilityConfigPatch,
551 tls: LegacyTlsPatch,
552 log: LegacyLogPatch,
553}
554
555#[derive(Debug, Deserialize, Default)]
556#[serde(default)]
557struct ServerConfigPatch {
558 bind_address: Option<String>,
559 port: Option<u16>,
560 long_poll_timeout_secs: Option<u64>,
561 sse_reconnect_interval_secs: Option<u64>,
562}
563
564#[derive(Debug, Deserialize, Default)]
565#[serde(default)]
566#[allow(clippy::struct_field_names)]
567struct LimitsConfigPatch {
568 max_memory_bytes: Option<u64>,
569 max_stream_bytes: Option<u64>,
570 max_stream_name_bytes: Option<usize>,
571 max_stream_name_segments: Option<usize>,
572}
573
574#[derive(Debug, Deserialize, Default)]
575#[serde(default)]
576struct HttpConfigPatch {
577 cors_origins: Option<String>,
578 stream_base_path: Option<String>,
579 allow_wildcard_cors: Option<bool>,
580}
581
582#[derive(Debug, Deserialize, Default)]
583#[serde(default)]
584struct StorageConfigPatch {
585 mode: Option<StorageMode>,
586 data_dir: Option<String>,
587 acid_shard_count: Option<usize>,
588 acid_backend: Option<AcidBackend>,
589}
590
591#[derive(Debug, Deserialize, Default)]
592#[serde(default)]
593struct TransportConfigPatch {
594 mode: Option<TransportMode>,
595 http: TransportHttpConfigPatch,
596 tls: TransportTlsConfigPatch,
597 connection: TransportConnectionConfigPatch,
598}
599
600#[derive(Debug, Deserialize, Default)]
601#[serde(default)]
602struct TransportHttpConfigPatch {
603 versions: Option<Vec<HttpVersion>>,
604}
605
606#[derive(Debug, Deserialize, Default)]
607#[serde(default)]
608struct TransportTlsConfigPatch {
609 cert_path: Option<String>,
610 key_path: Option<String>,
611 client_ca_path: Option<String>,
612 min_version: Option<TlsVersion>,
613 max_version: Option<TlsVersion>,
614 alpn_protocols: Option<Vec<AlpnProtocol>>,
615}
616
617#[derive(Debug, Deserialize, Default)]
618#[serde(default)]
619struct TransportConnectionConfigPatch {
620 long_poll_timeout_secs: Option<u64>,
621 sse_reconnect_interval_secs: Option<u64>,
622}
623
624#[derive(Debug, Deserialize, Default)]
625#[serde(default)]
626struct ProxyConfigPatch {
627 enabled: Option<bool>,
628 forwarded_headers: Option<ForwardedHeadersMode>,
629 trusted_proxies: Option<Vec<String>>,
630 identity: ProxyIdentityConfigPatch,
631}
632
633#[derive(Debug, Deserialize, Default)]
634#[serde(default)]
635struct ProxyIdentityConfigPatch {
636 mode: Option<ProxyIdentityMode>,
637 header_name: Option<String>,
638 require_tls: Option<bool>,
639}
640
641#[derive(Debug, Deserialize, Default)]
642#[serde(default)]
643struct ObservabilityConfigPatch {
644 rust_log: Option<String>,
645}
646
647#[derive(Debug, Deserialize, Default)]
648#[serde(default)]
649struct LegacyTlsPatch {
650 cert_path: Option<String>,
651 key_path: Option<String>,
652}
653
654#[derive(Debug, Deserialize, Default)]
655#[serde(default)]
656struct LegacyLogPatch {
657 rust_log: Option<String>,
658}
659
660#[derive(Debug, Default)]
661struct MergeContext {
662 explicit_transport_mode: bool,
663 legacy_tls_seen: bool,
664}
665
666impl Config {
667 pub fn from_env() -> Result<Self, ConfigLoadError> {
675 let mut config = Self::default();
676 let mut ctx = MergeContext::default();
677 config.apply_env_overrides(&|key| env::var(key).ok(), &mut ctx)?;
678 ctx.finalize(&mut config);
679 Ok(config)
680 }
681
682 pub fn from_sources(options: &ConfigLoadOptions) -> Result<Self, ConfigLoadError> {
698 let get = |key: &str| env::var(key).ok();
699 Self::from_sources_with_lookup(options, &get)
700 }
701
702 fn from_sources_with_lookup(
703 options: &ConfigLoadOptions,
704 get: &impl Fn(&str) -> Option<String>,
705 ) -> Result<Self, ConfigLoadError> {
706 let mut config = Self::default();
707 let mut ctx = MergeContext::default();
708
709 if let Some(profile_patch) = built_in_profile_patch(&options.profile) {
710 if profile_patch.transport.mode.is_some() {
711 ctx.explicit_transport_mode = true;
712 }
713 config.apply_patch(profile_patch, &mut ctx);
714 }
715
716 let default_path = options.config_dir.join("default.toml");
717 if default_path.is_file() {
718 let patch = extract_toml_patch(&default_path)?;
719 config.apply_patch(patch, &mut ctx);
720 }
721
722 let profile_path = options
723 .config_dir
724 .join(format!("{}.toml", options.profile.as_str()));
725 if profile_path.is_file() {
726 let patch = extract_toml_patch(&profile_path)?;
727 config.apply_patch(patch, &mut ctx);
728 }
729
730 let local_path = options.config_dir.join("local.toml");
731 if local_path.is_file() {
732 let patch = extract_toml_patch(&local_path)?;
733 config.apply_patch(patch, &mut ctx);
734 }
735
736 if let Some(override_path) = &options.config_override {
737 if !override_path.is_file() {
738 return Err(ConfigLoadError::OverrideFileNotFound {
739 path: override_path.clone(),
740 });
741 }
742 let patch = extract_toml_patch(override_path)?;
743 config.apply_patch(patch, &mut ctx);
744 }
745
746 config.apply_env_overrides(get, &mut ctx)?;
747 ctx.finalize(&mut config);
748 Ok(config)
749 }
750
751 fn apply_patch(&mut self, patch: ConfigPatch, ctx: &mut MergeContext) {
752 self.apply_server_patch(&patch.server);
753 self.apply_limits_patch(&patch.limits);
754 self.apply_http_patch(&patch.http);
755 self.apply_storage_patch(&patch.storage);
756 self.apply_transport_patch(&patch.transport, &patch.tls, &patch.server, ctx);
757 self.apply_proxy_patch(&patch.proxy);
758
759 let rust_log = patch.observability.rust_log.or(patch.log.rust_log);
760 if let Some(rust_log) = rust_log {
761 self.observability.rust_log = rust_log;
762 }
763 }
764
765 fn apply_server_patch(&mut self, patch: &ServerConfigPatch) {
766 if let Some(bind_address) = &patch.bind_address {
767 self.server.bind_address.clone_from(bind_address);
768 } else if let Some(port) = patch.port {
769 self.server.bind_address = format!("0.0.0.0:{port}");
770 }
771 }
772
773 fn apply_limits_patch(&mut self, patch: &LimitsConfigPatch) {
774 if let Some(max_memory_bytes) = patch.max_memory_bytes {
775 self.limits.max_memory_bytes = max_memory_bytes;
776 }
777 if let Some(max_stream_bytes) = patch.max_stream_bytes {
778 self.limits.max_stream_bytes = max_stream_bytes;
779 }
780 if let Some(max_stream_name_bytes) = patch.max_stream_name_bytes {
781 self.limits.max_stream_name_bytes = max_stream_name_bytes;
782 }
783 if let Some(max_stream_name_segments) = patch.max_stream_name_segments {
784 self.limits.max_stream_name_segments = max_stream_name_segments;
785 }
786 }
787
788 fn apply_http_patch(&mut self, patch: &HttpConfigPatch) {
789 if let Some(cors_origins) = &patch.cors_origins {
790 self.http.cors_origins.clone_from(cors_origins);
791 }
792 if let Some(stream_base_path) = &patch.stream_base_path {
793 self.http.stream_base_path.clone_from(stream_base_path);
794 }
795 if let Some(allow_wildcard_cors) = patch.allow_wildcard_cors {
796 self.http.allow_wildcard_cors = allow_wildcard_cors;
797 }
798 }
799
800 fn apply_storage_patch(&mut self, patch: &StorageConfigPatch) {
801 if let Some(mode) = patch.mode {
802 self.storage.mode = mode;
803 }
804 if let Some(data_dir) = &patch.data_dir {
805 self.storage.data_dir.clone_from(data_dir);
806 }
807 if let Some(acid_shard_count) = patch.acid_shard_count {
808 self.storage.acid_shard_count = acid_shard_count;
809 }
810 if let Some(acid_backend) = patch.acid_backend {
811 self.storage.acid_backend = acid_backend;
812 }
813 }
814
815 fn apply_transport_patch(
816 &mut self,
817 patch: &TransportConfigPatch,
818 legacy_tls: &LegacyTlsPatch,
819 server_patch: &ServerConfigPatch,
820 ctx: &mut MergeContext,
821 ) {
822 if let Some(mode) = patch.mode {
823 self.transport.mode = mode;
824 ctx.explicit_transport_mode = true;
825 }
826 if let Some(versions) = &patch.http.versions {
827 self.transport.http.versions.clone_from(versions);
828 self.transport.tls.alpn_protocols =
829 default_alpn_protocols(&self.transport.http.versions);
830 }
831
832 let legacy_tls_cert_path = &legacy_tls.cert_path;
833 let legacy_tls_key_path = &legacy_tls.key_path;
834 let saw_legacy_tls = legacy_tls_cert_path.is_some() || legacy_tls_key_path.is_some();
835 let tls_cert_path = patch
836 .tls
837 .cert_path
838 .as_ref()
839 .or(legacy_tls_cert_path.as_ref());
840 let tls_key_path = patch.tls.key_path.as_ref().or(legacy_tls_key_path.as_ref());
841 if tls_cert_path.is_some() || tls_key_path.is_some() {
842 ctx.legacy_tls_seen |= saw_legacy_tls;
843 }
844 if let Some(cert_path) = tls_cert_path {
845 self.transport.tls.cert_path = Some(cert_path.clone());
846 }
847 if let Some(key_path) = tls_key_path {
848 self.transport.tls.key_path = Some(key_path.clone());
849 }
850 if let Some(client_ca_path) = &patch.tls.client_ca_path {
851 self.transport.tls.client_ca_path = Some(client_ca_path.clone());
852 }
853 if let Some(min_version) = patch.tls.min_version {
854 self.transport.tls.min_version = min_version;
855 }
856 if let Some(max_version) = patch.tls.max_version {
857 self.transport.tls.max_version = max_version;
858 }
859 if let Some(alpn_protocols) = &patch.tls.alpn_protocols {
860 self.transport.tls.alpn_protocols.clone_from(alpn_protocols);
861 }
862
863 let long_poll_timeout_secs = patch
864 .connection
865 .long_poll_timeout_secs
866 .or(server_patch.long_poll_timeout_secs);
867 if let Some(long_poll_timeout_secs) = long_poll_timeout_secs {
868 self.transport.connection.long_poll_timeout_secs = long_poll_timeout_secs;
869 }
870
871 let sse_reconnect_interval_secs = patch
872 .connection
873 .sse_reconnect_interval_secs
874 .or(server_patch.sse_reconnect_interval_secs);
875 if let Some(sse_reconnect_interval_secs) = sse_reconnect_interval_secs {
876 self.transport.connection.sse_reconnect_interval_secs = sse_reconnect_interval_secs;
877 }
878 }
879
880 fn apply_proxy_patch(&mut self, patch: &ProxyConfigPatch) {
881 if let Some(enabled) = patch.enabled {
882 self.proxy.enabled = enabled;
883 }
884 if let Some(forwarded_headers) = patch.forwarded_headers {
885 self.proxy.forwarded_headers = forwarded_headers;
886 }
887 if let Some(trusted_proxies) = &patch.trusted_proxies {
888 self.proxy.trusted_proxies.clone_from(trusted_proxies);
889 }
890 if let Some(mode) = patch.identity.mode {
891 self.proxy.identity.mode = mode;
892 }
893 if let Some(header_name) = &patch.identity.header_name {
894 self.proxy.identity.header_name = Some(header_name.clone());
895 }
896 if let Some(require_tls) = patch.identity.require_tls {
897 self.proxy.identity.require_tls = require_tls;
898 }
899 }
900
901 fn apply_env_overrides(
903 &mut self,
904 get: &impl Fn(&str) -> Option<String>,
905 ctx: &mut MergeContext,
906 ) -> Result<(), ConfigLoadError> {
907 self.apply_server_env(get)?;
908 self.apply_limits_env(get)?;
909 self.apply_http_env(get)?;
910 self.apply_storage_env(get)?;
911 self.apply_transport_env(get, ctx)?;
912 self.apply_proxy_env(get)?;
913
914 if let Some(rust_log) =
915 get("DS_OBSERVABILITY__RUST_LOG").or_else(|| get("DS_LOG__RUST_LOG"))
916 {
917 self.observability.rust_log = rust_log;
918 }
919
920 Ok(())
921 }
922
923 fn apply_server_env(
924 &mut self,
925 get: &impl Fn(&str) -> Option<String>,
926 ) -> Result<(), ConfigLoadError> {
927 if let Some(bind_address) = get("DS_SERVER__BIND_ADDRESS") {
928 self.server.bind_address = bind_address;
929 } else if let Some(port) = parse_env::<u16>(get, "DS_SERVER__PORT")? {
930 self.server.bind_address = format!("0.0.0.0:{port}");
931 }
932
933 if let Some(long_poll_timeout_secs) =
934 parse_env::<u64>(get, "DS_TRANSPORT__CONNECTION__LONG_POLL_TIMEOUT_SECS")?
935 .or(parse_env::<u64>(get, "DS_SERVER__LONG_POLL_TIMEOUT_SECS")?)
936 {
937 self.transport.connection.long_poll_timeout_secs = long_poll_timeout_secs;
938 }
939
940 if let Some(sse_reconnect_interval_secs) =
941 parse_env::<u64>(get, "DS_TRANSPORT__CONNECTION__SSE_RECONNECT_INTERVAL_SECS")?.or(
942 parse_env::<u64>(get, "DS_SERVER__SSE_RECONNECT_INTERVAL_SECS")?,
943 )
944 {
945 self.transport.connection.sse_reconnect_interval_secs = sse_reconnect_interval_secs;
946 }
947
948 Ok(())
949 }
950
951 fn apply_limits_env(
952 &mut self,
953 get: &impl Fn(&str) -> Option<String>,
954 ) -> Result<(), ConfigLoadError> {
955 if let Some(max_memory_bytes) = parse_env::<u64>(get, "DS_LIMITS__MAX_MEMORY_BYTES")? {
956 self.limits.max_memory_bytes = max_memory_bytes;
957 }
958 if let Some(max_stream_bytes) = parse_env::<u64>(get, "DS_LIMITS__MAX_STREAM_BYTES")? {
959 self.limits.max_stream_bytes = max_stream_bytes;
960 }
961 if let Some(max_stream_name_bytes) =
962 parse_env::<usize>(get, "DS_LIMITS__MAX_STREAM_NAME_BYTES")?
963 {
964 self.limits.max_stream_name_bytes = max_stream_name_bytes;
965 }
966 if let Some(max_stream_name_segments) =
967 parse_env::<usize>(get, "DS_LIMITS__MAX_STREAM_NAME_SEGMENTS")?
968 {
969 self.limits.max_stream_name_segments = max_stream_name_segments;
970 }
971 Ok(())
972 }
973
974 fn apply_http_env(
975 &mut self,
976 get: &impl Fn(&str) -> Option<String>,
977 ) -> Result<(), ConfigLoadError> {
978 if let Some(cors_origins) = get("DS_HTTP__CORS_ORIGINS") {
979 self.http.cors_origins = cors_origins;
980 }
981 if let Some(stream_base_path) = get("DS_HTTP__STREAM_BASE_PATH") {
982 self.http.stream_base_path = stream_base_path;
983 }
984 if let Some(allow_wildcard_cors) = parse_env::<bool>(get, "DS_HTTP__ALLOW_WILDCARD_CORS")? {
985 self.http.allow_wildcard_cors = allow_wildcard_cors;
986 }
987 Ok(())
988 }
989
990 fn apply_storage_env(
991 &mut self,
992 get: &impl Fn(&str) -> Option<String>,
993 ) -> Result<(), ConfigLoadError> {
994 if let Some(storage_mode) = parse_env_with(get, "DS_STORAGE__MODE", parse_storage_mode_env)?
995 {
996 self.storage.mode = storage_mode;
997 }
998 if let Some(data_dir) = get("DS_STORAGE__DATA_DIR") {
999 self.storage.data_dir = data_dir;
1000 }
1001 if let Some(acid_shard_count) = parse_env::<usize>(get, "DS_STORAGE__ACID_SHARD_COUNT")? {
1002 self.storage.acid_shard_count = acid_shard_count;
1003 }
1004 if let Some(acid_backend) =
1005 parse_env_with(get, "DS_STORAGE__ACID_BACKEND", parse_acid_backend_env)?
1006 {
1007 self.storage.acid_backend = acid_backend;
1008 }
1009 Ok(())
1010 }
1011
1012 fn apply_transport_env(
1013 &mut self,
1014 get: &impl Fn(&str) -> Option<String>,
1015 ctx: &mut MergeContext,
1016 ) -> Result<(), ConfigLoadError> {
1017 if let Some(mode) = parse_env_with(get, "DS_TRANSPORT__MODE", parse_transport_mode_env)? {
1018 self.transport.mode = mode;
1019 ctx.explicit_transport_mode = true;
1020 }
1021 if let Some(versions) =
1022 parse_env_list_with(get, "DS_TRANSPORT__HTTP__VERSIONS", parse_http_version_env)?
1023 {
1024 self.transport.http.versions = versions;
1025 self.transport.tls.alpn_protocols =
1026 default_alpn_protocols(&self.transport.http.versions);
1027 }
1028
1029 let tls_cert_path =
1030 get("DS_TRANSPORT__TLS__CERT_PATH").or_else(|| get("DS_TLS__CERT_PATH"));
1031 let tls_key_path = get("DS_TRANSPORT__TLS__KEY_PATH").or_else(|| get("DS_TLS__KEY_PATH"));
1032 if get("DS_TLS__CERT_PATH").is_some() || get("DS_TLS__KEY_PATH").is_some() {
1033 ctx.legacy_tls_seen = true;
1034 }
1035 if let Some(cert_path) = tls_cert_path {
1036 self.transport.tls.cert_path = Some(cert_path);
1037 }
1038 if let Some(key_path) = tls_key_path {
1039 self.transport.tls.key_path = Some(key_path);
1040 }
1041 if let Some(client_ca_path) = get("DS_TRANSPORT__TLS__CLIENT_CA_PATH") {
1042 self.transport.tls.client_ca_path = Some(client_ca_path);
1043 }
1044 if let Some(min_version) =
1045 parse_env_with(get, "DS_TRANSPORT__TLS__MIN_VERSION", parse_tls_version_env)?
1046 {
1047 self.transport.tls.min_version = min_version;
1048 }
1049 if let Some(max_version) =
1050 parse_env_with(get, "DS_TRANSPORT__TLS__MAX_VERSION", parse_tls_version_env)?
1051 {
1052 self.transport.tls.max_version = max_version;
1053 }
1054 if let Some(alpn_protocols) = parse_env_list_with(
1055 get,
1056 "DS_TRANSPORT__TLS__ALPN_PROTOCOLS",
1057 parse_alpn_protocol_env,
1058 )? {
1059 self.transport.tls.alpn_protocols = alpn_protocols;
1060 }
1061 Ok(())
1062 }
1063
1064 fn apply_proxy_env(
1065 &mut self,
1066 get: &impl Fn(&str) -> Option<String>,
1067 ) -> Result<(), ConfigLoadError> {
1068 if let Some(enabled) = parse_env::<bool>(get, "DS_PROXY__ENABLED")? {
1069 self.proxy.enabled = enabled;
1070 }
1071 if let Some(forwarded_headers) = parse_env_with(
1072 get,
1073 "DS_PROXY__FORWARDED_HEADERS",
1074 parse_forwarded_headers_mode_env,
1075 )? {
1076 self.proxy.forwarded_headers = forwarded_headers;
1077 }
1078 if let Some(trusted_proxies) = parse_env_csv_strings(get, "DS_PROXY__TRUSTED_PROXIES")? {
1079 self.proxy.trusted_proxies = trusted_proxies;
1080 }
1081 if let Some(mode) = parse_env_with(
1082 get,
1083 "DS_PROXY__IDENTITY__MODE",
1084 parse_proxy_identity_mode_env,
1085 )? {
1086 self.proxy.identity.mode = mode;
1087 }
1088 if let Some(header_name) = get("DS_PROXY__IDENTITY__HEADER_NAME") {
1089 self.proxy.identity.header_name = Some(header_name);
1090 }
1091 if let Some(require_tls) = parse_env::<bool>(get, "DS_PROXY__IDENTITY__REQUIRE_TLS")? {
1092 self.proxy.identity.require_tls = require_tls;
1093 }
1094 Ok(())
1095 }
1096
1097 pub fn validate(&self) -> Result<(), ConfigValidationError> {
1103 validate_socket_addr(&self.server.bind_address)?;
1104 validate_cors_origins(&self.http.cors_origins)?;
1105 validate_stream_base_path(&self.http.stream_base_path)?;
1106 self.validate_limits()?;
1107 self.validate_storage()?;
1108 self.validate_transport()?;
1109 validate_proxy(self)?;
1110 Ok(())
1111 }
1112
1113 fn validate_limits(&self) -> Result<(), ConfigValidationError> {
1114 if self.limits.max_memory_bytes == 0 {
1115 return Err(ConfigValidationError::MaxMemoryBytesTooSmall);
1116 }
1117 if self.limits.max_stream_bytes == 0 {
1118 return Err(ConfigValidationError::MaxStreamBytesTooSmall);
1119 }
1120 if self.limits.max_stream_name_bytes == 0 {
1121 return Err(ConfigValidationError::MaxStreamNameBytesTooSmall);
1122 }
1123 if self.limits.max_stream_name_segments == 0 {
1124 return Err(ConfigValidationError::MaxStreamNameSegmentsTooSmall);
1125 }
1126 Ok(())
1127 }
1128
1129 fn validate_storage(&self) -> Result<(), ConfigValidationError> {
1130 if self.storage.mode != StorageMode::Memory && self.storage.data_dir.trim().is_empty() {
1131 return Err(ConfigValidationError::EmptyStorageDataDir {
1132 mode: self.storage.mode,
1133 });
1134 }
1135 if self.storage.mode == StorageMode::Acid
1136 && !valid_acid_shard_count(self.storage.acid_shard_count)
1137 {
1138 return Err(ConfigValidationError::InvalidAcidShardCount);
1139 }
1140 Ok(())
1141 }
1142
1143 fn validate_transport(&self) -> Result<(), ConfigValidationError> {
1144 if self.transport.connection.long_poll_timeout_secs == 0 {
1145 return Err(ConfigValidationError::LongPollTimeoutTooSmall);
1146 }
1147
1148 if self.transport.http.versions.is_empty() {
1149 return Err(ConfigValidationError::EmptyHttpVersions);
1150 }
1151 if self.transport.mode == TransportMode::Http
1152 && self.transport.http.versions.contains(&HttpVersion::Http2)
1153 {
1154 return Err(ConfigValidationError::HttpModeDoesNotSupportHttp2);
1155 }
1156 if self.transport.tls.min_version > self.transport.tls.max_version {
1157 return Err(ConfigValidationError::InvalidTlsVersionRange);
1158 }
1159
1160 for (field, value) in [
1161 ("cert_path", self.transport.tls.cert_path.as_deref()),
1162 ("key_path", self.transport.tls.key_path.as_deref()),
1163 (
1164 "client_ca_path",
1165 self.transport.tls.client_ca_path.as_deref(),
1166 ),
1167 ] {
1168 if matches!(value, Some(path) if path.trim().is_empty()) {
1169 return Err(ConfigValidationError::EmptyPath { field });
1170 }
1171 }
1172
1173 self.validate_transport_mode_tls()?;
1174 self.validate_alpn_protocols()?;
1175 Ok(())
1176 }
1177
1178 fn validate_transport_mode_tls(&self) -> Result<(), ConfigValidationError> {
1179 match self.transport.mode {
1180 TransportMode::Http => {
1181 if self.transport.tls.cert_path.is_some() {
1182 return Err(ConfigValidationError::HttpModeDisallowsTlsField {
1183 field: "cert_path",
1184 });
1185 }
1186 if self.transport.tls.key_path.is_some() {
1187 return Err(ConfigValidationError::HttpModeDisallowsTlsField {
1188 field: "key_path",
1189 });
1190 }
1191 if self.transport.tls.client_ca_path.is_some() {
1192 return Err(ConfigValidationError::HttpModeDisallowsTlsField {
1193 field: "client_ca_path",
1194 });
1195 }
1196 }
1197 TransportMode::Tls => {
1198 if self.transport.tls.cert_path.is_none() {
1199 return Err(ConfigValidationError::MissingTlsField {
1200 mode: self.transport.mode,
1201 field: "cert_path",
1202 });
1203 }
1204 if self.transport.tls.key_path.is_none() {
1205 return Err(ConfigValidationError::MissingTlsField {
1206 mode: self.transport.mode,
1207 field: "key_path",
1208 });
1209 }
1210 if self.transport.tls.client_ca_path.is_some() {
1211 return Err(ConfigValidationError::ClientCaRequiresMtls);
1212 }
1213 }
1214 TransportMode::Mtls => {
1215 if self.transport.tls.cert_path.is_none() {
1216 return Err(ConfigValidationError::MissingTlsField {
1217 mode: self.transport.mode,
1218 field: "cert_path",
1219 });
1220 }
1221 if self.transport.tls.key_path.is_none() {
1222 return Err(ConfigValidationError::MissingTlsField {
1223 mode: self.transport.mode,
1224 field: "key_path",
1225 });
1226 }
1227 if self.transport.tls.client_ca_path.is_none() {
1228 return Err(ConfigValidationError::MissingTlsField {
1229 mode: self.transport.mode,
1230 field: "client_ca_path",
1231 });
1232 }
1233 }
1234 }
1235 Ok(())
1236 }
1237
1238 fn validate_alpn_protocols(&self) -> Result<(), ConfigValidationError> {
1239 let expected_alpn = default_alpn_protocols(&self.transport.http.versions);
1240 for (version, alpn) in expected_alpn.iter().map(|alpn| {
1241 let version = match alpn {
1242 AlpnProtocol::Http1_1 => HttpVersion::Http1,
1243 AlpnProtocol::H2 => HttpVersion::Http2,
1244 };
1245 (version, *alpn)
1246 }) {
1247 if !self.transport.tls.alpn_protocols.contains(&alpn) {
1248 return Err(ConfigValidationError::MissingAlpnProtocol { version, alpn });
1249 }
1250 }
1251 for alpn in &self.transport.tls.alpn_protocols {
1252 let expected_version = match alpn {
1253 AlpnProtocol::Http1_1 => HttpVersion::Http1,
1254 AlpnProtocol::H2 => HttpVersion::Http2,
1255 };
1256 if !self.transport.http.versions.contains(&expected_version) {
1257 return Err(ConfigValidationError::UnexpectedAlpnProtocol { alpn: *alpn });
1258 }
1259 }
1260 Ok(())
1261 }
1262
1263 pub fn validate_profile(
1274 &self,
1275 profile: &DeploymentProfile,
1276 ) -> Result<(), ConfigValidationError> {
1277 let is_prod = matches!(
1278 profile,
1279 DeploymentProfile::Prod | DeploymentProfile::ProdTls | DeploymentProfile::ProdMtls
1280 );
1281
1282 if is_prod && self.http.cors_origins == "*" && !self.http.allow_wildcard_cors {
1283 return Err(ConfigValidationError::WildcardCorsOriginsProd {
1284 profile: profile.as_str().to_string(),
1285 });
1286 }
1287
1288 Ok(())
1289 }
1290
1291 #[must_use]
1297 pub fn warnings(&self) -> Vec<String> {
1298 let mut w = Vec::new();
1299 if self.http.cors_origins == "*" && !self.http.allow_wildcard_cors {
1300 w.push(
1301 "http.cors_origins is set to '*' (allows all origins); \
1302 consider restricting for production use"
1303 .to_string(),
1304 );
1305 }
1306 w
1307 }
1308
1309 #[must_use]
1311 pub fn tls_enabled(&self) -> bool {
1312 self.transport.mode.uses_tls() && self.transport.tls.has_server_credentials()
1313 }
1314
1315 pub fn bind_socket_addr(&self) -> Result<SocketAddr, ConfigValidationError> {
1322 validate_socket_addr(&self.server.bind_address)
1323 }
1324
1325 #[must_use]
1327 pub fn long_poll_timeout(&self) -> Duration {
1328 Duration::from_secs(self.transport.connection.long_poll_timeout_secs)
1329 }
1330
1331 pub fn render_effective_json(&self) -> Result<String, serde_json::Error> {
1337 serde_json::to_string_pretty(self)
1338 }
1339}
1340
1341impl Default for Config {
1342 fn default() -> Self {
1343 let versions = vec![HttpVersion::Http1];
1344 Self {
1345 server: ServerConfig {
1346 bind_address: "0.0.0.0:4437".to_string(),
1347 },
1348 limits: LimitsConfig {
1349 max_memory_bytes: 100 * 1024 * 1024,
1350 max_stream_bytes: 10 * 1024 * 1024,
1351 max_stream_name_bytes: 1024,
1352 max_stream_name_segments: 8,
1353 },
1354 http: HttpConfig {
1355 cors_origins: "*".to_string(),
1356 stream_base_path: DEFAULT_STREAM_BASE_PATH.to_string(),
1357 allow_wildcard_cors: false,
1358 },
1359 storage: StorageConfig {
1360 mode: StorageMode::Memory,
1361 data_dir: "./data/streams".to_string(),
1362 acid_shard_count: 16,
1363 acid_backend: AcidBackend::File,
1364 },
1365 transport: TransportConfig {
1366 mode: TransportMode::Http,
1367 http: TransportHttpConfig {
1368 versions: versions.clone(),
1369 },
1370 tls: TransportTlsConfig {
1371 cert_path: None,
1372 key_path: None,
1373 client_ca_path: None,
1374 min_version: TlsVersion::V1_3,
1375 max_version: TlsVersion::V1_3,
1376 alpn_protocols: default_alpn_protocols(&versions),
1377 },
1378 connection: TransportConnectionConfig {
1379 long_poll_timeout_secs: 30,
1380 sse_reconnect_interval_secs: 60,
1381 },
1382 },
1383 proxy: ProxyConfig {
1384 enabled: false,
1385 forwarded_headers: ForwardedHeadersMode::None,
1386 trusted_proxies: Vec::new(),
1387 identity: ProxyIdentityConfig {
1388 mode: ProxyIdentityMode::None,
1389 header_name: None,
1390 require_tls: true,
1391 },
1392 },
1393 observability: ObservabilityConfig {
1394 rust_log: "info".to_string(),
1395 },
1396 }
1397 }
1398}
1399
1400#[derive(Debug, Clone, Copy)]
1402pub struct LongPollTimeout(pub Duration);
1403
1404#[derive(Debug, Clone, Copy)]
1408pub struct SseReconnectInterval(pub u64);
1409
1410fn built_in_profile_patch(profile: &DeploymentProfile) -> Option<ConfigPatch> {
1411 match profile {
1412 DeploymentProfile::Default | DeploymentProfile::Named(_) => None,
1413 DeploymentProfile::Dev => Some(ConfigPatch {
1414 server: ServerConfigPatch {
1415 bind_address: Some("127.0.0.1:4437".to_string()),
1416 ..ServerConfigPatch::default()
1417 },
1418 observability: ObservabilityConfigPatch {
1419 rust_log: Some("debug".to_string()),
1420 },
1421 ..ConfigPatch::default()
1422 }),
1423 DeploymentProfile::Prod => Some(ConfigPatch {
1424 limits: LimitsConfigPatch {
1425 max_memory_bytes: Some(512 * 1024 * 1024),
1426 max_stream_bytes: Some(256 * 1024 * 1024),
1427 ..LimitsConfigPatch::default()
1428 },
1429 storage: StorageConfigPatch {
1430 mode: Some(StorageMode::FileDurable),
1431 data_dir: Some("/var/lib/durable-streams".to_string()),
1432 acid_shard_count: Some(16),
1433 ..StorageConfigPatch::default()
1434 },
1435 ..ConfigPatch::default()
1436 }),
1437 DeploymentProfile::ProdTls => Some(ConfigPatch {
1438 limits: LimitsConfigPatch {
1439 max_memory_bytes: Some(512 * 1024 * 1024),
1440 max_stream_bytes: Some(256 * 1024 * 1024),
1441 ..LimitsConfigPatch::default()
1442 },
1443 storage: StorageConfigPatch {
1444 mode: Some(StorageMode::FileDurable),
1445 data_dir: Some("/var/lib/durable-streams".to_string()),
1446 acid_shard_count: Some(16),
1447 ..StorageConfigPatch::default()
1448 },
1449 transport: TransportConfigPatch {
1450 mode: Some(TransportMode::Tls),
1451 http: TransportHttpConfigPatch {
1452 versions: Some(vec![HttpVersion::Http1, HttpVersion::Http2]),
1453 },
1454 ..TransportConfigPatch::default()
1455 },
1456 ..ConfigPatch::default()
1457 }),
1458 DeploymentProfile::ProdMtls => Some(ConfigPatch {
1459 limits: LimitsConfigPatch {
1460 max_memory_bytes: Some(512 * 1024 * 1024),
1461 max_stream_bytes: Some(256 * 1024 * 1024),
1462 ..LimitsConfigPatch::default()
1463 },
1464 storage: StorageConfigPatch {
1465 mode: Some(StorageMode::FileDurable),
1466 data_dir: Some("/var/lib/durable-streams".to_string()),
1467 acid_shard_count: Some(16),
1468 ..StorageConfigPatch::default()
1469 },
1470 transport: TransportConfigPatch {
1471 mode: Some(TransportMode::Mtls),
1472 http: TransportHttpConfigPatch {
1473 versions: Some(vec![HttpVersion::Http1, HttpVersion::Http2]),
1474 },
1475 ..TransportConfigPatch::default()
1476 },
1477 ..ConfigPatch::default()
1478 }),
1479 }
1480}
1481
1482fn extract_toml_patch(path: &Path) -> Result<ConfigPatch, ConfigLoadError> {
1483 Figment::from(Toml::file(path))
1484 .extract()
1485 .map_err(|error| ConfigLoadError::TomlParse {
1486 message: error.to_string(),
1487 })
1488}
1489
1490fn parse_env<T>(
1491 get: &impl Fn(&str) -> Option<String>,
1492 key: &'static str,
1493) -> Result<Option<T>, ConfigLoadError>
1494where
1495 T: std::str::FromStr,
1496 <T as std::str::FromStr>::Err: std::fmt::Display,
1497{
1498 get(key)
1499 .map(|value| {
1500 value
1501 .parse::<T>()
1502 .map_err(|error| ConfigLoadError::InvalidValue {
1503 input_source: "environment",
1504 key,
1505 value,
1506 reason: error.to_string(),
1507 })
1508 })
1509 .transpose()
1510}
1511
1512fn parse_env_with<T>(
1513 get: &impl Fn(&str) -> Option<String>,
1514 key: &'static str,
1515 parser: impl Fn(&str) -> Option<T>,
1516) -> Result<Option<T>, ConfigLoadError> {
1517 get(key)
1518 .map(|value| {
1519 parser(&value).ok_or_else(|| ConfigLoadError::InvalidValue {
1520 input_source: "environment",
1521 key,
1522 value,
1523 reason: "unrecognized value".to_string(),
1524 })
1525 })
1526 .transpose()
1527}
1528
1529fn parse_env_list_with<T>(
1530 get: &impl Fn(&str) -> Option<String>,
1531 key: &'static str,
1532 parser: impl Fn(&str) -> Option<T>,
1533) -> Result<Option<Vec<T>>, ConfigLoadError> {
1534 get(key)
1535 .map(|value| {
1536 value
1537 .split(',')
1538 .map(str::trim)
1539 .filter(|item| !item.is_empty())
1540 .map(|item| {
1541 parser(item).ok_or_else(|| ConfigLoadError::InvalidValue {
1542 input_source: "environment",
1543 key,
1544 value: value.clone(),
1545 reason: format!("unrecognized list item '{item}'"),
1546 })
1547 })
1548 .collect::<Result<Vec<_>, _>>()
1549 })
1550 .transpose()
1551}
1552
1553fn parse_env_csv_strings(
1554 get: &impl Fn(&str) -> Option<String>,
1555 key: &'static str,
1556) -> Result<Option<Vec<String>>, ConfigLoadError> {
1557 get(key)
1558 .map(|value| {
1559 if value.trim().is_empty() {
1560 return Ok(Vec::new());
1561 }
1562 Ok(value
1563 .split(',')
1564 .map(str::trim)
1565 .filter(|item| !item.is_empty())
1566 .map(ToOwned::to_owned)
1567 .collect())
1568 })
1569 .transpose()
1570}
1571
1572impl MergeContext {
1573 fn finalize(self, config: &mut Config) {
1574 if !self.explicit_transport_mode
1575 && self.legacy_tls_seen
1576 && config.transport.tls.has_server_credentials()
1577 {
1578 config.transport.mode = TransportMode::Tls;
1579 }
1580 }
1581}
1582
1583fn parse_storage_mode_env(raw: &str) -> Option<StorageMode> {
1584 match raw.to_ascii_lowercase().as_str() {
1585 "memory" => Some(StorageMode::Memory),
1586 "file" | "file-durable" | "durable" => Some(StorageMode::FileDurable),
1587 "file-fast" | "fast" => Some(StorageMode::FileFast),
1588 "acid" | "redb" => Some(StorageMode::Acid),
1589 _ => None,
1590 }
1591}
1592
1593fn parse_acid_backend_env(raw: &str) -> Option<AcidBackend> {
1594 match raw.to_ascii_lowercase().as_str() {
1595 "file" => Some(AcidBackend::File),
1596 "memory" | "in-memory" | "inmemory" => Some(AcidBackend::InMemory),
1597 _ => None,
1598 }
1599}
1600
1601fn parse_transport_mode_env(raw: &str) -> Option<TransportMode> {
1602 match raw.to_ascii_lowercase().as_str() {
1603 "http" => Some(TransportMode::Http),
1604 "tls" => Some(TransportMode::Tls),
1605 "mtls" => Some(TransportMode::Mtls),
1606 _ => None,
1607 }
1608}
1609
1610fn parse_http_version_env(raw: &str) -> Option<HttpVersion> {
1611 match raw.to_ascii_lowercase().as_str() {
1612 "http1" | "http1.1" | "http/1.1" | "1.1" | "h1" => Some(HttpVersion::Http1),
1613 "http2" | "2" | "h2" => Some(HttpVersion::Http2),
1614 _ => None,
1615 }
1616}
1617
1618fn parse_tls_version_env(raw: &str) -> Option<TlsVersion> {
1619 match raw.to_ascii_lowercase().as_str() {
1620 "1.2" | "tls1.2" | "tls-1.2" => Some(TlsVersion::V1_2),
1621 "1.3" | "tls1.3" | "tls-1.3" => Some(TlsVersion::V1_3),
1622 _ => None,
1623 }
1624}
1625
1626fn parse_alpn_protocol_env(raw: &str) -> Option<AlpnProtocol> {
1627 match raw.to_ascii_lowercase().as_str() {
1628 "http/1.1" | "http1" | "h1" => Some(AlpnProtocol::Http1_1),
1629 "h2" | "http2" => Some(AlpnProtocol::H2),
1630 _ => None,
1631 }
1632}
1633
1634fn parse_forwarded_headers_mode_env(raw: &str) -> Option<ForwardedHeadersMode> {
1635 match raw.to_ascii_lowercase().as_str() {
1636 "none" => Some(ForwardedHeadersMode::None),
1637 "x-forwarded" | "xforwarded" => Some(ForwardedHeadersMode::XForwarded),
1638 "forwarded" => Some(ForwardedHeadersMode::Forwarded),
1639 _ => None,
1640 }
1641}
1642
1643fn parse_proxy_identity_mode_env(raw: &str) -> Option<ProxyIdentityMode> {
1644 match raw.to_ascii_lowercase().as_str() {
1645 "none" => Some(ProxyIdentityMode::None),
1646 "header" => Some(ProxyIdentityMode::Header),
1647 _ => None,
1648 }
1649}
1650
1651fn default_alpn_protocols(versions: &[HttpVersion]) -> Vec<AlpnProtocol> {
1652 let mut protocols = Vec::new();
1653 if versions.contains(&HttpVersion::Http2) {
1655 protocols.push(AlpnProtocol::H2);
1656 }
1657 if versions.contains(&HttpVersion::Http1) {
1658 protocols.push(AlpnProtocol::Http1_1);
1659 }
1660 protocols
1661}
1662
1663fn validate_socket_addr(raw: &str) -> Result<SocketAddr, ConfigValidationError> {
1664 raw.parse::<SocketAddr>()
1665 .map_err(|error| ConfigValidationError::InvalidBindAddress {
1666 value: raw.to_string(),
1667 reason: error.to_string(),
1668 })
1669}
1670
1671fn validate_cors_origins(origins: &str) -> Result<(), ConfigValidationError> {
1672 if origins == "*" {
1673 return Ok(());
1674 }
1675
1676 let mut parsed_any = false;
1677 for origin in origins.split(',').map(str::trim) {
1678 if origin.is_empty() {
1679 return Err(ConfigValidationError::EmptyCorsOrigin);
1680 }
1681 HeaderValue::from_str(origin).map_err(|_| ConfigValidationError::InvalidCorsOrigin {
1682 value: origin.to_string(),
1683 })?;
1684 parsed_any = true;
1685 }
1686
1687 if !parsed_any {
1688 return Err(ConfigValidationError::EmptyCorsOrigin);
1689 }
1690
1691 Ok(())
1692}
1693
1694fn validate_stream_base_path(raw: &str) -> Result<(), ConfigValidationError> {
1695 let trimmed = raw.trim();
1696 if trimmed.is_empty() {
1697 return Err(ConfigValidationError::InvalidStreamBasePath {
1698 value: raw.to_string(),
1699 reason: "must be a non-empty absolute path".to_string(),
1700 });
1701 }
1702 if !trimmed.starts_with('/') {
1703 return Err(ConfigValidationError::InvalidStreamBasePath {
1704 value: raw.to_string(),
1705 reason: "must start with '/'".to_string(),
1706 });
1707 }
1708
1709 if trimmed != "/" && trimmed.ends_with('/') {
1710 return Err(ConfigValidationError::InvalidStreamBasePath {
1711 value: raw.to_string(),
1712 reason: "must not end with '/' unless the path is '/'".to_string(),
1713 });
1714 }
1715
1716 Ok(())
1717}
1718
1719fn valid_acid_shard_count(value: usize) -> bool {
1720 (1..=256).contains(&value) && value.is_power_of_two()
1721}
1722
1723fn validate_proxy(config: &Config) -> Result<(), ConfigValidationError> {
1724 let proxy = &config.proxy;
1725 if !proxy.enabled {
1726 if !proxy.trusted_proxies.is_empty() {
1727 return Err(ConfigValidationError::ProxyDisabledDisallowsTrustedProxies);
1728 }
1729 if proxy.forwarded_headers != ForwardedHeadersMode::None {
1730 return Err(
1731 ConfigValidationError::ProxyDisabledDisallowsForwardedHeaders {
1732 mode: proxy.forwarded_headers,
1733 },
1734 );
1735 }
1736 if proxy.identity.mode != ProxyIdentityMode::None {
1737 return Err(ConfigValidationError::ProxyDisabledDisallowsIdentityMode {
1738 mode: proxy.identity.mode,
1739 });
1740 }
1741 if proxy.identity.header_name.is_some() {
1742 return Err(ConfigValidationError::ProxyDisabledDisallowsIdentityHeader);
1743 }
1744 return Ok(());
1745 }
1746
1747 if proxy.forwarded_headers == ForwardedHeadersMode::None {
1748 return Err(ConfigValidationError::ProxyEnabledRequiresForwardedHeaders);
1749 }
1750 if proxy.trusted_proxies.is_empty() {
1751 return Err(ConfigValidationError::ProxyEnabledRequiresTrustedProxies);
1752 }
1753 for value in &proxy.trusted_proxies {
1754 if !valid_ip_or_cidr(value) {
1755 return Err(ConfigValidationError::InvalidTrustedProxy {
1756 value: value.clone(),
1757 });
1758 }
1759 }
1760
1761 match proxy.identity.mode {
1762 ProxyIdentityMode::None => {
1763 if proxy.identity.header_name.is_some() {
1764 return Err(ConfigValidationError::IdentityHeaderRequiresHeaderMode);
1765 }
1766 }
1767 ProxyIdentityMode::Header => {
1768 if config.transport.mode != TransportMode::Mtls {
1769 return Err(ConfigValidationError::HeaderIdentityRequiresMtls);
1770 }
1771 let Some(header_name) = proxy.identity.header_name.as_deref() else {
1772 return Err(ConfigValidationError::HeaderIdentityRequiresHeaderName);
1773 };
1774 HeaderName::from_bytes(header_name.as_bytes()).map_err(|_| {
1775 ConfigValidationError::InvalidIdentityHeaderName {
1776 value: header_name.to_string(),
1777 }
1778 })?;
1779 }
1780 }
1781
1782 Ok(())
1783}
1784
1785fn valid_ip_or_cidr(raw: &str) -> bool {
1786 if raw.parse::<IpAddr>().is_ok() {
1787 return true;
1788 }
1789
1790 let Some((address, prefix)) = raw.split_once('/') else {
1791 return false;
1792 };
1793 let Ok(address) = address.parse::<IpAddr>() else {
1794 return false;
1795 };
1796 let Ok(prefix) = prefix.parse::<u8>() else {
1797 return false;
1798 };
1799
1800 match address {
1801 IpAddr::V4(_) => prefix <= 32,
1802 IpAddr::V6(_) => prefix <= 128,
1803 }
1804}
1805
1806#[cfg(test)]
1807mod tests {
1808 use super::*;
1809 use std::collections::HashMap;
1810 use std::fs;
1811 use std::sync::atomic::{AtomicU64, Ordering};
1812
1813 fn lookup(pairs: &[(&str, &str)]) -> impl Fn(&str) -> Option<String> {
1814 let map: HashMap<String, String> = pairs
1815 .iter()
1816 .map(|(key, value)| ((*key).to_string(), (*value).to_string()))
1817 .collect();
1818 move |key: &str| map.get(key).cloned()
1819 }
1820
1821 fn temp_config_dir() -> PathBuf {
1822 static COUNTER: AtomicU64 = AtomicU64::new(0);
1823 let id = COUNTER.fetch_add(1, Ordering::Relaxed);
1824 let path =
1825 std::env::temp_dir().join(format!("ds-config-tests-{}-{}", std::process::id(), id));
1826 fs::create_dir_all(&path).expect("create temp config dir");
1827 path
1828 }
1829
1830 #[test]
1831 fn test_default_config() {
1832 let config = Config::default();
1833 assert_eq!(config.server.bind_address, "0.0.0.0:4437");
1834 assert_eq!(config.limits.max_memory_bytes, 100 * 1024 * 1024);
1835 assert_eq!(config.limits.max_stream_bytes, 10 * 1024 * 1024);
1836 assert_eq!(config.http.cors_origins, "*");
1837 assert_eq!(config.transport.connection.long_poll_timeout_secs, 30);
1838 assert_eq!(config.transport.connection.sse_reconnect_interval_secs, 60);
1839 assert_eq!(config.http.stream_base_path, DEFAULT_STREAM_BASE_PATH);
1840 assert_eq!(config.storage.mode, StorageMode::Memory);
1841 assert_eq!(config.storage.data_dir, "./data/streams");
1842 assert_eq!(config.storage.acid_shard_count, 16);
1843 assert_eq!(config.storage.acid_backend, AcidBackend::File);
1844 assert_eq!(config.transport.mode, TransportMode::Http);
1845 assert_eq!(config.transport.http.versions, vec![HttpVersion::Http1]);
1846 assert_eq!(config.transport.tls.cert_path, None);
1847 assert_eq!(config.transport.tls.key_path, None);
1848 assert_eq!(config.observability.rust_log, "info");
1849 }
1850
1851 #[test]
1852 fn test_from_env_uses_defaults_when_no_ds_vars() {
1853 let config = Config::from_env().expect("config from env");
1854 assert_eq!(config.server.bind_address, "0.0.0.0:4437");
1855 assert_eq!(config.storage.mode, StorageMode::Memory);
1856 assert_eq!(config.observability.rust_log, "info");
1857 }
1858
1859 #[test]
1860 fn test_env_overrides_parse_new_and_legacy_keys() {
1861 let options = ConfigLoadOptions::default();
1862 let env = lookup(&[
1863 ("DS_SERVER__PORT", "8080"),
1864 ("DS_LIMITS__MAX_MEMORY_BYTES", "200000000"),
1865 ("DS_LIMITS__MAX_STREAM_BYTES", "20000000"),
1866 ("DS_HTTP__CORS_ORIGINS", "https://example.com"),
1867 ("DS_TRANSPORT__CONNECTION__LONG_POLL_TIMEOUT_SECS", "5"),
1868 ("DS_SERVER__SSE_RECONNECT_INTERVAL_SECS", "120"),
1869 ("DS_HTTP__STREAM_BASE_PATH", "/streams"),
1870 ("DS_STORAGE__MODE", "file-fast"),
1871 ("DS_STORAGE__DATA_DIR", "/tmp/ds-store"),
1872 ("DS_STORAGE__ACID_SHARD_COUNT", "32"),
1873 ("DS_TRANSPORT__MODE", "tls"),
1874 ("DS_TLS__CERT_PATH", "/tmp/cert.pem"),
1875 ("DS_TRANSPORT__TLS__KEY_PATH", "/tmp/key.pem"),
1876 ("DS_TRANSPORT__HTTP__VERSIONS", "http1,http2"),
1877 ("DS_OBSERVABILITY__RUST_LOG", "debug"),
1878 ]);
1879 let config = Config::from_sources_with_lookup(&options, &env).expect("config from env");
1880
1881 assert_eq!(config.server.bind_address, "0.0.0.0:8080");
1882 assert_eq!(config.limits.max_memory_bytes, 200_000_000);
1883 assert_eq!(config.limits.max_stream_bytes, 20_000_000);
1884 assert_eq!(config.http.cors_origins, "https://example.com");
1885 assert_eq!(config.transport.connection.long_poll_timeout_secs, 5);
1886 assert_eq!(config.transport.connection.sse_reconnect_interval_secs, 120);
1887 assert_eq!(config.http.stream_base_path, "/streams");
1888 assert_eq!(config.storage.mode, StorageMode::FileFast);
1889 assert_eq!(config.storage.data_dir, "/tmp/ds-store");
1890 assert_eq!(config.storage.acid_shard_count, 32);
1891 assert_eq!(config.transport.mode, TransportMode::Tls);
1892 assert_eq!(
1893 config.transport.http.versions,
1894 vec![HttpVersion::Http1, HttpVersion::Http2]
1895 );
1896 assert_eq!(
1897 config.transport.tls.alpn_protocols,
1898 vec![AlpnProtocol::H2, AlpnProtocol::Http1_1]
1899 );
1900 assert_eq!(
1901 config.transport.tls.cert_path.as_deref(),
1902 Some("/tmp/cert.pem")
1903 );
1904 assert_eq!(
1905 config.transport.tls.key_path.as_deref(),
1906 Some("/tmp/key.pem")
1907 );
1908 assert_eq!(config.observability.rust_log, "debug");
1909 }
1910
1911 #[test]
1912 fn test_invalid_env_override_returns_typed_error() {
1913 let err = Config::from_sources_with_lookup(
1914 &ConfigLoadOptions::default(),
1915 &lookup(&[("DS_TRANSPORT__TLS__MIN_VERSION", "tls1.0")]),
1916 )
1917 .expect_err("expected invalid env override");
1918
1919 assert_eq!(
1920 err,
1921 ConfigLoadError::InvalidValue {
1922 input_source: "environment",
1923 key: "DS_TRANSPORT__TLS__MIN_VERSION",
1924 value: "tls1.0".to_string(),
1925 reason: "unrecognized value".to_string(),
1926 }
1927 );
1928 }
1929
1930 #[test]
1931 fn test_built_in_profile_defaults_apply_cleanly() {
1932 let config_dir = temp_config_dir();
1935 let config = Config::from_sources_with_lookup(
1936 &ConfigLoadOptions {
1937 config_dir,
1938 profile: DeploymentProfile::ProdTls,
1939 config_override: None,
1940 },
1941 &lookup(&[]),
1942 )
1943 .expect("config");
1944
1945 assert_eq!(config.storage.mode, StorageMode::FileDurable);
1946 assert_eq!(config.storage.data_dir, "/var/lib/durable-streams");
1947 assert_eq!(config.transport.mode, TransportMode::Tls);
1948 assert_eq!(
1949 config.transport.http.versions,
1950 vec![HttpVersion::Http1, HttpVersion::Http2]
1951 );
1952 assert_eq!(
1953 config.transport.tls.alpn_protocols,
1954 vec![AlpnProtocol::H2, AlpnProtocol::Http1_1]
1955 );
1956 }
1957
1958 #[test]
1959 fn test_sources_layer_default_profile_local_override_and_env() {
1960 let config_dir = temp_config_dir();
1961 fs::write(
1962 config_dir.join("default.toml"),
1963 r#"
1964[server]
1965bind_address = "0.0.0.0:4437"
1966
1967[http]
1968stream_base_path = "/v1/stream"
1969
1970[storage]
1971mode = "memory"
1972
1973[transport.connection]
1974long_poll_timeout_secs = 30
1975
1976[observability]
1977rust_log = "warn"
1978"#,
1979 )
1980 .expect("write default config");
1981 fs::write(
1982 config_dir.join("dev.toml"),
1983 r#"
1984[server]
1985bind_address = "127.0.0.1:7777"
1986
1987[http]
1988stream_base_path = "/streams"
1989
1990[storage]
1991mode = "file-fast"
1992data_dir = "/tmp/dev-store"
1993"#,
1994 )
1995 .expect("write profile config");
1996 fs::write(
1997 config_dir.join("local.toml"),
1998 r#"
1999[server]
2000bind_address = "127.0.0.1:8888"
2001"#,
2002 )
2003 .expect("write local config");
2004
2005 let config = Config::from_sources_with_lookup(
2006 &ConfigLoadOptions {
2007 config_dir,
2008 profile: DeploymentProfile::Dev,
2009 config_override: None,
2010 },
2011 &lookup(&[
2012 ("DS_SERVER__BIND_ADDRESS", "127.0.0.1:9999"),
2013 ("DS_OBSERVABILITY__RUST_LOG", "debug"),
2014 ]),
2015 )
2016 .expect("config from sources");
2017
2018 assert_eq!(config.server.bind_address, "127.0.0.1:9999");
2019 assert_eq!(config.http.stream_base_path, "/streams");
2020 assert_eq!(config.storage.mode, StorageMode::FileFast);
2021 assert_eq!(config.storage.data_dir, "/tmp/dev-store");
2022 assert_eq!(config.observability.rust_log, "debug");
2023 }
2024
2025 #[test]
2026 fn test_legacy_tls_fields_infer_tls_mode_when_mode_not_set() {
2027 let config_dir = temp_config_dir();
2028 fs::write(
2029 config_dir.join("default.toml"),
2030 r#"
2031[tls]
2032cert_path = "/tmp/cert.pem"
2033key_path = "/tmp/key.pem"
2034"#,
2035 )
2036 .expect("write config");
2037
2038 let config = Config::from_sources_with_lookup(
2039 &ConfigLoadOptions {
2040 config_dir,
2041 ..ConfigLoadOptions::default()
2042 },
2043 &lookup(&[]),
2044 )
2045 .expect("config from sources");
2046
2047 assert_eq!(config.transport.mode, TransportMode::Tls);
2048 assert_eq!(
2049 config.transport.tls.cert_path.as_deref(),
2050 Some("/tmp/cert.pem")
2051 );
2052 assert_eq!(
2053 config.transport.tls.key_path.as_deref(),
2054 Some("/tmp/key.pem")
2055 );
2056 }
2057
2058 #[test]
2059 fn test_render_effective_json_contains_nested_sections() {
2060 let rendered = Config::default()
2061 .render_effective_json()
2062 .expect("render effective config");
2063 assert!(rendered.contains("\"transport\""));
2064 assert!(rendered.contains("\"observability\""));
2065 assert!(rendered.contains("\"proxy\""));
2066 }
2067
2068 #[test]
2069 fn test_validate_accepts_valid_config_matrix() {
2070 let valid_configs = [
2071 Config::default(),
2072 Config {
2073 transport: TransportConfig {
2074 mode: TransportMode::Tls,
2075 http: TransportHttpConfig {
2076 versions: vec![HttpVersion::Http1, HttpVersion::Http2],
2077 },
2078 tls: TransportTlsConfig {
2079 cert_path: Some("/tmp/cert.pem".to_string()),
2080 key_path: Some("/tmp/key.pem".to_string()),
2081 client_ca_path: None,
2082 min_version: TlsVersion::V1_2,
2083 max_version: TlsVersion::V1_3,
2084 alpn_protocols: vec![AlpnProtocol::Http1_1, AlpnProtocol::H2],
2085 },
2086 connection: TransportConnectionConfig {
2087 long_poll_timeout_secs: 30,
2088 sse_reconnect_interval_secs: 60,
2089 },
2090 },
2091 ..Config::default()
2092 },
2093 Config {
2094 transport: TransportConfig {
2095 mode: TransportMode::Mtls,
2096 http: TransportHttpConfig {
2097 versions: vec![HttpVersion::Http1],
2098 },
2099 tls: TransportTlsConfig {
2100 cert_path: Some("/tmp/cert.pem".to_string()),
2101 key_path: Some("/tmp/key.pem".to_string()),
2102 client_ca_path: Some("/tmp/ca.pem".to_string()),
2103 min_version: TlsVersion::V1_2,
2104 max_version: TlsVersion::V1_3,
2105 alpn_protocols: vec![AlpnProtocol::Http1_1],
2106 },
2107 connection: TransportConnectionConfig {
2108 long_poll_timeout_secs: 30,
2109 sse_reconnect_interval_secs: 60,
2110 },
2111 },
2112 proxy: ProxyConfig {
2113 enabled: true,
2114 forwarded_headers: ForwardedHeadersMode::XForwarded,
2115 trusted_proxies: vec!["127.0.0.1/32".to_string()],
2116 identity: ProxyIdentityConfig {
2117 mode: ProxyIdentityMode::Header,
2118 header_name: Some("x-client-identity".to_string()),
2119 require_tls: true,
2120 },
2121 },
2122 ..Config::default()
2123 },
2124 ];
2125
2126 for config in valid_configs {
2127 assert!(
2128 config.validate().is_ok(),
2129 "config should validate: {config:?}"
2130 );
2131 }
2132 }
2133
2134 fn assert_invalid_configs(
2135 invalid_cases: impl IntoIterator<Item = (Config, ConfigValidationError)>,
2136 ) {
2137 for (config, expected) in invalid_cases {
2138 assert_eq!(config.validate().expect_err("config should fail"), expected);
2139 }
2140 }
2141
2142 #[test]
2143 fn test_validate_rejects_http_transport_tls_misconfigurations() {
2144 assert_invalid_configs([
2145 (
2146 Config {
2147 transport: TransportConfig {
2148 mode: TransportMode::Http,
2149 tls: TransportTlsConfig {
2150 cert_path: Some("/tmp/cert.pem".to_string()),
2151 ..Config::default().transport.tls
2152 },
2153 ..Config::default().transport
2154 },
2155 ..Config::default()
2156 },
2157 ConfigValidationError::HttpModeDisallowsTlsField { field: "cert_path" },
2158 ),
2159 (
2160 Config {
2161 transport: TransportConfig {
2162 mode: TransportMode::Tls,
2163 tls: TransportTlsConfig {
2164 cert_path: Some("/tmp/cert.pem".to_string()),
2165 key_path: None,
2166 ..Config::default().transport.tls
2167 },
2168 ..Config::default().transport
2169 },
2170 ..Config::default()
2171 },
2172 ConfigValidationError::MissingTlsField {
2173 mode: TransportMode::Tls,
2174 field: "key_path",
2175 },
2176 ),
2177 (
2178 Config {
2179 transport: TransportConfig {
2180 mode: TransportMode::Http,
2181 http: TransportHttpConfig {
2182 versions: vec![HttpVersion::Http1, HttpVersion::Http2],
2183 },
2184 tls: TransportTlsConfig {
2185 alpn_protocols: vec![AlpnProtocol::Http1_1, AlpnProtocol::H2],
2186 ..Config::default().transport.tls
2187 },
2188 ..Config::default().transport
2189 },
2190 ..Config::default()
2191 },
2192 ConfigValidationError::HttpModeDoesNotSupportHttp2,
2193 ),
2194 ]);
2195 }
2196
2197 #[test]
2198 fn test_validate_rejects_invalid_tls_ranges_and_proxy_headers() {
2199 assert_invalid_configs([
2200 (
2201 Config {
2202 transport: TransportConfig {
2203 mode: TransportMode::Tls,
2204 tls: TransportTlsConfig {
2205 cert_path: Some("/tmp/cert.pem".to_string()),
2206 key_path: Some("/tmp/key.pem".to_string()),
2207 min_version: TlsVersion::V1_3,
2208 max_version: TlsVersion::V1_2,
2209 alpn_protocols: vec![AlpnProtocol::Http1_1],
2210 ..Config::default().transport.tls
2211 },
2212 ..Config::default().transport
2213 },
2214 ..Config::default()
2215 },
2216 ConfigValidationError::InvalidTlsVersionRange,
2217 ),
2218 (
2219 Config {
2220 proxy: ProxyConfig {
2221 enabled: true,
2222 forwarded_headers: ForwardedHeadersMode::None,
2223 trusted_proxies: vec!["127.0.0.1".to_string()],
2224 ..Config::default().proxy
2225 },
2226 ..Config::default()
2227 },
2228 ConfigValidationError::ProxyEnabledRequiresForwardedHeaders,
2229 ),
2230 (
2231 Config {
2232 proxy: ProxyConfig {
2233 enabled: true,
2234 forwarded_headers: ForwardedHeadersMode::Forwarded,
2235 trusted_proxies: vec!["not-a-cidr".to_string()],
2236 ..Config::default().proxy
2237 },
2238 ..Config::default()
2239 },
2240 ConfigValidationError::InvalidTrustedProxy {
2241 value: "not-a-cidr".to_string(),
2242 },
2243 ),
2244 ]);
2245 }
2246
2247 #[test]
2248 fn test_validate_rejects_invalid_proxy_identity_requirements() {
2249 assert_invalid_configs([
2250 (
2251 Config {
2252 transport: TransportConfig {
2253 mode: TransportMode::Tls,
2254 tls: TransportTlsConfig {
2255 cert_path: Some("/tmp/cert.pem".to_string()),
2256 key_path: Some("/tmp/key.pem".to_string()),
2257 alpn_protocols: vec![AlpnProtocol::Http1_1],
2258 ..Config::default().transport.tls
2259 },
2260 ..Config::default().transport
2261 },
2262 proxy: ProxyConfig {
2263 enabled: true,
2264 forwarded_headers: ForwardedHeadersMode::XForwarded,
2265 trusted_proxies: vec!["127.0.0.1".to_string()],
2266 identity: ProxyIdentityConfig {
2267 mode: ProxyIdentityMode::Header,
2268 header_name: Some("x-client-identity".to_string()),
2269 require_tls: true,
2270 },
2271 },
2272 ..Config::default()
2273 },
2274 ConfigValidationError::HeaderIdentityRequiresMtls,
2275 ),
2276 (
2277 Config {
2278 transport: TransportConfig {
2279 mode: TransportMode::Mtls,
2280 tls: TransportTlsConfig {
2281 cert_path: Some("/tmp/cert.pem".to_string()),
2282 key_path: Some("/tmp/key.pem".to_string()),
2283 client_ca_path: Some("/tmp/ca.pem".to_string()),
2284 alpn_protocols: vec![AlpnProtocol::Http1_1],
2285 ..Config::default().transport.tls
2286 },
2287 ..Config::default().transport
2288 },
2289 proxy: ProxyConfig {
2290 enabled: true,
2291 forwarded_headers: ForwardedHeadersMode::XForwarded,
2292 trusted_proxies: vec!["127.0.0.1".to_string()],
2293 identity: ProxyIdentityConfig {
2294 mode: ProxyIdentityMode::Header,
2295 header_name: None,
2296 require_tls: true,
2297 },
2298 },
2299 ..Config::default()
2300 },
2301 ConfigValidationError::HeaderIdentityRequiresHeaderName,
2302 ),
2303 ]);
2304 }
2305
2306 #[test]
2309 fn test_wildcard_cors_emits_warning() {
2310 let config = Config::default();
2311 assert_eq!(config.http.cors_origins, "*");
2312 let warnings = config.warnings();
2313 assert_eq!(warnings.len(), 1);
2314 assert!(warnings[0].contains("cors_origins"));
2315 }
2316
2317 #[test]
2318 fn test_allow_wildcard_cors_suppresses_warning() {
2319 let config = Config {
2320 http: HttpConfig {
2321 allow_wildcard_cors: true,
2322 ..Config::default().http
2323 },
2324 ..Config::default()
2325 };
2326 assert!(config.warnings().is_empty());
2327 }
2328
2329 #[test]
2330 fn test_explicit_origins_no_warning() {
2331 let config = Config {
2332 http: HttpConfig {
2333 cors_origins: "https://example.com".to_string(),
2334 ..Config::default().http
2335 },
2336 ..Config::default()
2337 };
2338 assert!(config.warnings().is_empty());
2339 }
2340
2341 #[test]
2342 fn test_validate_profile_rejects_wildcard_cors_for_prod_profiles() {
2343 let config = Config::default();
2344 for profile in [
2345 DeploymentProfile::Prod,
2346 DeploymentProfile::ProdTls,
2347 DeploymentProfile::ProdMtls,
2348 ] {
2349 let expected = ConfigValidationError::WildcardCorsOriginsProd {
2350 profile: profile.as_str().to_string(),
2351 };
2352 assert_eq!(
2353 config.validate_profile(&profile).expect_err("should fail"),
2354 expected,
2355 );
2356 }
2357 }
2358
2359 #[test]
2360 fn test_validate_profile_allows_wildcard_cors_for_non_prod_profiles() {
2361 let config = Config::default();
2362 for profile in [
2363 DeploymentProfile::Default,
2364 DeploymentProfile::Dev,
2365 DeploymentProfile::Named("staging".to_string()),
2366 ] {
2367 assert!(
2368 config.validate_profile(&profile).is_ok(),
2369 "non-prod profile {profile:?} should pass"
2370 );
2371 }
2372 }
2373
2374 #[test]
2375 fn test_validate_profile_allows_wildcard_cors_with_escape_hatch() {
2376 let config = Config {
2377 http: HttpConfig {
2378 allow_wildcard_cors: true,
2379 ..Config::default().http
2380 },
2381 ..Config::default()
2382 };
2383 assert!(config.validate_profile(&DeploymentProfile::Prod).is_ok());
2384 assert!(config.validate_profile(&DeploymentProfile::ProdTls).is_ok());
2385 assert!(
2386 config
2387 .validate_profile(&DeploymentProfile::ProdMtls)
2388 .is_ok()
2389 );
2390 }
2391
2392 #[test]
2393 fn test_memory_mode_allows_empty_data_dir() {
2394 let config = Config {
2395 storage: StorageConfig {
2396 data_dir: String::new(),
2397 ..Config::default().storage
2398 },
2399 ..Config::default()
2400 };
2401 assert!(config.validate().is_ok());
2402 }
2403
2404 #[test]
2405 fn test_allow_wildcard_cors_env_override() {
2406 let config = Config::from_sources_with_lookup(
2407 &ConfigLoadOptions::default(),
2408 &lookup(&[("DS_HTTP__ALLOW_WILDCARD_CORS", "true")]),
2409 )
2410 .expect("config from env");
2411 assert!(config.http.allow_wildcard_cors);
2412 assert!(config.warnings().is_empty());
2413 assert!(config.validate_profile(&DeploymentProfile::Prod).is_ok());
2414 }
2415}