use crate::config::{ConfigError, ServerConfig};
pub const VALID_LOG_LEVELS: &[&str] = &["error", "warn", "info", "debug", "trace", "off"];
pub const MIN_BEARER_TOKEN_LEN: usize = 16;
pub const MAX_DEFAULT_MAX_TOKENS: usize = 8192;
impl ServerConfig {
pub fn validate(&self) -> Result<(), ConfigError> {
if self.bind.port == 0 {
return Err(ConfigError::Validation(
"bind.port must be in [1, 65535]".to_string(),
));
}
if self.sampling.default_max_tokens == 0
|| self.sampling.default_max_tokens > MAX_DEFAULT_MAX_TOKENS
{
return Err(ConfigError::Validation(format!(
"sampling.default_max_tokens must be in [1, {MAX_DEFAULT_MAX_TOKENS}], got {}",
self.sampling.default_max_tokens
)));
}
if !self.sampling.default_temperature.is_finite()
|| self.sampling.default_temperature < 0.0
|| self.sampling.default_temperature > 2.0
{
return Err(ConfigError::Validation(format!(
"sampling.default_temperature must be in [0, 2], got {}",
self.sampling.default_temperature
)));
}
if !self.sampling.default_top_p.is_finite()
|| self.sampling.default_top_p < 0.0
|| self.sampling.default_top_p > 1.0
{
return Err(ConfigError::Validation(format!(
"sampling.default_top_p must be in [0, 1], got {}",
self.sampling.default_top_p
)));
}
if !VALID_LOG_LEVELS
.iter()
.any(|l| l.eq_ignore_ascii_case(&self.observability.log_level))
{
return Err(ConfigError::Validation(format!(
"observability.log_level must be one of {VALID_LOG_LEVELS:?}, got {:?}",
self.observability.log_level
)));
}
if self.observability.metrics_path.is_empty()
|| !self.observability.metrics_path.starts_with('/')
{
return Err(ConfigError::Validation(format!(
"observability.metrics_path must be an absolute HTTP path, got {:?}",
self.observability.metrics_path
)));
}
if self.limits.max_concurrent_requests == 0 {
return Err(ConfigError::Validation(
"limits.max_concurrent_requests must be ≥ 1".to_string(),
));
}
if self.limits.per_request_timeout_ms == 0 {
return Err(ConfigError::Validation(
"limits.per_request_timeout_ms must be ≥ 1".to_string(),
));
}
if self.limits.max_input_tokens == 0 {
return Err(ConfigError::Validation(
"limits.max_input_tokens must be ≥ 1".to_string(),
));
}
if let Some(ref tok) = self.auth.bearer_token {
if tok.len() < MIN_BEARER_TOKEN_LEN {
return Err(ConfigError::Validation(format!(
"auth.bearer_token must be at least {MIN_BEARER_TOKEN_LEN} chars"
)));
}
}
if let Some(ref path) = self.model.path {
if !path.exists() {
return Err(ConfigError::Validation(format!(
"model.path does not exist: {}",
path.display()
)));
}
}
if let Some(ref path) = self.tokenizer.path {
if !path.exists() {
return Err(ConfigError::Validation(format!(
"tokenizer.path does not exist: {}",
path.display()
)));
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_validates() {
let cfg = ServerConfig::default();
cfg.validate().expect("defaults should validate");
}
#[test]
fn port_zero_rejected() {
let mut cfg = ServerConfig::default();
cfg.bind.port = 0;
assert!(cfg.validate().is_err());
}
#[test]
fn bad_log_level_rejected() {
let mut cfg = ServerConfig::default();
cfg.observability.log_level = "loud".to_string();
assert!(cfg.validate().is_err());
}
#[test]
fn bad_top_p_rejected() {
let mut cfg = ServerConfig::default();
cfg.sampling.default_top_p = 1.5;
assert!(cfg.validate().is_err());
}
#[test]
fn short_bearer_rejected() {
let mut cfg = ServerConfig::default();
cfg.auth.bearer_token = Some("short".to_string());
assert!(cfg.validate().is_err());
}
}