use crate::store::validate_durable_settings;
use serde::Deserialize;
use std::path::{Path, PathBuf};
#[derive(Debug, thiserror::Error)]
pub enum ConfigError {
#[error("failed to read config file {path}: {source}")]
Read {
path: String,
source: std::io::Error,
},
#[error("failed to parse config file {path}: {source}")]
Parse {
path: String,
source: toml::de::Error,
},
#[error("invalid value in config file {path}: {message}")]
Validation {
path: String,
message: String,
},
}
#[derive(Deserialize, Default)]
pub struct FileConfig {
pub port: Option<u16>,
pub account_id: Option<String>,
pub region: Option<String>,
pub shard_limit: Option<u32>,
pub create_stream_ms: Option<u64>,
pub delete_stream_ms: Option<u64>,
pub update_stream_ms: Option<u64>,
pub iterator_ttl_seconds: Option<u64>,
pub retention_check_interval_secs: Option<u64>,
pub enforce_limits: Option<bool>,
pub state_dir: Option<PathBuf>,
pub snapshot_interval_secs: Option<u64>,
pub max_retained_bytes: Option<u64>,
pub max_request_body_mb: Option<u64>,
pub log_level: Option<String>,
pub log_format: Option<String>,
pub otlp_endpoint: Option<String>,
pub otlp_protocol: Option<String>,
pub otel_sample_ratio: Option<f64>,
pub otel_service_name: Option<String>,
#[cfg(feature = "access-log")]
pub access_log: Option<bool>,
pub capture: Option<PathBuf>,
pub scrub: Option<bool>,
#[cfg(feature = "mirror")]
pub mirror: Option<MirrorConfig>,
#[cfg(feature = "tls")]
pub tls_cert: Option<PathBuf>,
#[cfg(feature = "tls")]
pub tls_key: Option<PathBuf>,
}
#[cfg(feature = "mirror")]
#[derive(Deserialize, Default)]
pub struct MirrorConfig {
pub to: Option<String>,
pub diff: Option<bool>,
pub concurrency: Option<usize>,
pub max_retries: Option<usize>,
pub initial_backoff_ms: Option<u64>,
pub max_backoff_ms: Option<u64>,
}
pub fn load_config(path: &Path) -> Result<FileConfig, ConfigError> {
let content = std::fs::read_to_string(path).map_err(|e| ConfigError::Read {
path: path.display().to_string(),
source: e,
})?;
let config: FileConfig = toml::from_str(&content).map_err(|e| ConfigError::Parse {
path: path.display().to_string(),
source: e,
})?;
if let Some(ttl) = config.iterator_ttl_seconds
&& !(1..=86400).contains(&ttl)
{
return Err(ConfigError::Validation {
path: path.display().to_string(),
message: format!("iterator_ttl_seconds must be between 1 and 86400, got {ttl}"),
});
}
if let Some(v) = config.retention_check_interval_secs
&& v > 86400
{
return Err(ConfigError::Validation {
path: path.display().to_string(),
message: format!("retention_check_interval_secs must be between 0 and 86400, got {v}"),
});
}
if let Err(err) =
validate_durable_settings(config.snapshot_interval_secs, config.max_retained_bytes)
{
return Err(ConfigError::Validation {
path: path.display().to_string(),
message: err.to_string(),
});
}
if let Some(ref level) = config.log_level
&& !["off", "error", "warn", "info", "debug", "trace"].contains(&level.as_str())
{
return Err(ConfigError::Validation {
path: path.display().to_string(),
message: format!(
"log_level must be one of: off, error, warn, info, debug, trace — got \"{level}\""
),
});
}
if let Some(ref format) = config.log_format
&& !["plain", "json"].contains(&format.as_str())
{
return Err(ConfigError::Validation {
path: path.display().to_string(),
message: format!("log_format must be one of: plain, json — got \"{format}\""),
});
}
if let Some(ref protocol) = config.otlp_protocol
&& !["grpc", "http"].contains(&protocol.as_str())
{
return Err(ConfigError::Validation {
path: path.display().to_string(),
message: format!("otlp_protocol must be one of: grpc, http — got \"{protocol}\""),
});
}
if let Some(ratio) = config.otel_sample_ratio
&& !(0.0..=1.0).contains(&ratio)
{
return Err(ConfigError::Validation {
path: path.display().to_string(),
message: format!("otel_sample_ratio must be between 0.0 and 1.0, got {ratio}"),
});
}
#[cfg(feature = "mirror")]
if let Some(ref mirror) = config.mirror
&& let Some(concurrency) = mirror.concurrency
&& concurrency == 0
{
return Err(ConfigError::Validation {
path: path.display().to_string(),
message: format!("mirror.concurrency must be at least 1, got {concurrency}"),
});
}
#[cfg(feature = "mirror")]
if let Some(ref mirror) = config.mirror
&& let Some(initial) = mirror.initial_backoff_ms
&& initial == 0
{
return Err(ConfigError::Validation {
path: path.display().to_string(),
message: format!("mirror.initial_backoff_ms must be at least 1, got {initial}"),
});
}
#[cfg(feature = "mirror")]
if let Some(ref mirror) = config.mirror
&& let Some(initial) = mirror.initial_backoff_ms
&& let Some(max) = mirror.max_backoff_ms
&& max < initial
{
return Err(ConfigError::Validation {
path: path.display().to_string(),
message: format!(
"mirror.max_backoff_ms ({max}) must be >= mirror.initial_backoff_ms ({initial})"
),
});
}
#[cfg(feature = "tls")]
match (&config.tls_cert, &config.tls_key) {
(Some(_), None) | (None, Some(_)) => {
return Err(ConfigError::Validation {
path: path.display().to_string(),
message: "tls_cert and tls_key must both be set or both be omitted".into(),
});
}
_ => {}
}
Ok(config)
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn load_config_rejects_zero_max_retained_bytes() {
let file = NamedTempFile::new().unwrap();
std::fs::write(file.path(), "max_retained_bytes = 0\n").unwrap();
let err = match load_config(file.path()) {
Ok(_) => panic!("expected config validation error"),
Err(err) => err,
};
assert!(matches!(err, ConfigError::Validation { .. }));
assert!(
err.to_string()
.contains("max_retained_bytes must be greater than 0")
);
}
#[test]
fn load_config_rejects_out_of_range_snapshot_interval() {
let file = NamedTempFile::new().unwrap();
std::fs::write(file.path(), "snapshot_interval_secs = 86401\n").unwrap();
let err = match load_config(file.path()) {
Ok(_) => panic!("expected config validation error"),
Err(err) => err,
};
assert!(matches!(err, ConfigError::Validation { .. }));
assert!(
err.to_string()
.contains("snapshot_interval_secs must be between 0 and 86400")
);
}
fn write_temp_toml(contents: &str) -> NamedTempFile {
let mut file = tempfile::NamedTempFile::new().expect("create temp config file");
file.write_all(contents.as_bytes())
.expect("write temp config file");
file
}
#[test]
fn rejects_invalid_log_format() {
let file = write_temp_toml("log_format = \"pretty\"\n");
let err = match load_config(file.path()) {
Ok(_) => panic!("invalid log_format should fail"),
Err(err) => err,
};
assert!(matches!(err, ConfigError::Validation { .. }));
assert!(err.to_string().contains("log_format must be one of"));
}
#[test]
fn rejects_invalid_otlp_protocol() {
let file = write_temp_toml("otlp_protocol = \"tcp\"\n");
let err = match load_config(file.path()) {
Ok(_) => panic!("invalid otlp_protocol should fail"),
Err(err) => err,
};
assert!(matches!(err, ConfigError::Validation { .. }));
assert!(err.to_string().contains("otlp_protocol must be one of"));
}
#[test]
fn rejects_out_of_range_otel_sample_ratio() {
let file = write_temp_toml("otel_sample_ratio = 1.1\n");
let err = match load_config(file.path()) {
Ok(_) => panic!("invalid otel_sample_ratio should fail"),
Err(err) => err,
};
assert!(matches!(err, ConfigError::Validation { .. }));
assert!(
err.to_string()
.contains("otel_sample_ratio must be between 0.0 and 1.0")
);
}
}