Skip to main content

durable_streams_server/
config.rs

1use axum::http::HeaderValue;
2use figment::{
3    Figment,
4    providers::{Format, Toml},
5};
6use serde::Deserialize;
7use std::env;
8use std::path::PathBuf;
9use std::time::Duration;
10
11/// Storage runtime mode.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum StorageMode {
14    /// In-memory backend.
15    Memory,
16    /// File backend without fsync/fdatasync on every append.
17    FileFast,
18    /// File backend with fsync/fdatasync on every append.
19    FileDurable,
20    /// ACID backend using sharded redb databases.
21    Acid,
22}
23
24impl StorageMode {
25    #[must_use]
26    pub fn as_str(self) -> &'static str {
27        match self {
28            Self::Memory => "memory",
29            Self::FileFast => "file-fast",
30            Self::FileDurable => "file-durable",
31            Self::Acid => "acid",
32        }
33    }
34
35    #[must_use]
36    pub fn uses_file_backend(self) -> bool {
37        matches!(self, Self::FileFast | Self::FileDurable)
38    }
39
40    #[must_use]
41    pub fn sync_on_append(self) -> bool {
42        matches!(self, Self::FileDurable)
43    }
44}
45
46/// Server configuration
47#[derive(Debug, Clone)]
48pub struct Config {
49    /// Port to bind the server to
50    pub port: u16,
51    /// Maximum total memory usage in bytes
52    pub max_memory_bytes: u64,
53    /// Maximum bytes per stream
54    pub max_stream_bytes: u64,
55    /// CORS allowed origins (comma-separated, "*" for all)
56    pub cors_origins: String,
57    /// Long-poll timeout duration
58    pub long_poll_timeout: Duration,
59    /// SSE reconnect interval in seconds (0 disables).
60    ///
61    /// Matches Caddy's `sse_reconnect_interval`. Connections are closed after
62    /// this many idle seconds to enable CDN request collapsing.
63    pub sse_reconnect_interval_secs: u64,
64    /// Selected storage mode
65    pub storage_mode: StorageMode,
66    /// Root directory for file/acid-backed storage.
67    ///
68    /// Matches Caddy's `data_dir`.
69    pub data_dir: String,
70    /// Number of shards for acid/redb storage mode.
71    pub acid_shard_count: usize,
72    /// Optional TLS certificate path (PEM). Requires `tls_key_path`.
73    pub tls_cert_path: Option<String>,
74    /// Optional TLS private key path (PEM or PKCS#8). Requires `tls_cert_path`.
75    pub tls_key_path: Option<String>,
76    /// Default log filter when `RUST_LOG` is not explicitly set.
77    pub rust_log: String,
78}
79
80#[derive(Debug, Clone)]
81pub struct ConfigLoadOptions {
82    pub config_dir: PathBuf,
83    pub profile: String,
84    pub config_override: Option<PathBuf>,
85}
86
87impl Default for ConfigLoadOptions {
88    fn default() -> Self {
89        Self {
90            config_dir: PathBuf::from("config"),
91            profile: "default".to_string(),
92            config_override: None,
93        }
94    }
95}
96
97#[derive(Debug, Deserialize, Default)]
98#[serde(default)]
99struct SettingsFile {
100    server: ServerSettingsFile,
101    limits: LimitsSettingsFile,
102    http: HttpSettingsFile,
103    storage: StorageSettingsFile,
104    tls: TlsSettingsFile,
105    log: LogSettingsFile,
106}
107
108#[derive(Debug, Deserialize, Default)]
109#[serde(default)]
110struct ServerSettingsFile {
111    port: Option<u16>,
112    long_poll_timeout_secs: Option<u64>,
113    sse_reconnect_interval_secs: Option<u64>,
114}
115
116#[derive(Debug, Deserialize, Default)]
117#[serde(default)]
118struct LimitsSettingsFile {
119    max_memory_bytes: Option<u64>,
120    max_stream_bytes: Option<u64>,
121}
122
123#[derive(Debug, Deserialize, Default)]
124#[serde(default)]
125struct HttpSettingsFile {
126    cors_origins: Option<String>,
127}
128
129#[derive(Debug, Deserialize, Default)]
130#[serde(default)]
131struct StorageSettingsFile {
132    mode: Option<String>,
133    data_dir: Option<String>,
134    acid_shard_count: Option<usize>,
135}
136
137#[derive(Debug, Deserialize, Default)]
138#[serde(default)]
139struct TlsSettingsFile {
140    cert_path: Option<String>,
141    key_path: Option<String>,
142}
143
144#[derive(Debug, Deserialize, Default)]
145#[serde(default)]
146struct LogSettingsFile {
147    rust_log: Option<String>,
148}
149
150impl Config {
151    /// Load configuration from `DS_*` environment variables with sensible defaults.
152    ///
153    /// Used by tests and as a simple entry point when TOML layering is not needed.
154    /// # Errors
155    ///
156    /// Returns an error when any `DS_*` environment variable is present but invalid.
157    pub fn from_env() -> Result<Self, String> {
158        let mut config = Self::default();
159        config.apply_env_overrides(&|key| env::var(key).ok())?;
160        Ok(config)
161    }
162
163    /// Load configuration from layered TOML files plus environment overrides.
164    ///
165    /// Order (later wins):
166    /// 1. built-in defaults
167    /// 2. `config/default.toml` (if present)
168    /// 3. `config/<profile>.toml` (if present)
169    /// 4. `config/local.toml` (if present)
170    /// 5. `--config <path>` override file (if provided)
171    /// 6. `DS_*` env vars
172    ///
173    /// # Errors
174    ///
175    /// Returns an error when config files cannot be parsed or an explicit
176    /// override file path does not exist/read.
177    pub fn from_sources(options: &ConfigLoadOptions) -> Result<Self, String> {
178        let get = |key: &str| env::var(key).ok();
179        Self::from_sources_with_lookup(options, &get)
180    }
181
182    fn from_sources_with_lookup(
183        options: &ConfigLoadOptions,
184        get: &impl Fn(&str) -> Option<String>,
185    ) -> Result<Self, String> {
186        let mut figment = Figment::new();
187
188        let default_path = options.config_dir.join("default.toml");
189        if default_path.is_file() {
190            figment = figment.merge(Toml::file(&default_path));
191        }
192
193        let profile_path = options
194            .config_dir
195            .join(format!("{}.toml", options.profile.trim()));
196        if profile_path.is_file() {
197            figment = figment.merge(Toml::file(&profile_path));
198        }
199
200        let local_path = options.config_dir.join("local.toml");
201        if local_path.is_file() {
202            figment = figment.merge(Toml::file(&local_path));
203        }
204
205        if let Some(override_path) = &options.config_override {
206            if !override_path.is_file() {
207                return Err(format!(
208                    "config override file not found: '{}'",
209                    override_path.display()
210                ));
211            }
212            figment = figment.merge(Toml::file(override_path));
213        }
214
215        let settings: SettingsFile = figment
216            .extract()
217            .map_err(|e| format!("failed to parse TOML config: {e}"))?;
218
219        let mut config = Self::apply_file_settings(settings)?;
220        config.apply_env_overrides(get)?;
221        Ok(config)
222    }
223
224    fn apply_file_settings(settings: SettingsFile) -> Result<Self, String> {
225        let mut config = Self::default();
226
227        if let Some(port) = settings.server.port {
228            config.port = port;
229        }
230        if let Some(long_poll_timeout_secs) = settings.server.long_poll_timeout_secs {
231            config.long_poll_timeout = Duration::from_secs(long_poll_timeout_secs);
232        }
233        if let Some(sse_reconnect_interval_secs) = settings.server.sse_reconnect_interval_secs {
234            config.sse_reconnect_interval_secs = sse_reconnect_interval_secs;
235        }
236
237        if let Some(max_memory_bytes) = settings.limits.max_memory_bytes {
238            config.max_memory_bytes = max_memory_bytes;
239        }
240        if let Some(max_stream_bytes) = settings.limits.max_stream_bytes {
241            config.max_stream_bytes = max_stream_bytes;
242        }
243
244        if let Some(cors_origins) = settings.http.cors_origins {
245            config.cors_origins = cors_origins;
246        }
247
248        if let Some(mode) = settings.storage.mode {
249            config.storage_mode = Self::parse_storage_mode_value(&mode)
250                .ok_or_else(|| format!("invalid storage.mode value: '{mode}'"))?;
251        }
252        if let Some(data_dir) = settings.storage.data_dir {
253            config.data_dir = data_dir;
254        }
255        if let Some(acid_shard_count) = settings.storage.acid_shard_count {
256            if Self::valid_acid_shard_count(acid_shard_count) {
257                config.acid_shard_count = acid_shard_count;
258            } else {
259                return Err(format!(
260                    "invalid storage.acid_shard_count value: '{acid_shard_count}' (must be power-of-two in 1..=256)"
261                ));
262            }
263        }
264
265        config.tls_cert_path = settings.tls.cert_path;
266        config.tls_key_path = settings.tls.key_path;
267
268        if let Some(rust_log) = settings.log.rust_log {
269            config.rust_log = rust_log;
270        }
271
272        Ok(config)
273    }
274
275    /// Apply `DS_*` environment variable overrides on top of current config.
276    fn apply_env_overrides(&mut self, get: &impl Fn(&str) -> Option<String>) -> Result<(), String> {
277        if let Some(port) = get("DS_SERVER__PORT") {
278            self.port = port
279                .parse()
280                .map_err(|_| format!("invalid DS_SERVER__PORT value: '{port}'"))?;
281        }
282        if let Some(long_poll_timeout_secs) = get("DS_SERVER__LONG_POLL_TIMEOUT_SECS") {
283            self.long_poll_timeout = Duration::from_secs(
284                long_poll_timeout_secs
285                    .parse()
286                    .map_err(|_| format!("invalid DS_SERVER__LONG_POLL_TIMEOUT_SECS value: '{long_poll_timeout_secs}'"))?,
287            );
288        }
289        if let Some(sse_reconnect_interval_secs) = get("DS_SERVER__SSE_RECONNECT_INTERVAL_SECS") {
290            self.sse_reconnect_interval_secs = sse_reconnect_interval_secs.parse().map_err(|_| {
291                format!("invalid DS_SERVER__SSE_RECONNECT_INTERVAL_SECS value: '{sse_reconnect_interval_secs}'")
292            })?;
293        }
294
295        if let Some(max_memory_bytes) = get("DS_LIMITS__MAX_MEMORY_BYTES") {
296            self.max_memory_bytes = max_memory_bytes.parse().map_err(|_| {
297                format!("invalid DS_LIMITS__MAX_MEMORY_BYTES value: '{max_memory_bytes}'")
298            })?;
299        }
300        if let Some(max_stream_bytes) = get("DS_LIMITS__MAX_STREAM_BYTES") {
301            self.max_stream_bytes = max_stream_bytes.parse().map_err(|_| {
302                format!("invalid DS_LIMITS__MAX_STREAM_BYTES value: '{max_stream_bytes}'")
303            })?;
304        }
305
306        if let Some(cors_origins) = get("DS_HTTP__CORS_ORIGINS") {
307            self.cors_origins = cors_origins;
308        }
309
310        if let Some(storage_mode) = get("DS_STORAGE__MODE") {
311            self.storage_mode = Self::parse_storage_mode_value(&storage_mode)
312                .ok_or_else(|| format!("invalid DS_STORAGE__MODE value: '{storage_mode}'"))?;
313        }
314
315        if let Some(data_dir) = get("DS_STORAGE__DATA_DIR") {
316            self.data_dir = data_dir;
317        }
318
319        if let Some(acid_shard_count) = get("DS_STORAGE__ACID_SHARD_COUNT") {
320            let parsed = acid_shard_count.parse::<usize>().map_err(|_| {
321                format!("invalid DS_STORAGE__ACID_SHARD_COUNT value: '{acid_shard_count}'")
322            })?;
323            if !Self::valid_acid_shard_count(parsed) {
324                return Err(format!(
325                    "invalid DS_STORAGE__ACID_SHARD_COUNT value: '{acid_shard_count}' (must be power-of-two in 1..=256)"
326                ));
327            }
328            self.acid_shard_count = parsed;
329        }
330
331        if let Some(cert_path) = get("DS_TLS__CERT_PATH") {
332            self.tls_cert_path = Some(cert_path);
333        }
334        if let Some(key_path) = get("DS_TLS__KEY_PATH") {
335            self.tls_key_path = Some(key_path);
336        }
337
338        if let Some(rust_log) = get("DS_LOG__RUST_LOG") {
339            self.rust_log = rust_log;
340        }
341
342        Ok(())
343    }
344
345    /// Validate configuration invariants before server startup.
346    ///
347    /// # Errors
348    ///
349    /// Returns an error string when config is internally inconsistent.
350    pub fn validate(&self) -> std::result::Result<(), String> {
351        match (&self.tls_cert_path, &self.tls_key_path) {
352            (Some(_), Some(_)) | (None, None) => Ok(()),
353            (Some(_), None) => Err(
354                "tls.cert_path is set but tls.key_path is missing; both must be set together"
355                    .to_string(),
356            ),
357            (None, Some(_)) => Err(
358                "tls.key_path is set but tls.cert_path is missing; both must be set together"
359                    .to_string(),
360            ),
361        }?;
362
363        Self::validate_cors_origins(&self.cors_origins)?;
364
365        Ok(())
366    }
367
368    fn validate_cors_origins(origins: &str) -> Result<(), String> {
369        if origins == "*" {
370            return Ok(());
371        }
372
373        let mut parsed_any = false;
374        for origin in origins.split(',').map(str::trim) {
375            if origin.is_empty() {
376                return Err("http.cors_origins contains an empty origin entry".to_string());
377            }
378            HeaderValue::from_str(origin)
379                .map_err(|_| format!("invalid http.cors_origins entry: '{origin}'"))?;
380            parsed_any = true;
381        }
382
383        if !parsed_any {
384            return Err(
385                "http.cors_origins must be '*' or a non-empty comma-separated list".to_string(),
386            );
387        }
388
389        Ok(())
390    }
391
392    /// True when direct TLS termination is enabled on this server.
393    #[must_use]
394    pub fn tls_enabled(&self) -> bool {
395        self.tls_cert_path.is_some() && self.tls_key_path.is_some()
396    }
397
398    fn parse_storage_mode_value(raw: &str) -> Option<StorageMode> {
399        match raw.to_ascii_lowercase().as_str() {
400            "memory" => Some(StorageMode::Memory),
401            "file" | "file-durable" | "durable" => Some(StorageMode::FileDurable),
402            "file-fast" | "fast" => Some(StorageMode::FileFast),
403            "acid" | "redb" => Some(StorageMode::Acid),
404            _ => None,
405        }
406    }
407
408    fn valid_acid_shard_count(value: usize) -> bool {
409        (1..=256).contains(&value) && value.is_power_of_two()
410    }
411}
412
413impl Default for Config {
414    fn default() -> Self {
415        Self {
416            port: 4437,
417            max_memory_bytes: 100 * 1024 * 1024,
418            max_stream_bytes: 10 * 1024 * 1024,
419            cors_origins: "*".to_string(),
420            long_poll_timeout: Duration::from_secs(30),
421            sse_reconnect_interval_secs: 60,
422            storage_mode: StorageMode::Memory,
423            data_dir: "./data/streams".to_string(),
424            acid_shard_count: 16,
425            tls_cert_path: None,
426            tls_key_path: None,
427            rust_log: "info".to_string(),
428        }
429    }
430}
431
432/// Typed wrapper for long-poll timeout, injected via axum `Extension`.
433#[derive(Debug, Clone, Copy)]
434pub struct LongPollTimeout(pub Duration);
435
436/// Typed wrapper for SSE reconnect interval in seconds (0 = disabled).
437///
438/// Matches Caddy's `sse_reconnect_interval`. Injected via axum `Extension`.
439#[derive(Debug, Clone, Copy)]
440pub struct SseReconnectInterval(pub u64);
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445    use std::collections::HashMap;
446    use std::fs;
447    use std::sync::atomic::{AtomicU64, Ordering};
448
449    /// Helper: build a lookup function from key-value pairs.
450    fn lookup(pairs: &[(&str, &str)]) -> impl Fn(&str) -> Option<String> {
451        let map: HashMap<String, String> = pairs
452            .iter()
453            .map(|(k, v)| ((*k).to_string(), (*v).to_string()))
454            .collect();
455        move |key: &str| map.get(key).cloned()
456    }
457
458    fn temp_config_dir() -> PathBuf {
459        static COUNTER: AtomicU64 = AtomicU64::new(0);
460        let id = COUNTER.fetch_add(1, Ordering::Relaxed);
461        let path =
462            std::env::temp_dir().join(format!("ds-config-tests-{}-{}", std::process::id(), id));
463        fs::create_dir_all(&path).expect("create temp config dir");
464        path
465    }
466
467    #[test]
468    fn test_default_config() {
469        let config = Config::default();
470        assert_eq!(config.port, 4437);
471        assert_eq!(config.max_memory_bytes, 100 * 1024 * 1024);
472        assert_eq!(config.max_stream_bytes, 10 * 1024 * 1024);
473        assert_eq!(config.cors_origins, "*");
474        assert_eq!(config.long_poll_timeout, Duration::from_secs(30));
475        assert_eq!(config.sse_reconnect_interval_secs, 60);
476        assert_eq!(config.storage_mode, StorageMode::Memory);
477        assert_eq!(config.data_dir, "./data/streams");
478        assert_eq!(config.acid_shard_count, 16);
479        assert_eq!(config.tls_cert_path, None);
480        assert_eq!(config.tls_key_path, None);
481        assert_eq!(config.rust_log, "info");
482    }
483
484    #[test]
485    fn test_from_env_uses_defaults_when_no_ds_vars() {
486        // from_env reads real env; in test context no DS_* vars are set
487        let config = Config::from_env().expect("config from env");
488        assert_eq!(config.port, 4437);
489        assert_eq!(config.storage_mode, StorageMode::Memory);
490        assert_eq!(config.rust_log, "info");
491    }
492
493    #[test]
494    fn test_env_overrides_parse_all_ds_vars() {
495        let mut config = Config::default();
496        let get = lookup(&[
497            ("DS_SERVER__PORT", "8080"),
498            ("DS_LIMITS__MAX_MEMORY_BYTES", "200000000"),
499            ("DS_LIMITS__MAX_STREAM_BYTES", "20000000"),
500            ("DS_HTTP__CORS_ORIGINS", "https://example.com"),
501            ("DS_SERVER__LONG_POLL_TIMEOUT_SECS", "5"),
502            ("DS_SERVER__SSE_RECONNECT_INTERVAL_SECS", "120"),
503            ("DS_STORAGE__MODE", "file-fast"),
504            ("DS_STORAGE__DATA_DIR", "/tmp/ds-store"),
505            ("DS_STORAGE__ACID_SHARD_COUNT", "32"),
506            ("DS_TLS__CERT_PATH", "/tmp/cert.pem"),
507            ("DS_TLS__KEY_PATH", "/tmp/key.pem"),
508            ("DS_LOG__RUST_LOG", "debug"),
509        ]);
510        config
511            .apply_env_overrides(&get)
512            .expect("apply env overrides");
513        assert_eq!(config.port, 8080);
514        assert_eq!(config.max_memory_bytes, 200_000_000);
515        assert_eq!(config.max_stream_bytes, 20_000_000);
516        assert_eq!(config.cors_origins, "https://example.com");
517        assert_eq!(config.long_poll_timeout, Duration::from_secs(5));
518        assert_eq!(config.sse_reconnect_interval_secs, 120);
519        assert_eq!(config.storage_mode, StorageMode::FileFast);
520        assert_eq!(config.data_dir, "/tmp/ds-store");
521        assert_eq!(config.acid_shard_count, 32);
522        assert_eq!(config.tls_cert_path.as_deref(), Some("/tmp/cert.pem"));
523        assert_eq!(config.tls_key_path.as_deref(), Some("/tmp/key.pem"));
524        assert_eq!(config.rust_log, "debug");
525    }
526
527    #[test]
528    fn test_env_overrides_reject_unparseable_values() {
529        let mut config = Config::default();
530        let get = lookup(&[
531            ("DS_SERVER__PORT", "not-a-number"),
532            ("DS_LIMITS__MAX_MEMORY_BYTES", ""),
533            ("DS_SERVER__LONG_POLL_TIMEOUT_SECS", "abc"),
534        ]);
535        let err = config
536            .apply_env_overrides(&get)
537            .expect_err("invalid env override should fail");
538        assert_eq!(err, "invalid DS_SERVER__PORT value: 'not-a-number'");
539        assert_eq!(config.port, 4437);
540        assert_eq!(config.max_memory_bytes, 100 * 1024 * 1024);
541        assert_eq!(config.long_poll_timeout, Duration::from_secs(30));
542    }
543
544    #[test]
545    fn test_env_overrides_partial() {
546        let mut config = Config::default();
547        let get = lookup(&[("DS_SERVER__PORT", "9090")]);
548        config
549            .apply_env_overrides(&get)
550            .expect("apply env overrides");
551        assert_eq!(config.port, 9090);
552        // Everything else stays at defaults
553        assert_eq!(config.storage_mode, StorageMode::Memory);
554        assert_eq!(config.rust_log, "info");
555    }
556
557    #[test]
558    fn test_from_sources_file_layers_and_env_override() {
559        let config_dir = temp_config_dir();
560        fs::write(
561            config_dir.join("default.toml"),
562            r#"
563                [server]
564                port = 4437
565                [storage]
566                mode = "memory"
567                [log]
568                rust_log = "warn"
569            "#,
570        )
571        .expect("write default.toml");
572
573        fs::write(
574            config_dir.join("dev.toml"),
575            r#"
576                [server]
577                port = 7777
578                [storage]
579                mode = "file-fast"
580                data_dir = "/tmp/dev-store"
581            "#,
582        )
583        .expect("write dev.toml");
584
585        fs::write(
586            config_dir.join("local.toml"),
587            r"
588                [server]
589                port = 8888
590            ",
591        )
592        .expect("write local.toml");
593
594        let options = ConfigLoadOptions {
595            config_dir,
596            profile: "dev".to_string(),
597            config_override: None,
598        };
599
600        // DS_SERVER__PORT env override wins over all TOML layers
601        let env = lookup(&[("DS_SERVER__PORT", "9999"), ("DS_LOG__RUST_LOG", "debug")]);
602        let config = Config::from_sources_with_lookup(&options, &env).expect("config from sources");
603
604        assert_eq!(config.port, 9999);
605        assert_eq!(config.storage_mode, StorageMode::FileFast);
606        assert_eq!(config.data_dir, "/tmp/dev-store");
607        assert_eq!(config.rust_log, "debug");
608    }
609
610    #[test]
611    fn test_from_sources_env_overrides_toml() {
612        let config_dir = temp_config_dir();
613        fs::write(
614            config_dir.join("default.toml"),
615            r#"
616                [server]
617                port = 4437
618                [storage]
619                mode = "memory"
620            "#,
621        )
622        .expect("write default.toml");
623
624        let options = ConfigLoadOptions {
625            config_dir,
626            profile: "default".to_string(),
627            config_override: None,
628        };
629
630        let env = lookup(&[
631            ("DS_SERVER__PORT", "12345"),
632            ("DS_STORAGE__MODE", "acid"),
633            ("DS_STORAGE__ACID_SHARD_COUNT", "32"),
634            ("DS_TLS__CERT_PATH", "/tmp/cert.pem"),
635            ("DS_TLS__KEY_PATH", "/tmp/key.pem"),
636        ]);
637        let config = Config::from_sources_with_lookup(&options, &env).expect("config from sources");
638
639        assert_eq!(config.port, 12345);
640        assert_eq!(config.storage_mode, StorageMode::Acid);
641        assert_eq!(config.acid_shard_count, 32);
642        assert_eq!(config.tls_cert_path.as_deref(), Some("/tmp/cert.pem"));
643        assert_eq!(config.tls_key_path.as_deref(), Some("/tmp/key.pem"));
644    }
645
646    #[test]
647    fn test_validate_tls_pair_ok_when_both_absent_or_present() {
648        let mut config = Config::default();
649        assert!(config.validate().is_ok());
650        assert!(!config.tls_enabled());
651
652        config.tls_cert_path = Some("/tmp/cert.pem".to_string());
653        config.tls_key_path = Some("/tmp/key.pem".to_string());
654        assert!(config.validate().is_ok());
655        assert!(config.tls_enabled());
656    }
657
658    #[test]
659    fn test_validate_tls_pair_rejects_partial_configuration() {
660        let mut config = Config {
661            tls_cert_path: Some("/tmp/cert.pem".to_string()),
662            ..Config::default()
663        };
664        assert!(config.validate().is_err());
665
666        config.tls_cert_path = None;
667        config.tls_key_path = Some("/tmp/key.pem".to_string());
668        assert!(config.validate().is_err());
669    }
670
671    #[test]
672    fn test_storage_mode_aliases() {
673        let mut config = Config::default();
674        config
675            .apply_env_overrides(&lookup(&[("DS_STORAGE__MODE", "acid")]))
676            .expect("apply env overrides");
677        assert_eq!(config.storage_mode, StorageMode::Acid);
678
679        let mut config = Config::default();
680        config
681            .apply_env_overrides(&lookup(&[("DS_STORAGE__MODE", "redb")]))
682            .expect("apply env overrides");
683        assert_eq!(config.storage_mode, StorageMode::Acid);
684    }
685
686    #[test]
687    fn test_acid_shard_count_valid_values() {
688        let mut config = Config::default();
689        config
690            .apply_env_overrides(&lookup(&[("DS_STORAGE__ACID_SHARD_COUNT", "1")]))
691            .expect("apply env overrides");
692        assert_eq!(config.acid_shard_count, 1);
693
694        let mut config = Config::default();
695        config
696            .apply_env_overrides(&lookup(&[("DS_STORAGE__ACID_SHARD_COUNT", "256")]))
697            .expect("apply env overrides");
698        assert_eq!(config.acid_shard_count, 256);
699    }
700
701    #[test]
702    fn test_acid_shard_count_invalid_values_return_error() {
703        let mut config = Config::default();
704        let err = config
705            .apply_env_overrides(&lookup(&[("DS_STORAGE__ACID_SHARD_COUNT", "0")]))
706            .expect_err("invalid shard count should fail");
707        assert_eq!(
708            err,
709            "invalid DS_STORAGE__ACID_SHARD_COUNT value: '0' (must be power-of-two in 1..=256)"
710        );
711        assert_eq!(config.acid_shard_count, 16);
712
713        let mut config = Config::default();
714        let err = config
715            .apply_env_overrides(&lookup(&[("DS_STORAGE__ACID_SHARD_COUNT", "3")]))
716            .expect_err("invalid shard count should fail");
717        assert_eq!(
718            err,
719            "invalid DS_STORAGE__ACID_SHARD_COUNT value: '3' (must be power-of-two in 1..=256)"
720        );
721        assert_eq!(config.acid_shard_count, 16);
722
723        let mut config = Config::default();
724        let err = config
725            .apply_env_overrides(&lookup(&[("DS_STORAGE__ACID_SHARD_COUNT", "abc")]))
726            .expect_err("invalid shard count should fail");
727        assert_eq!(err, "invalid DS_STORAGE__ACID_SHARD_COUNT value: 'abc'");
728        assert_eq!(config.acid_shard_count, 16);
729    }
730
731    #[test]
732    fn test_env_overrides_reject_invalid_storage_mode() {
733        let mut config = Config::default();
734        let err = config
735            .apply_env_overrides(&lookup(&[("DS_STORAGE__MODE", "memroy")]))
736            .expect_err("invalid storage mode should fail");
737        assert_eq!(err, "invalid DS_STORAGE__MODE value: 'memroy'");
738    }
739
740    #[test]
741    fn test_validate_rejects_invalid_cors_origins() {
742        let config = Config {
743            cors_origins: "https://good.example, ,https://other.example".to_string(),
744            ..Config::default()
745        };
746        assert_eq!(
747            config
748                .validate()
749                .expect_err("invalid cors origins should fail"),
750            "http.cors_origins contains an empty origin entry"
751        );
752    }
753
754    #[test]
755    fn test_long_poll_timeout_newtype() {
756        let timeout = LongPollTimeout(Duration::from_secs(10));
757        assert_eq!(timeout.0, Duration::from_secs(10));
758    }
759
760    #[test]
761    fn test_sse_reconnect_interval_newtype() {
762        let interval = SseReconnectInterval(120);
763        assert_eq!(interval.0, 120);
764    }
765}