1use crate::router::DEFAULT_STREAM_BASE_PATH;
8use axum::http::HeaderValue;
9use figment::{
10 Figment,
11 providers::{Format, Toml},
12};
13use serde::Deserialize;
14use std::env;
15use std::path::PathBuf;
16use std::time::Duration;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum StorageMode {
21 Memory,
23 FileFast,
25 FileDurable,
27 Acid,
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum AcidBackend {
34 File,
36 InMemory,
39}
40
41impl AcidBackend {
42 #[must_use]
43 pub fn as_str(self) -> &'static str {
44 match self {
45 Self::File => "file",
46 Self::InMemory => "memory",
47 }
48 }
49}
50
51impl StorageMode {
52 #[must_use]
53 pub fn as_str(self) -> &'static str {
54 match self {
55 Self::Memory => "memory",
56 Self::FileFast => "file-fast",
57 Self::FileDurable => "file-durable",
58 Self::Acid => "acid",
59 }
60 }
61
62 #[must_use]
63 pub fn uses_file_backend(self) -> bool {
64 matches!(self, Self::FileFast | Self::FileDurable)
65 }
66
67 #[must_use]
68 pub fn sync_on_append(self) -> bool {
69 matches!(self, Self::FileDurable)
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct Config {
76 pub port: u16,
78 pub max_memory_bytes: u64,
80 pub max_stream_bytes: u64,
82 pub max_stream_name_bytes: usize,
84 pub max_stream_name_segments: usize,
86 pub cors_origins: String,
88 pub long_poll_timeout: Duration,
90 pub sse_reconnect_interval_secs: u64,
95 pub stream_base_path: String,
97 pub storage_mode: StorageMode,
99 pub data_dir: String,
103 pub acid_shard_count: usize,
105 pub acid_backend: AcidBackend,
107 pub tls_cert_path: Option<String>,
109 pub tls_key_path: Option<String>,
111 pub rust_log: String,
113}
114
115#[derive(Debug, Clone)]
116pub struct ConfigLoadOptions {
117 pub config_dir: PathBuf,
119 pub profile: String,
121 pub config_override: Option<PathBuf>,
123}
124
125impl Default for ConfigLoadOptions {
126 fn default() -> Self {
127 Self {
128 config_dir: PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("config"),
129 profile: "default".to_string(),
130 config_override: None,
131 }
132 }
133}
134
135#[derive(Debug, Deserialize, Default)]
136#[serde(default)]
137struct SettingsFile {
138 server: ServerSettingsFile,
139 limits: LimitsSettingsFile,
140 http: HttpSettingsFile,
141 storage: StorageSettingsFile,
142 tls: TlsSettingsFile,
143 log: LogSettingsFile,
144}
145
146#[derive(Debug, Deserialize, Default)]
147#[serde(default)]
148struct ServerSettingsFile {
149 port: Option<u16>,
150 long_poll_timeout_secs: Option<u64>,
151 sse_reconnect_interval_secs: Option<u64>,
152}
153
154#[derive(Debug, Deserialize, Default)]
155#[serde(default)]
156#[allow(clippy::struct_field_names)]
157struct LimitsSettingsFile {
158 max_memory_bytes: Option<u64>,
159 max_stream_bytes: Option<u64>,
160 max_stream_name_bytes: Option<usize>,
161 max_stream_name_segments: Option<usize>,
162}
163
164#[derive(Debug, Deserialize, Default)]
165#[serde(default)]
166struct HttpSettingsFile {
167 cors_origins: Option<String>,
168 stream_base_path: Option<String>,
169}
170
171#[derive(Debug, Deserialize, Default)]
172#[serde(default)]
173struct StorageSettingsFile {
174 mode: Option<String>,
175 data_dir: Option<String>,
176 acid_shard_count: Option<usize>,
177 acid_backend: Option<String>,
178}
179
180#[derive(Debug, Deserialize, Default)]
181#[serde(default)]
182struct TlsSettingsFile {
183 cert_path: Option<String>,
184 key_path: Option<String>,
185}
186
187#[derive(Debug, Deserialize, Default)]
188#[serde(default)]
189struct LogSettingsFile {
190 rust_log: Option<String>,
191}
192
193impl Config {
194 pub fn from_env() -> Result<Self, String> {
201 let mut config = Self::default();
202 config.apply_env_overrides(&|key| env::var(key).ok())?;
203 Ok(config)
204 }
205
206 pub fn from_sources(options: &ConfigLoadOptions) -> Result<Self, String> {
221 let get = |key: &str| env::var(key).ok();
222 Self::from_sources_with_lookup(options, &get)
223 }
224
225 fn from_sources_with_lookup(
226 options: &ConfigLoadOptions,
227 get: &impl Fn(&str) -> Option<String>,
228 ) -> Result<Self, String> {
229 let mut figment = Figment::new();
230
231 let default_path = options.config_dir.join("default.toml");
232 if default_path.is_file() {
233 figment = figment.merge(Toml::file(&default_path));
234 }
235
236 let profile_path = options
237 .config_dir
238 .join(format!("{}.toml", options.profile.trim()));
239 if profile_path.is_file() {
240 figment = figment.merge(Toml::file(&profile_path));
241 }
242
243 let local_path = options.config_dir.join("local.toml");
244 if local_path.is_file() {
245 figment = figment.merge(Toml::file(&local_path));
246 }
247
248 if let Some(override_path) = &options.config_override {
249 if !override_path.is_file() {
250 return Err(format!(
251 "config override file not found: '{}'",
252 override_path.display()
253 ));
254 }
255 figment = figment.merge(Toml::file(override_path));
256 }
257
258 let settings: SettingsFile = figment
259 .extract()
260 .map_err(|e| format!("failed to parse TOML config: {e}"))?;
261
262 let mut config = Self::apply_file_settings(settings)?;
263 config.apply_env_overrides(get)?;
264 Ok(config)
265 }
266
267 fn apply_file_settings(settings: SettingsFile) -> Result<Self, String> {
268 let mut config = Self::default();
269
270 if let Some(port) = settings.server.port {
271 config.port = port;
272 }
273 if let Some(long_poll_timeout_secs) = settings.server.long_poll_timeout_secs {
274 config.long_poll_timeout = Duration::from_secs(long_poll_timeout_secs);
275 }
276 if let Some(sse_reconnect_interval_secs) = settings.server.sse_reconnect_interval_secs {
277 config.sse_reconnect_interval_secs = sse_reconnect_interval_secs;
278 }
279
280 if let Some(max_memory_bytes) = settings.limits.max_memory_bytes {
281 config.max_memory_bytes = max_memory_bytes;
282 }
283 if let Some(max_stream_bytes) = settings.limits.max_stream_bytes {
284 config.max_stream_bytes = max_stream_bytes;
285 }
286 if let Some(max_stream_name_bytes) = settings.limits.max_stream_name_bytes {
287 config.max_stream_name_bytes = max_stream_name_bytes;
288 }
289 if let Some(max_stream_name_segments) = settings.limits.max_stream_name_segments {
290 config.max_stream_name_segments = max_stream_name_segments;
291 }
292
293 if let Some(cors_origins) = settings.http.cors_origins {
294 config.cors_origins = cors_origins;
295 }
296 if let Some(stream_base_path) = settings.http.stream_base_path {
297 config.stream_base_path = Self::parse_stream_base_path_value(&stream_base_path)
298 .map_err(|reason| format!("invalid http.stream_base_path value: {reason}"))?;
299 }
300
301 if let Some(mode) = settings.storage.mode {
302 config.storage_mode = Self::parse_storage_mode_value(&mode)
303 .ok_or_else(|| format!("invalid storage.mode value: '{mode}'"))?;
304 }
305 if let Some(data_dir) = settings.storage.data_dir {
306 config.data_dir = data_dir;
307 }
308 if let Some(acid_shard_count) = settings.storage.acid_shard_count {
309 if Self::valid_acid_shard_count(acid_shard_count) {
310 config.acid_shard_count = acid_shard_count;
311 } else {
312 return Err(format!(
313 "invalid storage.acid_shard_count value: '{acid_shard_count}' (must be power-of-two in 1..=256)"
314 ));
315 }
316 }
317 if let Some(acid_backend) = settings.storage.acid_backend {
318 config.acid_backend = Self::parse_acid_backend_value(&acid_backend)
319 .ok_or_else(|| format!("invalid storage.acid_backend value: '{acid_backend}'"))?;
320 }
321
322 config.tls_cert_path = settings.tls.cert_path;
323 config.tls_key_path = settings.tls.key_path;
324
325 if let Some(rust_log) = settings.log.rust_log {
326 config.rust_log = rust_log;
327 }
328
329 Ok(config)
330 }
331
332 fn apply_env_overrides(&mut self, get: &impl Fn(&str) -> Option<String>) -> Result<(), String> {
334 if let Some(port) = get("DS_SERVER__PORT") {
335 self.port = port
336 .parse()
337 .map_err(|_| format!("invalid DS_SERVER__PORT value: '{port}'"))?;
338 }
339 if let Some(long_poll_timeout_secs) = get("DS_SERVER__LONG_POLL_TIMEOUT_SECS") {
340 self.long_poll_timeout = Duration::from_secs(
341 long_poll_timeout_secs
342 .parse()
343 .map_err(|_| format!("invalid DS_SERVER__LONG_POLL_TIMEOUT_SECS value: '{long_poll_timeout_secs}'"))?,
344 );
345 }
346 if let Some(sse_reconnect_interval_secs) = get("DS_SERVER__SSE_RECONNECT_INTERVAL_SECS") {
347 self.sse_reconnect_interval_secs = sse_reconnect_interval_secs.parse().map_err(|_| {
348 format!("invalid DS_SERVER__SSE_RECONNECT_INTERVAL_SECS value: '{sse_reconnect_interval_secs}'")
349 })?;
350 }
351
352 if let Some(max_memory_bytes) = get("DS_LIMITS__MAX_MEMORY_BYTES") {
353 self.max_memory_bytes = max_memory_bytes.parse().map_err(|_| {
354 format!("invalid DS_LIMITS__MAX_MEMORY_BYTES value: '{max_memory_bytes}'")
355 })?;
356 }
357 if let Some(max_stream_bytes) = get("DS_LIMITS__MAX_STREAM_BYTES") {
358 self.max_stream_bytes = max_stream_bytes.parse().map_err(|_| {
359 format!("invalid DS_LIMITS__MAX_STREAM_BYTES value: '{max_stream_bytes}'")
360 })?;
361 }
362 if let Some(max_stream_name_bytes) = get("DS_LIMITS__MAX_STREAM_NAME_BYTES") {
363 self.max_stream_name_bytes = max_stream_name_bytes.parse().map_err(|_| {
364 format!("invalid DS_LIMITS__MAX_STREAM_NAME_BYTES value: '{max_stream_name_bytes}'")
365 })?;
366 }
367 if let Some(max_stream_name_segments) = get("DS_LIMITS__MAX_STREAM_NAME_SEGMENTS") {
368 self.max_stream_name_segments = max_stream_name_segments.parse().map_err(|_| {
369 format!("invalid DS_LIMITS__MAX_STREAM_NAME_SEGMENTS value: '{max_stream_name_segments}'")
370 })?;
371 }
372
373 if let Some(cors_origins) = get("DS_HTTP__CORS_ORIGINS") {
374 self.cors_origins = cors_origins;
375 }
376 if let Some(stream_base_path) = get("DS_HTTP__STREAM_BASE_PATH") {
377 self.stream_base_path = Self::parse_stream_base_path_value(&stream_base_path)
378 .map_err(|reason| format!("invalid DS_HTTP__STREAM_BASE_PATH value: {reason}"))?;
379 }
380
381 if let Some(storage_mode) = get("DS_STORAGE__MODE") {
382 self.storage_mode = Self::parse_storage_mode_value(&storage_mode)
383 .ok_or_else(|| format!("invalid DS_STORAGE__MODE value: '{storage_mode}'"))?;
384 }
385
386 if let Some(data_dir) = get("DS_STORAGE__DATA_DIR") {
387 self.data_dir = data_dir;
388 }
389
390 if let Some(acid_shard_count) = get("DS_STORAGE__ACID_SHARD_COUNT") {
391 let parsed = acid_shard_count.parse::<usize>().map_err(|_| {
392 format!("invalid DS_STORAGE__ACID_SHARD_COUNT value: '{acid_shard_count}'")
393 })?;
394 if !Self::valid_acid_shard_count(parsed) {
395 return Err(format!(
396 "invalid DS_STORAGE__ACID_SHARD_COUNT value: '{acid_shard_count}' (must be power-of-two in 1..=256)"
397 ));
398 }
399 self.acid_shard_count = parsed;
400 }
401
402 if let Some(acid_backend) = get("DS_STORAGE__ACID_BACKEND") {
403 self.acid_backend = Self::parse_acid_backend_value(&acid_backend).ok_or_else(|| {
404 format!("invalid DS_STORAGE__ACID_BACKEND value: '{acid_backend}'")
405 })?;
406 }
407
408 if let Some(cert_path) = get("DS_TLS__CERT_PATH") {
409 self.tls_cert_path = Some(cert_path);
410 }
411 if let Some(key_path) = get("DS_TLS__KEY_PATH") {
412 self.tls_key_path = Some(key_path);
413 }
414
415 if let Some(rust_log) = get("DS_LOG__RUST_LOG") {
416 self.rust_log = rust_log;
417 }
418
419 Ok(())
420 }
421
422 pub fn validate(&self) -> std::result::Result<(), String> {
428 match (&self.tls_cert_path, &self.tls_key_path) {
429 (Some(_), Some(_)) | (None, None) => Ok(()),
430 (Some(_), None) => Err(
431 "tls.cert_path is set but tls.key_path is missing; both must be set together"
432 .to_string(),
433 ),
434 (None, Some(_)) => Err(
435 "tls.key_path is set but tls.cert_path is missing; both must be set together"
436 .to_string(),
437 ),
438 }?;
439
440 Self::validate_cors_origins(&self.cors_origins)?;
441 Self::parse_stream_base_path_value(&self.stream_base_path).map(|_| ())?;
442
443 if self.max_stream_name_bytes == 0 {
444 return Err(
445 "limits.max_stream_name_bytes must be at least 1".to_string(),
446 );
447 }
448 if self.max_stream_name_segments == 0 {
449 return Err(
450 "limits.max_stream_name_segments must be at least 1".to_string(),
451 );
452 }
453
454 Ok(())
455 }
456
457 fn validate_cors_origins(origins: &str) -> Result<(), String> {
458 if origins == "*" {
459 return Ok(());
460 }
461
462 let mut parsed_any = false;
463 for origin in origins.split(',').map(str::trim) {
464 if origin.is_empty() {
465 return Err("http.cors_origins contains an empty origin entry".to_string());
466 }
467 HeaderValue::from_str(origin)
468 .map_err(|_| format!("invalid http.cors_origins entry: '{origin}'"))?;
469 parsed_any = true;
470 }
471
472 if !parsed_any {
473 return Err(
474 "http.cors_origins must be '*' or a non-empty comma-separated list".to_string(),
475 );
476 }
477
478 Ok(())
479 }
480
481 fn parse_stream_base_path_value(raw: &str) -> Result<String, String> {
482 let trimmed = raw.trim();
483 if trimmed.is_empty() {
484 return Err("must be a non-empty absolute path".to_string());
485 }
486 if !trimmed.starts_with('/') {
487 return Err(format!("'{trimmed}' (must start with '/')"));
488 }
489
490 if trimmed == "/" {
491 return Ok("/".to_string());
492 }
493
494 let normalized = trimmed.trim_end_matches('/');
495 if normalized.is_empty() {
496 return Err("must be a non-empty absolute path".to_string());
497 }
498
499 Ok(normalized.to_string())
500 }
501
502 #[must_use]
504 pub fn tls_enabled(&self) -> bool {
505 self.tls_cert_path.is_some() && self.tls_key_path.is_some()
506 }
507
508 fn parse_storage_mode_value(raw: &str) -> Option<StorageMode> {
509 match raw.to_ascii_lowercase().as_str() {
510 "memory" => Some(StorageMode::Memory),
511 "file" | "file-durable" | "durable" => Some(StorageMode::FileDurable),
512 "file-fast" | "fast" => Some(StorageMode::FileFast),
513 "acid" | "redb" => Some(StorageMode::Acid),
514 _ => None,
515 }
516 }
517
518 fn valid_acid_shard_count(value: usize) -> bool {
519 (1..=256).contains(&value) && value.is_power_of_two()
520 }
521
522 fn parse_acid_backend_value(raw: &str) -> Option<AcidBackend> {
523 match raw.to_ascii_lowercase().as_str() {
524 "file" => Some(AcidBackend::File),
525 "memory" | "in-memory" | "inmemory" => Some(AcidBackend::InMemory),
526 _ => None,
527 }
528 }
529}
530
531impl Default for Config {
532 fn default() -> Self {
533 Self {
534 port: 4437,
535 max_memory_bytes: 100 * 1024 * 1024,
536 max_stream_bytes: 10 * 1024 * 1024,
537 max_stream_name_bytes: 1024,
538 max_stream_name_segments: 8,
539 cors_origins: "*".to_string(),
540 long_poll_timeout: Duration::from_secs(30),
541 sse_reconnect_interval_secs: 60,
542 stream_base_path: DEFAULT_STREAM_BASE_PATH.to_string(),
543 storage_mode: StorageMode::Memory,
544 data_dir: "./data/streams".to_string(),
545 acid_shard_count: 16,
546 acid_backend: AcidBackend::File,
547 tls_cert_path: None,
548 tls_key_path: None,
549 rust_log: "info".to_string(),
550 }
551 }
552}
553
554#[derive(Debug, Clone, Copy)]
556pub struct LongPollTimeout(pub Duration);
557
558#[derive(Debug, Clone, Copy)]
562pub struct SseReconnectInterval(pub u64);
563
564#[cfg(test)]
565mod tests {
566 use super::*;
567 use std::collections::HashMap;
568 use std::fs;
569 use std::sync::atomic::{AtomicU64, Ordering};
570
571 fn lookup(pairs: &[(&str, &str)]) -> impl Fn(&str) -> Option<String> {
573 let map: HashMap<String, String> = pairs
574 .iter()
575 .map(|(k, v)| ((*k).to_string(), (*v).to_string()))
576 .collect();
577 move |key: &str| map.get(key).cloned()
578 }
579
580 fn temp_config_dir() -> PathBuf {
581 static COUNTER: AtomicU64 = AtomicU64::new(0);
582 let id = COUNTER.fetch_add(1, Ordering::Relaxed);
583 let path =
584 std::env::temp_dir().join(format!("ds-config-tests-{}-{}", std::process::id(), id));
585 fs::create_dir_all(&path).expect("create temp config dir");
586 path
587 }
588
589 #[test]
590 fn test_default_config() {
591 let config = Config::default();
592 assert_eq!(config.port, 4437);
593 assert_eq!(config.max_memory_bytes, 100 * 1024 * 1024);
594 assert_eq!(config.max_stream_bytes, 10 * 1024 * 1024);
595 assert_eq!(config.cors_origins, "*");
596 assert_eq!(config.long_poll_timeout, Duration::from_secs(30));
597 assert_eq!(config.sse_reconnect_interval_secs, 60);
598 assert_eq!(config.stream_base_path, DEFAULT_STREAM_BASE_PATH);
599 assert_eq!(config.storage_mode, StorageMode::Memory);
600 assert_eq!(config.data_dir, "./data/streams");
601 assert_eq!(config.acid_shard_count, 16);
602 assert_eq!(config.acid_backend, AcidBackend::File);
603 assert_eq!(config.tls_cert_path, None);
604 assert_eq!(config.tls_key_path, None);
605 assert_eq!(config.rust_log, "info");
606 }
607
608 #[test]
609 fn test_from_env_uses_defaults_when_no_ds_vars() {
610 let config = Config::from_env().expect("config from env");
612 assert_eq!(config.port, 4437);
613 assert_eq!(config.storage_mode, StorageMode::Memory);
614 assert_eq!(config.rust_log, "info");
615 }
616
617 #[test]
618 fn test_env_overrides_parse_all_ds_vars() {
619 let mut config = Config::default();
620 let get = lookup(&[
621 ("DS_SERVER__PORT", "8080"),
622 ("DS_LIMITS__MAX_MEMORY_BYTES", "200000000"),
623 ("DS_LIMITS__MAX_STREAM_BYTES", "20000000"),
624 ("DS_HTTP__CORS_ORIGINS", "https://example.com"),
625 ("DS_SERVER__LONG_POLL_TIMEOUT_SECS", "5"),
626 ("DS_SERVER__SSE_RECONNECT_INTERVAL_SECS", "120"),
627 ("DS_HTTP__STREAM_BASE_PATH", "/streams"),
628 ("DS_STORAGE__MODE", "file-fast"),
629 ("DS_STORAGE__DATA_DIR", "/tmp/ds-store"),
630 ("DS_STORAGE__ACID_SHARD_COUNT", "32"),
631 ("DS_TLS__CERT_PATH", "/tmp/cert.pem"),
632 ("DS_TLS__KEY_PATH", "/tmp/key.pem"),
633 ("DS_LOG__RUST_LOG", "debug"),
634 ]);
635 config
636 .apply_env_overrides(&get)
637 .expect("apply env overrides");
638 assert_eq!(config.port, 8080);
639 assert_eq!(config.max_memory_bytes, 200_000_000);
640 assert_eq!(config.max_stream_bytes, 20_000_000);
641 assert_eq!(config.cors_origins, "https://example.com");
642 assert_eq!(config.long_poll_timeout, Duration::from_secs(5));
643 assert_eq!(config.sse_reconnect_interval_secs, 120);
644 assert_eq!(config.stream_base_path, "/streams");
645 assert_eq!(config.storage_mode, StorageMode::FileFast);
646 assert_eq!(config.data_dir, "/tmp/ds-store");
647 assert_eq!(config.acid_shard_count, 32);
648 assert_eq!(config.tls_cert_path.as_deref(), Some("/tmp/cert.pem"));
649 assert_eq!(config.tls_key_path.as_deref(), Some("/tmp/key.pem"));
650 assert_eq!(config.rust_log, "debug");
651 }
652
653 #[test]
654 fn test_env_overrides_reject_unparseable_values() {
655 let mut config = Config::default();
656 let get = lookup(&[
657 ("DS_SERVER__PORT", "not-a-number"),
658 ("DS_LIMITS__MAX_MEMORY_BYTES", ""),
659 ("DS_SERVER__LONG_POLL_TIMEOUT_SECS", "abc"),
660 ]);
661 let err = config
662 .apply_env_overrides(&get)
663 .expect_err("invalid env override should fail");
664 assert_eq!(err, "invalid DS_SERVER__PORT value: 'not-a-number'");
665 assert_eq!(config.port, 4437);
666 assert_eq!(config.max_memory_bytes, 100 * 1024 * 1024);
667 assert_eq!(config.long_poll_timeout, Duration::from_secs(30));
668 }
669
670 #[test]
671 fn test_env_overrides_partial() {
672 let mut config = Config::default();
673 let get = lookup(&[("DS_SERVER__PORT", "9090")]);
674 config
675 .apply_env_overrides(&get)
676 .expect("apply env overrides");
677 assert_eq!(config.port, 9090);
678 assert_eq!(config.storage_mode, StorageMode::Memory);
680 assert_eq!(config.rust_log, "info");
681 }
682
683 #[test]
684 fn test_from_sources_file_layers_and_env_override() {
685 let config_dir = temp_config_dir();
686 fs::write(
687 config_dir.join("default.toml"),
688 r#"
689 [server]
690 port = 4437
691 [http]
692 stream_base_path = "/v1/stream"
693 [storage]
694 mode = "memory"
695 [log]
696 rust_log = "warn"
697 "#,
698 )
699 .expect("write default.toml");
700
701 fs::write(
702 config_dir.join("dev.toml"),
703 r#"
704 [server]
705 port = 7777
706 [http]
707 stream_base_path = "/streams"
708 [storage]
709 mode = "file-fast"
710 data_dir = "/tmp/dev-store"
711 "#,
712 )
713 .expect("write dev.toml");
714
715 fs::write(
716 config_dir.join("local.toml"),
717 r"
718 [server]
719 port = 8888
720 ",
721 )
722 .expect("write local.toml");
723
724 let options = ConfigLoadOptions {
725 config_dir,
726 profile: "dev".to_string(),
727 config_override: None,
728 };
729
730 let env = lookup(&[("DS_SERVER__PORT", "9999"), ("DS_LOG__RUST_LOG", "debug")]);
732 let config = Config::from_sources_with_lookup(&options, &env).expect("config from sources");
733
734 assert_eq!(config.port, 9999);
735 assert_eq!(config.stream_base_path, "/streams");
736 assert_eq!(config.storage_mode, StorageMode::FileFast);
737 assert_eq!(config.data_dir, "/tmp/dev-store");
738 assert_eq!(config.rust_log, "debug");
739 }
740
741 #[test]
742 fn test_from_sources_env_overrides_toml() {
743 let config_dir = temp_config_dir();
744 fs::write(
745 config_dir.join("default.toml"),
746 r#"
747 [server]
748 port = 4437
749 [storage]
750 mode = "memory"
751 "#,
752 )
753 .expect("write default.toml");
754
755 let options = ConfigLoadOptions {
756 config_dir,
757 profile: "default".to_string(),
758 config_override: None,
759 };
760
761 let env = lookup(&[
762 ("DS_SERVER__PORT", "12345"),
763 ("DS_STORAGE__MODE", "acid"),
764 ("DS_STORAGE__ACID_SHARD_COUNT", "32"),
765 ("DS_TLS__CERT_PATH", "/tmp/cert.pem"),
766 ("DS_TLS__KEY_PATH", "/tmp/key.pem"),
767 ]);
768 let config = Config::from_sources_with_lookup(&options, &env).expect("config from sources");
769
770 assert_eq!(config.port, 12345);
771 assert_eq!(config.storage_mode, StorageMode::Acid);
772 assert_eq!(config.acid_shard_count, 32);
773 assert_eq!(config.tls_cert_path.as_deref(), Some("/tmp/cert.pem"));
774 assert_eq!(config.tls_key_path.as_deref(), Some("/tmp/key.pem"));
775 }
776
777 #[test]
778 fn test_validate_tls_pair_ok_when_both_absent_or_present() {
779 let mut config = Config::default();
780 assert!(config.validate().is_ok());
781 assert!(!config.tls_enabled());
782
783 config.tls_cert_path = Some("/tmp/cert.pem".to_string());
784 config.tls_key_path = Some("/tmp/key.pem".to_string());
785 assert!(config.validate().is_ok());
786 assert!(config.tls_enabled());
787 }
788
789 #[test]
790 fn test_validate_tls_pair_rejects_partial_configuration() {
791 let mut config = Config {
792 tls_cert_path: Some("/tmp/cert.pem".to_string()),
793 ..Config::default()
794 };
795 assert!(config.validate().is_err());
796
797 config.tls_cert_path = None;
798 config.tls_key_path = Some("/tmp/key.pem".to_string());
799 assert!(config.validate().is_err());
800 }
801
802 #[test]
803 fn test_storage_mode_aliases() {
804 let mut config = Config::default();
805 config
806 .apply_env_overrides(&lookup(&[("DS_STORAGE__MODE", "acid")]))
807 .expect("apply env overrides");
808 assert_eq!(config.storage_mode, StorageMode::Acid);
809
810 let mut config = Config::default();
811 config
812 .apply_env_overrides(&lookup(&[("DS_STORAGE__MODE", "redb")]))
813 .expect("apply env overrides");
814 assert_eq!(config.storage_mode, StorageMode::Acid);
815 }
816
817 #[test]
818 fn test_acid_shard_count_valid_values() {
819 let mut config = Config::default();
820 config
821 .apply_env_overrides(&lookup(&[("DS_STORAGE__ACID_SHARD_COUNT", "1")]))
822 .expect("apply env overrides");
823 assert_eq!(config.acid_shard_count, 1);
824
825 let mut config = Config::default();
826 config
827 .apply_env_overrides(&lookup(&[("DS_STORAGE__ACID_SHARD_COUNT", "256")]))
828 .expect("apply env overrides");
829 assert_eq!(config.acid_shard_count, 256);
830 }
831
832 #[test]
833 fn test_acid_shard_count_invalid_values_return_error() {
834 let mut config = Config::default();
835 let err = config
836 .apply_env_overrides(&lookup(&[("DS_STORAGE__ACID_SHARD_COUNT", "0")]))
837 .expect_err("invalid shard count should fail");
838 assert_eq!(
839 err,
840 "invalid DS_STORAGE__ACID_SHARD_COUNT value: '0' (must be power-of-two in 1..=256)"
841 );
842 assert_eq!(config.acid_shard_count, 16);
843
844 let mut config = Config::default();
845 let err = config
846 .apply_env_overrides(&lookup(&[("DS_STORAGE__ACID_SHARD_COUNT", "3")]))
847 .expect_err("invalid shard count should fail");
848 assert_eq!(
849 err,
850 "invalid DS_STORAGE__ACID_SHARD_COUNT value: '3' (must be power-of-two in 1..=256)"
851 );
852 assert_eq!(config.acid_shard_count, 16);
853
854 let mut config = Config::default();
855 let err = config
856 .apply_env_overrides(&lookup(&[("DS_STORAGE__ACID_SHARD_COUNT", "abc")]))
857 .expect_err("invalid shard count should fail");
858 assert_eq!(err, "invalid DS_STORAGE__ACID_SHARD_COUNT value: 'abc'");
859 assert_eq!(config.acid_shard_count, 16);
860 }
861
862 #[test]
863 fn test_acid_backend_env_override() {
864 let mut config = Config::default();
865 assert_eq!(config.acid_backend, AcidBackend::File);
866
867 config
868 .apply_env_overrides(&lookup(&[("DS_STORAGE__ACID_BACKEND", "memory")]))
869 .expect("apply env overrides");
870 assert_eq!(config.acid_backend, AcidBackend::InMemory);
871
872 let mut config = Config::default();
873 config
874 .apply_env_overrides(&lookup(&[("DS_STORAGE__ACID_BACKEND", "in-memory")]))
875 .expect("apply env overrides");
876 assert_eq!(config.acid_backend, AcidBackend::InMemory);
877
878 let mut config = Config::default();
879 config
880 .apply_env_overrides(&lookup(&[("DS_STORAGE__ACID_BACKEND", "file")]))
881 .expect("apply env overrides");
882 assert_eq!(config.acid_backend, AcidBackend::File);
883 }
884
885 #[test]
886 fn test_acid_backend_env_override_rejects_invalid() {
887 let mut config = Config::default();
888 let err = config
889 .apply_env_overrides(&lookup(&[("DS_STORAGE__ACID_BACKEND", "sqlite")]))
890 .expect_err("invalid acid backend should fail");
891 assert_eq!(err, "invalid DS_STORAGE__ACID_BACKEND value: 'sqlite'");
892 assert_eq!(config.acid_backend, AcidBackend::File);
893 }
894
895 #[test]
896 fn test_env_overrides_reject_invalid_storage_mode() {
897 let mut config = Config::default();
898 let err = config
899 .apply_env_overrides(&lookup(&[("DS_STORAGE__MODE", "memroy")]))
900 .expect_err("invalid storage mode should fail");
901 assert_eq!(err, "invalid DS_STORAGE__MODE value: 'memroy'");
902 }
903
904 #[test]
905 fn test_validate_rejects_invalid_cors_origins() {
906 let config = Config {
907 cors_origins: "https://good.example, ,https://other.example".to_string(),
908 ..Config::default()
909 };
910 assert_eq!(
911 config
912 .validate()
913 .expect_err("invalid cors origins should fail"),
914 "http.cors_origins contains an empty origin entry"
915 );
916 }
917
918 #[test]
919 fn test_stream_base_path_normalizes_trailing_slash() {
920 let mut config = Config::default();
921 config
922 .apply_env_overrides(&lookup(&[("DS_HTTP__STREAM_BASE_PATH", "/streams/")]))
923 .expect("apply env overrides");
924 assert_eq!(config.stream_base_path, "/streams");
925 }
926
927 #[test]
928 fn test_stream_base_path_rejects_relative_path() {
929 let mut config = Config::default();
930 let err = config
931 .apply_env_overrides(&lookup(&[("DS_HTTP__STREAM_BASE_PATH", "streams")]))
932 .expect_err("relative base path should fail");
933 assert_eq!(
934 err,
935 "invalid DS_HTTP__STREAM_BASE_PATH value: 'streams' (must start with '/')"
936 );
937 }
938
939 #[test]
940 fn test_validate_rejects_invalid_stream_base_path() {
941 let config = Config {
942 stream_base_path: "streams".to_string(),
943 ..Config::default()
944 };
945 assert_eq!(
946 config
947 .validate()
948 .expect_err("invalid stream base path should fail"),
949 "'streams' (must start with '/')"
950 );
951 }
952
953 #[test]
954 fn test_long_poll_timeout_newtype() {
955 let timeout = LongPollTimeout(Duration::from_secs(10));
956 assert_eq!(timeout.0, Duration::from_secs(10));
957 }
958
959 #[test]
960 fn test_sse_reconnect_interval_newtype() {
961 let interval = SseReconnectInterval(120);
962 assert_eq!(interval.0, 120);
963 }
964}