use anyhow::Context;
use serde::Deserialize;
pub(crate) const SUPPORTED_RATES: &[u32] = &[8000, 16000, 24000, 44100, 48000];
pub(crate) const DEFAULT_SAMPLE_RATE: u32 = 48000;
pub(crate) fn pool_retry_after_ms(limits: &RuntimeLimits) -> u32 {
limits
.pool_checkout_timeout_secs
.saturating_mul(1000)
.min(u32::MAX as u64) as u32
}
pub(crate) fn pool_retry_after_secs(limits: &RuntimeLimits) -> u64 {
limits.pool_checkout_timeout_secs
}
#[derive(Debug, Clone, Default)]
pub struct OriginPolicy {
pub allow_any: bool,
pub allowed_origins: Vec<String>,
}
impl OriginPolicy {
pub fn loopback_only() -> Self {
Self::default()
}
}
#[derive(Debug)]
pub(crate) enum OriginVerdict {
AllowedNoEcho,
Allowed(String),
Denied,
}
fn is_loopback_origin(origin: &str) -> bool {
let lowered = origin.to_ascii_lowercase();
const HOST_PREFIXES: &[&str] = &[
"http://localhost",
"https://localhost",
"http://127.0.0.1",
"https://127.0.0.1",
"http://[::1]",
"https://[::1]",
];
HOST_PREFIXES.iter().any(|p| match lowered.strip_prefix(p) {
None => false,
Some(rest) => rest.is_empty() || rest.starts_with(':') || rest.starts_with('/'),
})
}
impl OriginPolicy {
pub(crate) fn evaluate(&self, origin: Option<&str>) -> OriginVerdict {
let Some(origin) = origin else {
return OriginVerdict::AllowedNoEcho;
};
if origin.eq_ignore_ascii_case("null") {
return OriginVerdict::AllowedNoEcho;
}
if self.allow_any || is_loopback_origin(origin) {
return OriginVerdict::Allowed(origin.to_string());
}
if self
.allowed_origins
.iter()
.any(|a| a.eq_ignore_ascii_case(origin))
{
return OriginVerdict::Allowed(origin.to_string());
}
OriginVerdict::Denied
}
}
#[derive(Debug, Clone)]
pub struct RuntimeLimits {
pub idle_timeout_secs: u64,
pub ws_frame_max_bytes: usize,
pub body_limit_bytes: usize,
pub rate_limit_per_minute: u32,
pub rate_limit_burst: u32,
pub max_session_secs: u64,
pub shutdown_drain_secs: u64,
pub pool_checkout_timeout_secs: u64,
}
impl Default for RuntimeLimits {
fn default() -> Self {
Self {
idle_timeout_secs: 300,
ws_frame_max_bytes: 512 * 1024,
body_limit_bytes: 50 * 1024 * 1024,
rate_limit_per_minute: 0,
rate_limit_burst: 10,
max_session_secs: 3600,
shutdown_drain_secs: 10,
pool_checkout_timeout_secs: 30,
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct RuntimeLimitsConfig {
pub idle_timeout_secs: u64,
pub ws_frame_max_bytes: usize,
pub body_limit_bytes: usize,
pub rate_limit_per_minute: u32,
pub rate_limit_burst: u32,
pub max_session_secs: u64,
pub shutdown_drain_secs: u64,
pub pool_checkout_timeout_secs: u64,
}
impl Default for RuntimeLimitsConfig {
fn default() -> Self {
let d = RuntimeLimits::default();
Self {
idle_timeout_secs: d.idle_timeout_secs,
ws_frame_max_bytes: d.ws_frame_max_bytes,
body_limit_bytes: d.body_limit_bytes,
rate_limit_per_minute: d.rate_limit_per_minute,
rate_limit_burst: d.rate_limit_burst,
max_session_secs: d.max_session_secs,
shutdown_drain_secs: d.shutdown_drain_secs,
pool_checkout_timeout_secs: d.pool_checkout_timeout_secs,
}
}
}
impl From<RuntimeLimitsConfig> for RuntimeLimits {
fn from(cfg: RuntimeLimitsConfig) -> Self {
Self {
idle_timeout_secs: cfg.idle_timeout_secs,
ws_frame_max_bytes: cfg.ws_frame_max_bytes,
body_limit_bytes: cfg.body_limit_bytes,
rate_limit_per_minute: cfg.rate_limit_per_minute,
rate_limit_burst: cfg.rate_limit_burst,
max_session_secs: cfg.max_session_secs,
shutdown_drain_secs: cfg.shutdown_drain_secs,
pool_checkout_timeout_secs: cfg.pool_checkout_timeout_secs,
}
}
}
pub fn load_config_file(path: &std::path::Path) -> anyhow::Result<RuntimeLimits> {
let content = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read config file: {}", path.display()))?;
let cfg: RuntimeLimitsConfig = toml::from_str(&content)
.with_context(|| format!("Failed to parse config file: {}", path.display()))?;
Ok(cfg.into())
}
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub port: u16,
pub host: String,
pub origin_policy: OriginPolicy,
pub limits: RuntimeLimits,
pub metrics_enabled: bool,
pub trust_proxy: bool,
pub config_path: Option<std::path::PathBuf>,
}
impl ServerConfig {
pub fn local(port: u16) -> Self {
Self {
port,
host: "127.0.0.1".to_string(),
origin_policy: OriginPolicy::loopback_only(),
limits: RuntimeLimits::default(),
metrics_enabled: false,
trust_proxy: false,
config_path: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_runtime_limits_default_rate_limit_disabled() {
let limits = RuntimeLimits::default();
assert_eq!(
limits.rate_limit_per_minute, 0,
"rate limiting must be off by default (privacy-first)"
);
assert_eq!(limits.rate_limit_burst, 10, "default burst size must be 10");
}
#[test]
fn test_runtime_limits_default_session_and_drain() {
let limits = RuntimeLimits::default();
assert_eq!(
limits.max_session_secs, 3600,
"default session cap must be 1 hour to stop silence-streamers from \
holding a triplet forever"
);
assert_eq!(
limits.shutdown_drain_secs, 10,
"default shutdown drain must be 10 s — comfortably inside the usual \
k8s terminationGracePeriodSeconds = 30"
);
}
#[test]
fn test_supported_rates_contains_common() {
assert!(
SUPPORTED_RATES.contains(&8000),
"SUPPORTED_RATES must include 8000 Hz"
);
assert!(
SUPPORTED_RATES.contains(&16000),
"SUPPORTED_RATES must include 16000 Hz"
);
assert!(
SUPPORTED_RATES.contains(&48000),
"SUPPORTED_RATES must include 48000 Hz"
);
}
#[test]
fn test_default_sample_rate_in_supported() {
assert!(
SUPPORTED_RATES.contains(&DEFAULT_SAMPLE_RATE),
"DEFAULT_SAMPLE_RATE ({DEFAULT_SAMPLE_RATE}) must be present in SUPPORTED_RATES"
);
}
#[test]
fn test_loopback_origin_matcher() {
assert!(is_loopback_origin("http://localhost"));
assert!(is_loopback_origin("https://localhost:3000"));
assert!(is_loopback_origin("http://127.0.0.1:9876"));
assert!(is_loopback_origin("HTTPS://127.0.0.1")); assert!(is_loopback_origin("http://[::1]:9876"));
assert!(!is_loopback_origin("https://evil.example.com"));
assert!(!is_loopback_origin("http://192.168.1.10"));
assert!(!is_loopback_origin("http://localhost.evil.example.com"));
}
#[test]
fn test_origin_policy_default_denies_third_party() {
let policy = OriginPolicy::loopback_only();
assert!(matches!(
policy.evaluate(Some("https://evil.example.com")),
OriginVerdict::Denied
));
}
#[test]
fn test_origin_policy_allows_loopback_by_default() {
let policy = OriginPolicy::loopback_only();
assert!(matches!(
policy.evaluate(Some("http://localhost:3000")),
OriginVerdict::Allowed(_)
));
}
#[test]
fn test_origin_policy_allows_listed_origin() {
let policy = OriginPolicy {
allow_any: false,
allowed_origins: vec!["https://app.example.com".into()],
};
assert!(matches!(
policy.evaluate(Some("https://app.example.com")),
OriginVerdict::Allowed(_)
));
assert!(matches!(
policy.evaluate(Some("https://app.example.com.evil.com")),
OriginVerdict::Denied
));
}
#[test]
fn test_origin_policy_allow_any_short_circuits() {
let policy = OriginPolicy {
allow_any: true,
allowed_origins: vec![],
};
assert!(matches!(
policy.evaluate(Some("https://anything.example.com")),
OriginVerdict::Allowed(_)
));
}
#[test]
fn test_runtime_limits_from_toml() {
let toml_str = r#"
idle_timeout_secs = 600
rate_limit_per_minute = 120
"#;
let cfg: RuntimeLimitsConfig = toml::from_str(toml_str).unwrap();
assert_eq!(cfg.idle_timeout_secs, 600);
assert_eq!(cfg.rate_limit_per_minute, 120);
assert_eq!(cfg.max_session_secs, 3600);
}
#[test]
fn test_runtime_limits_config_to_limits() {
let cfg = RuntimeLimitsConfig::default();
let limits: RuntimeLimits = cfg.into();
let defaults = RuntimeLimits::default();
assert_eq!(limits.idle_timeout_secs, defaults.idle_timeout_secs);
assert_eq!(limits.max_session_secs, defaults.max_session_secs);
}
#[test]
fn test_origin_policy_no_header_allowed() {
let policy = OriginPolicy::loopback_only();
assert!(matches!(
policy.evaluate(None),
OriginVerdict::AllowedNoEcho
));
assert!(matches!(
policy.evaluate(Some("null")),
OriginVerdict::AllowedNoEcho
));
}
#[test]
fn test_pool_retry_after_ms_saturation() {
let limits = RuntimeLimits {
pool_checkout_timeout_secs: u32::MAX as u64,
..Default::default()
};
let ms = pool_retry_after_ms(&limits);
assert_eq!(ms, u32::MAX);
}
#[test]
fn test_load_config_file_not_found() {
let result = load_config_file(std::path::Path::new("/nonexistent/config.toml"));
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("Failed to read config file"));
}
#[test]
fn test_load_config_file_bad_toml() {
let tmp = tempfile::NamedTempFile::new().unwrap();
std::fs::write(tmp.path(), b"not valid toml {{{").unwrap();
let result = load_config_file(tmp.path());
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("Failed to parse config file"));
}
#[test]
fn test_load_config_file_valid() {
let tmp = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
tmp.path(),
b"idle_timeout_secs = 123\nws_frame_max_bytes = 1024\n",
)
.unwrap();
let limits = load_config_file(tmp.path()).unwrap();
assert_eq!(limits.idle_timeout_secs, 123);
assert_eq!(limits.ws_frame_max_bytes, 1024);
assert_eq!(limits.max_session_secs, 3600);
}
#[test]
fn test_server_config_local_defaults() {
let cfg = ServerConfig::local(9876);
assert_eq!(cfg.port, 9876);
assert_eq!(cfg.host, "127.0.0.1");
assert!(!cfg.metrics_enabled);
assert!(!cfg.trust_proxy);
assert!(cfg.config_path.is_none());
}
}