Skip to main content

durable_streams_server/
config.rs

1//! Configuration loading and runtime settings for the server crate.
2//!
3//! [`Config`] is the resolved runtime view. Use [`Config::from_sources`] for the
4//! layered TOML plus environment-variable flow used by the binary, or
5//! [`Config::from_env`] when tests only need the `DS_*` override surface.
6
7use crate::router::DEFAULT_STREAM_BASE_PATH;
8use axum::http::HeaderValue;
9use figment::{
10    Figment,
11    providers::{Format, Toml},
12};
13use serde::Deserialize;
14use std::env;
15use std::path::PathBuf;
16use std::time::Duration;
17
18/// Storage backend selection for the server runtime.
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum StorageMode {
21    /// In-memory backend.
22    Memory,
23    /// File backend without fsync/fdatasync on every append.
24    FileFast,
25    /// File backend with fsync/fdatasync on every append.
26    FileDurable,
27    /// ACID backend using sharded redb databases.
28    Acid,
29}
30
31/// Redb storage backend used by the [`StorageMode::Acid`] mode.
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum AcidBackend {
34    /// File-backed redb (default). Data persists across restarts.
35    File,
36    /// In-memory redb. Provides ACID transactions without disk I/O; all data
37    /// is lost on shutdown.
38    InMemory,
39}
40
41impl AcidBackend {
42    #[must_use]
43    pub fn as_str(self) -> &'static str {
44        match self {
45            Self::File => "file",
46            Self::InMemory => "memory",
47        }
48    }
49}
50
51impl StorageMode {
52    #[must_use]
53    pub fn as_str(self) -> &'static str {
54        match self {
55            Self::Memory => "memory",
56            Self::FileFast => "file-fast",
57            Self::FileDurable => "file-durable",
58            Self::Acid => "acid",
59        }
60    }
61
62    #[must_use]
63    pub fn uses_file_backend(self) -> bool {
64        matches!(self, Self::FileFast | Self::FileDurable)
65    }
66
67    #[must_use]
68    pub fn sync_on_append(self) -> bool {
69        matches!(self, Self::FileDurable)
70    }
71}
72
73/// Server configuration
74#[derive(Debug, Clone)]
75pub struct Config {
76    /// TCP port to bind the server to.
77    pub port: u16,
78    /// Maximum total in-process payload bytes across all streams.
79    pub max_memory_bytes: u64,
80    /// Maximum payload bytes retained for any single stream.
81    pub max_stream_bytes: u64,
82    /// Maximum byte length of a stream name.
83    pub max_stream_name_bytes: usize,
84    /// Maximum number of `/`-separated segments in a stream name.
85    pub max_stream_name_segments: usize,
86    /// CORS allowlist as `"*"` or a comma-separated origin list.
87    pub cors_origins: String,
88    /// Long-poll timeout used by `GET ?live=long-poll`.
89    pub long_poll_timeout: Duration,
90    /// SSE reconnect interval in seconds (`0` disables forced reconnects).
91    ///
92    /// Matches Caddy's `sse_reconnect_interval`. Connections are closed after
93    /// this many idle seconds to enable CDN request collapsing.
94    pub sse_reconnect_interval_secs: u64,
95    /// Mount path for the protocol HTTP surface.
96    pub stream_base_path: String,
97    /// Selected persistence backend.
98    pub storage_mode: StorageMode,
99    /// Root directory for file-backed and acid-backed storage.
100    ///
101    /// Matches Caddy's `data_dir`.
102    pub data_dir: String,
103    /// Number of shards used by the acid/redb backend.
104    pub acid_shard_count: usize,
105    /// Redb backend selection for the acid storage mode.
106    pub acid_backend: AcidBackend,
107    /// Optional TLS certificate path in PEM format. Requires `tls_key_path`.
108    pub tls_cert_path: Option<String>,
109    /// Optional TLS private key path in PEM or PKCS#8 format. Requires `tls_cert_path`.
110    pub tls_key_path: Option<String>,
111    /// Default tracing filter used when `RUST_LOG` is not explicitly set.
112    pub rust_log: String,
113}
114
115#[derive(Debug, Clone)]
116pub struct ConfigLoadOptions {
117    /// Directory containing `default.toml`, `<profile>.toml`, and `local.toml`.
118    pub config_dir: PathBuf,
119    /// Named profile loaded after `default.toml`, for example `dev` or `prod`.
120    pub profile: String,
121    /// Optional extra TOML file merged after the standard config files.
122    pub config_override: Option<PathBuf>,
123}
124
125impl Default for ConfigLoadOptions {
126    fn default() -> Self {
127        Self {
128            config_dir: PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("config"),
129            profile: "default".to_string(),
130            config_override: None,
131        }
132    }
133}
134
135#[derive(Debug, Deserialize, Default)]
136#[serde(default)]
137struct SettingsFile {
138    server: ServerSettingsFile,
139    limits: LimitsSettingsFile,
140    http: HttpSettingsFile,
141    storage: StorageSettingsFile,
142    tls: TlsSettingsFile,
143    log: LogSettingsFile,
144}
145
146#[derive(Debug, Deserialize, Default)]
147#[serde(default)]
148struct ServerSettingsFile {
149    port: Option<u16>,
150    long_poll_timeout_secs: Option<u64>,
151    sse_reconnect_interval_secs: Option<u64>,
152}
153
154#[derive(Debug, Deserialize, Default)]
155#[serde(default)]
156#[allow(clippy::struct_field_names)]
157struct LimitsSettingsFile {
158    max_memory_bytes: Option<u64>,
159    max_stream_bytes: Option<u64>,
160    max_stream_name_bytes: Option<usize>,
161    max_stream_name_segments: Option<usize>,
162}
163
164#[derive(Debug, Deserialize, Default)]
165#[serde(default)]
166struct HttpSettingsFile {
167    cors_origins: Option<String>,
168    stream_base_path: Option<String>,
169}
170
171#[derive(Debug, Deserialize, Default)]
172#[serde(default)]
173struct StorageSettingsFile {
174    mode: Option<String>,
175    data_dir: Option<String>,
176    acid_shard_count: Option<usize>,
177    acid_backend: Option<String>,
178}
179
180#[derive(Debug, Deserialize, Default)]
181#[serde(default)]
182struct TlsSettingsFile {
183    cert_path: Option<String>,
184    key_path: Option<String>,
185}
186
187#[derive(Debug, Deserialize, Default)]
188#[serde(default)]
189struct LogSettingsFile {
190    rust_log: Option<String>,
191}
192
193impl Config {
194    /// Load configuration from `DS_*` environment variables with sensible defaults.
195    ///
196    /// Used by tests and as a simple entry point when TOML layering is not needed.
197    /// # Errors
198    ///
199    /// Returns an error when any `DS_*` environment variable is present but invalid.
200    pub fn from_env() -> Result<Self, String> {
201        let mut config = Self::default();
202        config.apply_env_overrides(&|key| env::var(key).ok())?;
203        Ok(config)
204    }
205
206    /// Load configuration from layered TOML files plus environment overrides.
207    ///
208    /// Order (later wins):
209    /// 1. built-in defaults
210    /// 2. `config/default.toml` (if present)
211    /// 3. `config/<profile>.toml` (if present)
212    /// 4. `config/local.toml` (if present)
213    /// 5. `--config <path>` override file (if provided)
214    /// 6. `DS_*` env vars
215    ///
216    /// # Errors
217    ///
218    /// Returns an error when config files cannot be parsed or an explicit
219    /// override file path does not exist/read.
220    pub fn from_sources(options: &ConfigLoadOptions) -> Result<Self, String> {
221        let get = |key: &str| env::var(key).ok();
222        Self::from_sources_with_lookup(options, &get)
223    }
224
225    fn from_sources_with_lookup(
226        options: &ConfigLoadOptions,
227        get: &impl Fn(&str) -> Option<String>,
228    ) -> Result<Self, String> {
229        let mut figment = Figment::new();
230
231        let default_path = options.config_dir.join("default.toml");
232        if default_path.is_file() {
233            figment = figment.merge(Toml::file(&default_path));
234        }
235
236        let profile_path = options
237            .config_dir
238            .join(format!("{}.toml", options.profile.trim()));
239        if profile_path.is_file() {
240            figment = figment.merge(Toml::file(&profile_path));
241        }
242
243        let local_path = options.config_dir.join("local.toml");
244        if local_path.is_file() {
245            figment = figment.merge(Toml::file(&local_path));
246        }
247
248        if let Some(override_path) = &options.config_override {
249            if !override_path.is_file() {
250                return Err(format!(
251                    "config override file not found: '{}'",
252                    override_path.display()
253                ));
254            }
255            figment = figment.merge(Toml::file(override_path));
256        }
257
258        let settings: SettingsFile = figment
259            .extract()
260            .map_err(|e| format!("failed to parse TOML config: {e}"))?;
261
262        let mut config = Self::apply_file_settings(settings)?;
263        config.apply_env_overrides(get)?;
264        Ok(config)
265    }
266
267    fn apply_file_settings(settings: SettingsFile) -> Result<Self, String> {
268        let mut config = Self::default();
269
270        if let Some(port) = settings.server.port {
271            config.port = port;
272        }
273        if let Some(long_poll_timeout_secs) = settings.server.long_poll_timeout_secs {
274            config.long_poll_timeout = Duration::from_secs(long_poll_timeout_secs);
275        }
276        if let Some(sse_reconnect_interval_secs) = settings.server.sse_reconnect_interval_secs {
277            config.sse_reconnect_interval_secs = sse_reconnect_interval_secs;
278        }
279
280        if let Some(max_memory_bytes) = settings.limits.max_memory_bytes {
281            config.max_memory_bytes = max_memory_bytes;
282        }
283        if let Some(max_stream_bytes) = settings.limits.max_stream_bytes {
284            config.max_stream_bytes = max_stream_bytes;
285        }
286        if let Some(max_stream_name_bytes) = settings.limits.max_stream_name_bytes {
287            config.max_stream_name_bytes = max_stream_name_bytes;
288        }
289        if let Some(max_stream_name_segments) = settings.limits.max_stream_name_segments {
290            config.max_stream_name_segments = max_stream_name_segments;
291        }
292
293        if let Some(cors_origins) = settings.http.cors_origins {
294            config.cors_origins = cors_origins;
295        }
296        if let Some(stream_base_path) = settings.http.stream_base_path {
297            config.stream_base_path = Self::parse_stream_base_path_value(&stream_base_path)
298                .map_err(|reason| format!("invalid http.stream_base_path value: {reason}"))?;
299        }
300
301        if let Some(mode) = settings.storage.mode {
302            config.storage_mode = Self::parse_storage_mode_value(&mode)
303                .ok_or_else(|| format!("invalid storage.mode value: '{mode}'"))?;
304        }
305        if let Some(data_dir) = settings.storage.data_dir {
306            config.data_dir = data_dir;
307        }
308        if let Some(acid_shard_count) = settings.storage.acid_shard_count {
309            if Self::valid_acid_shard_count(acid_shard_count) {
310                config.acid_shard_count = acid_shard_count;
311            } else {
312                return Err(format!(
313                    "invalid storage.acid_shard_count value: '{acid_shard_count}' (must be power-of-two in 1..=256)"
314                ));
315            }
316        }
317        if let Some(acid_backend) = settings.storage.acid_backend {
318            config.acid_backend = Self::parse_acid_backend_value(&acid_backend)
319                .ok_or_else(|| format!("invalid storage.acid_backend value: '{acid_backend}'"))?;
320        }
321
322        config.tls_cert_path = settings.tls.cert_path;
323        config.tls_key_path = settings.tls.key_path;
324
325        if let Some(rust_log) = settings.log.rust_log {
326            config.rust_log = rust_log;
327        }
328
329        Ok(config)
330    }
331
332    /// Apply `DS_*` environment variable overrides on top of current config.
333    fn apply_env_overrides(&mut self, get: &impl Fn(&str) -> Option<String>) -> Result<(), String> {
334        if let Some(port) = get("DS_SERVER__PORT") {
335            self.port = port
336                .parse()
337                .map_err(|_| format!("invalid DS_SERVER__PORT value: '{port}'"))?;
338        }
339        if let Some(long_poll_timeout_secs) = get("DS_SERVER__LONG_POLL_TIMEOUT_SECS") {
340            self.long_poll_timeout = Duration::from_secs(
341                long_poll_timeout_secs
342                    .parse()
343                    .map_err(|_| format!("invalid DS_SERVER__LONG_POLL_TIMEOUT_SECS value: '{long_poll_timeout_secs}'"))?,
344            );
345        }
346        if let Some(sse_reconnect_interval_secs) = get("DS_SERVER__SSE_RECONNECT_INTERVAL_SECS") {
347            self.sse_reconnect_interval_secs = sse_reconnect_interval_secs.parse().map_err(|_| {
348                format!("invalid DS_SERVER__SSE_RECONNECT_INTERVAL_SECS value: '{sse_reconnect_interval_secs}'")
349            })?;
350        }
351
352        if let Some(max_memory_bytes) = get("DS_LIMITS__MAX_MEMORY_BYTES") {
353            self.max_memory_bytes = max_memory_bytes.parse().map_err(|_| {
354                format!("invalid DS_LIMITS__MAX_MEMORY_BYTES value: '{max_memory_bytes}'")
355            })?;
356        }
357        if let Some(max_stream_bytes) = get("DS_LIMITS__MAX_STREAM_BYTES") {
358            self.max_stream_bytes = max_stream_bytes.parse().map_err(|_| {
359                format!("invalid DS_LIMITS__MAX_STREAM_BYTES value: '{max_stream_bytes}'")
360            })?;
361        }
362        if let Some(max_stream_name_bytes) = get("DS_LIMITS__MAX_STREAM_NAME_BYTES") {
363            self.max_stream_name_bytes = max_stream_name_bytes.parse().map_err(|_| {
364                format!("invalid DS_LIMITS__MAX_STREAM_NAME_BYTES value: '{max_stream_name_bytes}'")
365            })?;
366        }
367        if let Some(max_stream_name_segments) = get("DS_LIMITS__MAX_STREAM_NAME_SEGMENTS") {
368            self.max_stream_name_segments = max_stream_name_segments.parse().map_err(|_| {
369                format!("invalid DS_LIMITS__MAX_STREAM_NAME_SEGMENTS value: '{max_stream_name_segments}'")
370            })?;
371        }
372
373        if let Some(cors_origins) = get("DS_HTTP__CORS_ORIGINS") {
374            self.cors_origins = cors_origins;
375        }
376        if let Some(stream_base_path) = get("DS_HTTP__STREAM_BASE_PATH") {
377            self.stream_base_path = Self::parse_stream_base_path_value(&stream_base_path)
378                .map_err(|reason| format!("invalid DS_HTTP__STREAM_BASE_PATH value: {reason}"))?;
379        }
380
381        if let Some(storage_mode) = get("DS_STORAGE__MODE") {
382            self.storage_mode = Self::parse_storage_mode_value(&storage_mode)
383                .ok_or_else(|| format!("invalid DS_STORAGE__MODE value: '{storage_mode}'"))?;
384        }
385
386        if let Some(data_dir) = get("DS_STORAGE__DATA_DIR") {
387            self.data_dir = data_dir;
388        }
389
390        if let Some(acid_shard_count) = get("DS_STORAGE__ACID_SHARD_COUNT") {
391            let parsed = acid_shard_count.parse::<usize>().map_err(|_| {
392                format!("invalid DS_STORAGE__ACID_SHARD_COUNT value: '{acid_shard_count}'")
393            })?;
394            if !Self::valid_acid_shard_count(parsed) {
395                return Err(format!(
396                    "invalid DS_STORAGE__ACID_SHARD_COUNT value: '{acid_shard_count}' (must be power-of-two in 1..=256)"
397                ));
398            }
399            self.acid_shard_count = parsed;
400        }
401
402        if let Some(acid_backend) = get("DS_STORAGE__ACID_BACKEND") {
403            self.acid_backend = Self::parse_acid_backend_value(&acid_backend).ok_or_else(|| {
404                format!("invalid DS_STORAGE__ACID_BACKEND value: '{acid_backend}'")
405            })?;
406        }
407
408        if let Some(cert_path) = get("DS_TLS__CERT_PATH") {
409            self.tls_cert_path = Some(cert_path);
410        }
411        if let Some(key_path) = get("DS_TLS__KEY_PATH") {
412            self.tls_key_path = Some(key_path);
413        }
414
415        if let Some(rust_log) = get("DS_LOG__RUST_LOG") {
416            self.rust_log = rust_log;
417        }
418
419        Ok(())
420    }
421
422    /// Validate configuration invariants before server startup.
423    ///
424    /// # Errors
425    ///
426    /// Returns an error string when config is internally inconsistent.
427    pub fn validate(&self) -> std::result::Result<(), String> {
428        match (&self.tls_cert_path, &self.tls_key_path) {
429            (Some(_), Some(_)) | (None, None) => Ok(()),
430            (Some(_), None) => Err(
431                "tls.cert_path is set but tls.key_path is missing; both must be set together"
432                    .to_string(),
433            ),
434            (None, Some(_)) => Err(
435                "tls.key_path is set but tls.cert_path is missing; both must be set together"
436                    .to_string(),
437            ),
438        }?;
439
440        Self::validate_cors_origins(&self.cors_origins)?;
441        Self::parse_stream_base_path_value(&self.stream_base_path).map(|_| ())?;
442
443        if self.max_stream_name_bytes == 0 {
444            return Err(
445                "limits.max_stream_name_bytes must be at least 1".to_string(),
446            );
447        }
448        if self.max_stream_name_segments == 0 {
449            return Err(
450                "limits.max_stream_name_segments must be at least 1".to_string(),
451            );
452        }
453
454        Ok(())
455    }
456
457    fn validate_cors_origins(origins: &str) -> Result<(), String> {
458        if origins == "*" {
459            return Ok(());
460        }
461
462        let mut parsed_any = false;
463        for origin in origins.split(',').map(str::trim) {
464            if origin.is_empty() {
465                return Err("http.cors_origins contains an empty origin entry".to_string());
466            }
467            HeaderValue::from_str(origin)
468                .map_err(|_| format!("invalid http.cors_origins entry: '{origin}'"))?;
469            parsed_any = true;
470        }
471
472        if !parsed_any {
473            return Err(
474                "http.cors_origins must be '*' or a non-empty comma-separated list".to_string(),
475            );
476        }
477
478        Ok(())
479    }
480
481    fn parse_stream_base_path_value(raw: &str) -> Result<String, String> {
482        let trimmed = raw.trim();
483        if trimmed.is_empty() {
484            return Err("must be a non-empty absolute path".to_string());
485        }
486        if !trimmed.starts_with('/') {
487            return Err(format!("'{trimmed}' (must start with '/')"));
488        }
489
490        if trimmed == "/" {
491            return Ok("/".to_string());
492        }
493
494        let normalized = trimmed.trim_end_matches('/');
495        if normalized.is_empty() {
496            return Err("must be a non-empty absolute path".to_string());
497        }
498
499        Ok(normalized.to_string())
500    }
501
502    /// True when direct TLS termination is enabled on this server.
503    #[must_use]
504    pub fn tls_enabled(&self) -> bool {
505        self.tls_cert_path.is_some() && self.tls_key_path.is_some()
506    }
507
508    fn parse_storage_mode_value(raw: &str) -> Option<StorageMode> {
509        match raw.to_ascii_lowercase().as_str() {
510            "memory" => Some(StorageMode::Memory),
511            "file" | "file-durable" | "durable" => Some(StorageMode::FileDurable),
512            "file-fast" | "fast" => Some(StorageMode::FileFast),
513            "acid" | "redb" => Some(StorageMode::Acid),
514            _ => None,
515        }
516    }
517
518    fn valid_acid_shard_count(value: usize) -> bool {
519        (1..=256).contains(&value) && value.is_power_of_two()
520    }
521
522    fn parse_acid_backend_value(raw: &str) -> Option<AcidBackend> {
523        match raw.to_ascii_lowercase().as_str() {
524            "file" => Some(AcidBackend::File),
525            "memory" | "in-memory" | "inmemory" => Some(AcidBackend::InMemory),
526            _ => None,
527        }
528    }
529}
530
531impl Default for Config {
532    fn default() -> Self {
533        Self {
534            port: 4437,
535            max_memory_bytes: 100 * 1024 * 1024,
536            max_stream_bytes: 10 * 1024 * 1024,
537            max_stream_name_bytes: 1024,
538            max_stream_name_segments: 8,
539            cors_origins: "*".to_string(),
540            long_poll_timeout: Duration::from_secs(30),
541            sse_reconnect_interval_secs: 60,
542            stream_base_path: DEFAULT_STREAM_BASE_PATH.to_string(),
543            storage_mode: StorageMode::Memory,
544            data_dir: "./data/streams".to_string(),
545            acid_shard_count: 16,
546            acid_backend: AcidBackend::File,
547            tls_cert_path: None,
548            tls_key_path: None,
549            rust_log: "info".to_string(),
550        }
551    }
552}
553
554/// Typed wrapper for long-poll timeout, injected via axum `Extension`.
555#[derive(Debug, Clone, Copy)]
556pub struct LongPollTimeout(pub Duration);
557
558/// Typed wrapper for SSE reconnect interval in seconds (0 = disabled).
559///
560/// Matches Caddy's `sse_reconnect_interval`. Injected via axum `Extension`.
561#[derive(Debug, Clone, Copy)]
562pub struct SseReconnectInterval(pub u64);
563
564#[cfg(test)]
565mod tests {
566    use super::*;
567    use std::collections::HashMap;
568    use std::fs;
569    use std::sync::atomic::{AtomicU64, Ordering};
570
571    /// Helper: build a lookup function from key-value pairs.
572    fn lookup(pairs: &[(&str, &str)]) -> impl Fn(&str) -> Option<String> {
573        let map: HashMap<String, String> = pairs
574            .iter()
575            .map(|(k, v)| ((*k).to_string(), (*v).to_string()))
576            .collect();
577        move |key: &str| map.get(key).cloned()
578    }
579
580    fn temp_config_dir() -> PathBuf {
581        static COUNTER: AtomicU64 = AtomicU64::new(0);
582        let id = COUNTER.fetch_add(1, Ordering::Relaxed);
583        let path =
584            std::env::temp_dir().join(format!("ds-config-tests-{}-{}", std::process::id(), id));
585        fs::create_dir_all(&path).expect("create temp config dir");
586        path
587    }
588
589    #[test]
590    fn test_default_config() {
591        let config = Config::default();
592        assert_eq!(config.port, 4437);
593        assert_eq!(config.max_memory_bytes, 100 * 1024 * 1024);
594        assert_eq!(config.max_stream_bytes, 10 * 1024 * 1024);
595        assert_eq!(config.cors_origins, "*");
596        assert_eq!(config.long_poll_timeout, Duration::from_secs(30));
597        assert_eq!(config.sse_reconnect_interval_secs, 60);
598        assert_eq!(config.stream_base_path, DEFAULT_STREAM_BASE_PATH);
599        assert_eq!(config.storage_mode, StorageMode::Memory);
600        assert_eq!(config.data_dir, "./data/streams");
601        assert_eq!(config.acid_shard_count, 16);
602        assert_eq!(config.acid_backend, AcidBackend::File);
603        assert_eq!(config.tls_cert_path, None);
604        assert_eq!(config.tls_key_path, None);
605        assert_eq!(config.rust_log, "info");
606    }
607
608    #[test]
609    fn test_from_env_uses_defaults_when_no_ds_vars() {
610        // from_env reads real env; in test context no DS_* vars are set
611        let config = Config::from_env().expect("config from env");
612        assert_eq!(config.port, 4437);
613        assert_eq!(config.storage_mode, StorageMode::Memory);
614        assert_eq!(config.rust_log, "info");
615    }
616
617    #[test]
618    fn test_env_overrides_parse_all_ds_vars() {
619        let mut config = Config::default();
620        let get = lookup(&[
621            ("DS_SERVER__PORT", "8080"),
622            ("DS_LIMITS__MAX_MEMORY_BYTES", "200000000"),
623            ("DS_LIMITS__MAX_STREAM_BYTES", "20000000"),
624            ("DS_HTTP__CORS_ORIGINS", "https://example.com"),
625            ("DS_SERVER__LONG_POLL_TIMEOUT_SECS", "5"),
626            ("DS_SERVER__SSE_RECONNECT_INTERVAL_SECS", "120"),
627            ("DS_HTTP__STREAM_BASE_PATH", "/streams"),
628            ("DS_STORAGE__MODE", "file-fast"),
629            ("DS_STORAGE__DATA_DIR", "/tmp/ds-store"),
630            ("DS_STORAGE__ACID_SHARD_COUNT", "32"),
631            ("DS_TLS__CERT_PATH", "/tmp/cert.pem"),
632            ("DS_TLS__KEY_PATH", "/tmp/key.pem"),
633            ("DS_LOG__RUST_LOG", "debug"),
634        ]);
635        config
636            .apply_env_overrides(&get)
637            .expect("apply env overrides");
638        assert_eq!(config.port, 8080);
639        assert_eq!(config.max_memory_bytes, 200_000_000);
640        assert_eq!(config.max_stream_bytes, 20_000_000);
641        assert_eq!(config.cors_origins, "https://example.com");
642        assert_eq!(config.long_poll_timeout, Duration::from_secs(5));
643        assert_eq!(config.sse_reconnect_interval_secs, 120);
644        assert_eq!(config.stream_base_path, "/streams");
645        assert_eq!(config.storage_mode, StorageMode::FileFast);
646        assert_eq!(config.data_dir, "/tmp/ds-store");
647        assert_eq!(config.acid_shard_count, 32);
648        assert_eq!(config.tls_cert_path.as_deref(), Some("/tmp/cert.pem"));
649        assert_eq!(config.tls_key_path.as_deref(), Some("/tmp/key.pem"));
650        assert_eq!(config.rust_log, "debug");
651    }
652
653    #[test]
654    fn test_env_overrides_reject_unparseable_values() {
655        let mut config = Config::default();
656        let get = lookup(&[
657            ("DS_SERVER__PORT", "not-a-number"),
658            ("DS_LIMITS__MAX_MEMORY_BYTES", ""),
659            ("DS_SERVER__LONG_POLL_TIMEOUT_SECS", "abc"),
660        ]);
661        let err = config
662            .apply_env_overrides(&get)
663            .expect_err("invalid env override should fail");
664        assert_eq!(err, "invalid DS_SERVER__PORT value: 'not-a-number'");
665        assert_eq!(config.port, 4437);
666        assert_eq!(config.max_memory_bytes, 100 * 1024 * 1024);
667        assert_eq!(config.long_poll_timeout, Duration::from_secs(30));
668    }
669
670    #[test]
671    fn test_env_overrides_partial() {
672        let mut config = Config::default();
673        let get = lookup(&[("DS_SERVER__PORT", "9090")]);
674        config
675            .apply_env_overrides(&get)
676            .expect("apply env overrides");
677        assert_eq!(config.port, 9090);
678        // Everything else stays at defaults
679        assert_eq!(config.storage_mode, StorageMode::Memory);
680        assert_eq!(config.rust_log, "info");
681    }
682
683    #[test]
684    fn test_from_sources_file_layers_and_env_override() {
685        let config_dir = temp_config_dir();
686        fs::write(
687            config_dir.join("default.toml"),
688            r#"
689                [server]
690                port = 4437
691                [http]
692                stream_base_path = "/v1/stream"
693                [storage]
694                mode = "memory"
695                [log]
696                rust_log = "warn"
697            "#,
698        )
699        .expect("write default.toml");
700
701        fs::write(
702            config_dir.join("dev.toml"),
703            r#"
704                [server]
705                port = 7777
706                [http]
707                stream_base_path = "/streams"
708                [storage]
709                mode = "file-fast"
710                data_dir = "/tmp/dev-store"
711            "#,
712        )
713        .expect("write dev.toml");
714
715        fs::write(
716            config_dir.join("local.toml"),
717            r"
718                [server]
719                port = 8888
720            ",
721        )
722        .expect("write local.toml");
723
724        let options = ConfigLoadOptions {
725            config_dir,
726            profile: "dev".to_string(),
727            config_override: None,
728        };
729
730        // DS_SERVER__PORT env override wins over all TOML layers
731        let env = lookup(&[("DS_SERVER__PORT", "9999"), ("DS_LOG__RUST_LOG", "debug")]);
732        let config = Config::from_sources_with_lookup(&options, &env).expect("config from sources");
733
734        assert_eq!(config.port, 9999);
735        assert_eq!(config.stream_base_path, "/streams");
736        assert_eq!(config.storage_mode, StorageMode::FileFast);
737        assert_eq!(config.data_dir, "/tmp/dev-store");
738        assert_eq!(config.rust_log, "debug");
739    }
740
741    #[test]
742    fn test_from_sources_env_overrides_toml() {
743        let config_dir = temp_config_dir();
744        fs::write(
745            config_dir.join("default.toml"),
746            r#"
747                [server]
748                port = 4437
749                [storage]
750                mode = "memory"
751            "#,
752        )
753        .expect("write default.toml");
754
755        let options = ConfigLoadOptions {
756            config_dir,
757            profile: "default".to_string(),
758            config_override: None,
759        };
760
761        let env = lookup(&[
762            ("DS_SERVER__PORT", "12345"),
763            ("DS_STORAGE__MODE", "acid"),
764            ("DS_STORAGE__ACID_SHARD_COUNT", "32"),
765            ("DS_TLS__CERT_PATH", "/tmp/cert.pem"),
766            ("DS_TLS__KEY_PATH", "/tmp/key.pem"),
767        ]);
768        let config = Config::from_sources_with_lookup(&options, &env).expect("config from sources");
769
770        assert_eq!(config.port, 12345);
771        assert_eq!(config.storage_mode, StorageMode::Acid);
772        assert_eq!(config.acid_shard_count, 32);
773        assert_eq!(config.tls_cert_path.as_deref(), Some("/tmp/cert.pem"));
774        assert_eq!(config.tls_key_path.as_deref(), Some("/tmp/key.pem"));
775    }
776
777    #[test]
778    fn test_validate_tls_pair_ok_when_both_absent_or_present() {
779        let mut config = Config::default();
780        assert!(config.validate().is_ok());
781        assert!(!config.tls_enabled());
782
783        config.tls_cert_path = Some("/tmp/cert.pem".to_string());
784        config.tls_key_path = Some("/tmp/key.pem".to_string());
785        assert!(config.validate().is_ok());
786        assert!(config.tls_enabled());
787    }
788
789    #[test]
790    fn test_validate_tls_pair_rejects_partial_configuration() {
791        let mut config = Config {
792            tls_cert_path: Some("/tmp/cert.pem".to_string()),
793            ..Config::default()
794        };
795        assert!(config.validate().is_err());
796
797        config.tls_cert_path = None;
798        config.tls_key_path = Some("/tmp/key.pem".to_string());
799        assert!(config.validate().is_err());
800    }
801
802    #[test]
803    fn test_storage_mode_aliases() {
804        let mut config = Config::default();
805        config
806            .apply_env_overrides(&lookup(&[("DS_STORAGE__MODE", "acid")]))
807            .expect("apply env overrides");
808        assert_eq!(config.storage_mode, StorageMode::Acid);
809
810        let mut config = Config::default();
811        config
812            .apply_env_overrides(&lookup(&[("DS_STORAGE__MODE", "redb")]))
813            .expect("apply env overrides");
814        assert_eq!(config.storage_mode, StorageMode::Acid);
815    }
816
817    #[test]
818    fn test_acid_shard_count_valid_values() {
819        let mut config = Config::default();
820        config
821            .apply_env_overrides(&lookup(&[("DS_STORAGE__ACID_SHARD_COUNT", "1")]))
822            .expect("apply env overrides");
823        assert_eq!(config.acid_shard_count, 1);
824
825        let mut config = Config::default();
826        config
827            .apply_env_overrides(&lookup(&[("DS_STORAGE__ACID_SHARD_COUNT", "256")]))
828            .expect("apply env overrides");
829        assert_eq!(config.acid_shard_count, 256);
830    }
831
832    #[test]
833    fn test_acid_shard_count_invalid_values_return_error() {
834        let mut config = Config::default();
835        let err = config
836            .apply_env_overrides(&lookup(&[("DS_STORAGE__ACID_SHARD_COUNT", "0")]))
837            .expect_err("invalid shard count should fail");
838        assert_eq!(
839            err,
840            "invalid DS_STORAGE__ACID_SHARD_COUNT value: '0' (must be power-of-two in 1..=256)"
841        );
842        assert_eq!(config.acid_shard_count, 16);
843
844        let mut config = Config::default();
845        let err = config
846            .apply_env_overrides(&lookup(&[("DS_STORAGE__ACID_SHARD_COUNT", "3")]))
847            .expect_err("invalid shard count should fail");
848        assert_eq!(
849            err,
850            "invalid DS_STORAGE__ACID_SHARD_COUNT value: '3' (must be power-of-two in 1..=256)"
851        );
852        assert_eq!(config.acid_shard_count, 16);
853
854        let mut config = Config::default();
855        let err = config
856            .apply_env_overrides(&lookup(&[("DS_STORAGE__ACID_SHARD_COUNT", "abc")]))
857            .expect_err("invalid shard count should fail");
858        assert_eq!(err, "invalid DS_STORAGE__ACID_SHARD_COUNT value: 'abc'");
859        assert_eq!(config.acid_shard_count, 16);
860    }
861
862    #[test]
863    fn test_acid_backend_env_override() {
864        let mut config = Config::default();
865        assert_eq!(config.acid_backend, AcidBackend::File);
866
867        config
868            .apply_env_overrides(&lookup(&[("DS_STORAGE__ACID_BACKEND", "memory")]))
869            .expect("apply env overrides");
870        assert_eq!(config.acid_backend, AcidBackend::InMemory);
871
872        let mut config = Config::default();
873        config
874            .apply_env_overrides(&lookup(&[("DS_STORAGE__ACID_BACKEND", "in-memory")]))
875            .expect("apply env overrides");
876        assert_eq!(config.acid_backend, AcidBackend::InMemory);
877
878        let mut config = Config::default();
879        config
880            .apply_env_overrides(&lookup(&[("DS_STORAGE__ACID_BACKEND", "file")]))
881            .expect("apply env overrides");
882        assert_eq!(config.acid_backend, AcidBackend::File);
883    }
884
885    #[test]
886    fn test_acid_backend_env_override_rejects_invalid() {
887        let mut config = Config::default();
888        let err = config
889            .apply_env_overrides(&lookup(&[("DS_STORAGE__ACID_BACKEND", "sqlite")]))
890            .expect_err("invalid acid backend should fail");
891        assert_eq!(err, "invalid DS_STORAGE__ACID_BACKEND value: 'sqlite'");
892        assert_eq!(config.acid_backend, AcidBackend::File);
893    }
894
895    #[test]
896    fn test_env_overrides_reject_invalid_storage_mode() {
897        let mut config = Config::default();
898        let err = config
899            .apply_env_overrides(&lookup(&[("DS_STORAGE__MODE", "memroy")]))
900            .expect_err("invalid storage mode should fail");
901        assert_eq!(err, "invalid DS_STORAGE__MODE value: 'memroy'");
902    }
903
904    #[test]
905    fn test_validate_rejects_invalid_cors_origins() {
906        let config = Config {
907            cors_origins: "https://good.example, ,https://other.example".to_string(),
908            ..Config::default()
909        };
910        assert_eq!(
911            config
912                .validate()
913                .expect_err("invalid cors origins should fail"),
914            "http.cors_origins contains an empty origin entry"
915        );
916    }
917
918    #[test]
919    fn test_stream_base_path_normalizes_trailing_slash() {
920        let mut config = Config::default();
921        config
922            .apply_env_overrides(&lookup(&[("DS_HTTP__STREAM_BASE_PATH", "/streams/")]))
923            .expect("apply env overrides");
924        assert_eq!(config.stream_base_path, "/streams");
925    }
926
927    #[test]
928    fn test_stream_base_path_rejects_relative_path() {
929        let mut config = Config::default();
930        let err = config
931            .apply_env_overrides(&lookup(&[("DS_HTTP__STREAM_BASE_PATH", "streams")]))
932            .expect_err("relative base path should fail");
933        assert_eq!(
934            err,
935            "invalid DS_HTTP__STREAM_BASE_PATH value: 'streams' (must start with '/')"
936        );
937    }
938
939    #[test]
940    fn test_validate_rejects_invalid_stream_base_path() {
941        let config = Config {
942            stream_base_path: "streams".to_string(),
943            ..Config::default()
944        };
945        assert_eq!(
946            config
947                .validate()
948                .expect_err("invalid stream base path should fail"),
949            "'streams' (must start with '/')"
950        );
951    }
952
953    #[test]
954    fn test_long_poll_timeout_newtype() {
955        let timeout = LongPollTimeout(Duration::from_secs(10));
956        assert_eq!(timeout.0, Duration::from_secs(10));
957    }
958
959    #[test]
960    fn test_sse_reconnect_interval_newtype() {
961        let interval = SseReconnectInterval(120);
962        assert_eq!(interval.0, 120);
963    }
964}