use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ConfigError {
#[error("failed to parse TOML config: {0}")]
TomlParse(String),
#[error("failed to parse environment variable {name}: {reason}")]
EnvParse {
name: String,
reason: String,
},
#[error("configuration validation failed: {0}")]
Validation(String),
#[error("failed to read config file {path}: {source}")]
Io {
path: String,
#[source]
source: std::io::Error,
},
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct BindConfig {
pub host: String,
pub port: u16,
}
impl Default for BindConfig {
fn default() -> Self {
Self {
host: "0.0.0.0".to_string(),
port: 8080,
}
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct ModelConfig {
pub path: Option<PathBuf>,
pub quantization_hint: Option<String>,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct TokenizerConfigSection {
pub path: Option<PathBuf>,
pub kind: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct SamplingConfig {
pub default_max_tokens: usize,
pub default_temperature: f32,
pub default_top_p: f32,
}
impl Default for SamplingConfig {
fn default() -> Self {
Self {
default_max_tokens: 256,
default_temperature: 0.7,
default_top_p: 1.0,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct LimitsConfig {
pub max_input_tokens: usize,
pub max_concurrent_requests: usize,
pub per_request_timeout_ms: u64,
}
impl Default for LimitsConfig {
fn default() -> Self {
Self {
max_input_tokens: 8192,
max_concurrent_requests: 32,
per_request_timeout_ms: 60_000,
}
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct AuthConfig {
pub bearer_token: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct ObservabilityConfig {
pub log_level: String,
pub metrics_enabled: bool,
pub metrics_path: String,
}
impl Default for ObservabilityConfig {
fn default() -> Self {
Self {
log_level: "info".to_string(),
metrics_enabled: true,
metrics_path: "/metrics".to_string(),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct ServerConfig {
#[serde(default)]
pub bind: BindConfig,
#[serde(default)]
pub model: ModelConfig,
#[serde(default)]
pub tokenizer: TokenizerConfigSection,
#[serde(default)]
pub sampling: SamplingConfig,
#[serde(default)]
pub limits: LimitsConfig,
#[serde(default)]
pub auth: AuthConfig,
#[serde(default)]
pub observability: ObservabilityConfig,
#[serde(default = "default_seed")]
pub seed: u64,
}
fn default_seed() -> u64 {
42
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
bind: BindConfig::default(),
model: ModelConfig::default(),
tokenizer: TokenizerConfigSection::default(),
sampling: SamplingConfig::default(),
limits: LimitsConfig::default(),
auth: AuthConfig::default(),
observability: ObservabilityConfig::default(),
seed: default_seed(),
}
}
}
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
pub struct PartialServerConfig {
pub host: Option<String>,
pub port: Option<u16>,
pub model_path: Option<PathBuf>,
pub quantization_hint: Option<String>,
pub tokenizer_path: Option<PathBuf>,
pub tokenizer_kind: Option<String>,
pub default_max_tokens: Option<usize>,
pub default_temperature: Option<f32>,
pub default_top_p: Option<f32>,
pub max_input_tokens: Option<usize>,
pub max_concurrent_requests: Option<usize>,
pub per_request_timeout_ms: Option<u64>,
pub bearer_token: Option<String>,
pub log_level: Option<String>,
pub metrics_enabled: Option<bool>,
pub metrics_path: Option<String>,
pub seed: Option<u64>,
}
impl PartialServerConfig {
pub fn merge(mut self, other: PartialServerConfig) -> Self {
macro_rules! merge_field {
($name:ident) => {
if other.$name.is_some() {
self.$name = other.$name;
}
};
}
merge_field!(host);
merge_field!(port);
merge_field!(model_path);
merge_field!(quantization_hint);
merge_field!(tokenizer_path);
merge_field!(tokenizer_kind);
merge_field!(default_max_tokens);
merge_field!(default_temperature);
merge_field!(default_top_p);
merge_field!(max_input_tokens);
merge_field!(max_concurrent_requests);
merge_field!(per_request_timeout_ms);
merge_field!(bearer_token);
merge_field!(log_level);
merge_field!(metrics_enabled);
merge_field!(metrics_path);
merge_field!(seed);
self
}
pub fn from_toml_str(s: &str) -> Result<Self, ConfigError> {
let helper: TomlHelper =
toml::from_str(s).map_err(|e| ConfigError::TomlParse(e.to_string()))?;
Ok(helper.into_partial())
}
}
#[derive(Debug, Default, Deserialize)]
struct TomlHelper {
#[serde(default)]
bind: Option<BindPartial>,
#[serde(default)]
model: Option<ModelPartial>,
#[serde(default)]
tokenizer: Option<TokenizerPartial>,
#[serde(default)]
sampling: Option<SamplingPartial>,
#[serde(default)]
limits: Option<LimitsPartial>,
#[serde(default)]
auth: Option<AuthPartial>,
#[serde(default)]
observability: Option<ObservabilityPartial>,
#[serde(default)]
seed: Option<u64>,
}
#[derive(Debug, Default, Deserialize)]
struct BindPartial {
host: Option<String>,
port: Option<u16>,
}
#[derive(Debug, Default, Deserialize)]
struct ModelPartial {
path: Option<PathBuf>,
quantization_hint: Option<String>,
}
#[derive(Debug, Default, Deserialize)]
struct TokenizerPartial {
path: Option<PathBuf>,
kind: Option<String>,
}
#[derive(Debug, Default, Deserialize)]
struct SamplingPartial {
default_max_tokens: Option<usize>,
default_temperature: Option<f32>,
default_top_p: Option<f32>,
}
#[derive(Debug, Default, Deserialize)]
struct LimitsPartial {
max_input_tokens: Option<usize>,
max_concurrent_requests: Option<usize>,
per_request_timeout_ms: Option<u64>,
}
#[derive(Debug, Default, Deserialize)]
struct AuthPartial {
bearer_token: Option<String>,
}
#[derive(Debug, Default, Deserialize)]
struct ObservabilityPartial {
log_level: Option<String>,
metrics_enabled: Option<bool>,
metrics_path: Option<String>,
}
impl TomlHelper {
fn into_partial(self) -> PartialServerConfig {
let bind = self.bind.unwrap_or_default();
let model = self.model.unwrap_or_default();
let tok = self.tokenizer.unwrap_or_default();
let samp = self.sampling.unwrap_or_default();
let lim = self.limits.unwrap_or_default();
let auth = self.auth.unwrap_or_default();
let obs = self.observability.unwrap_or_default();
PartialServerConfig {
host: bind.host,
port: bind.port,
model_path: model.path,
quantization_hint: model.quantization_hint,
tokenizer_path: tok.path,
tokenizer_kind: tok.kind,
default_max_tokens: samp.default_max_tokens,
default_temperature: samp.default_temperature,
default_top_p: samp.default_top_p,
max_input_tokens: lim.max_input_tokens,
max_concurrent_requests: lim.max_concurrent_requests,
per_request_timeout_ms: lim.per_request_timeout_ms,
bearer_token: auth.bearer_token,
log_level: obs.log_level,
metrics_enabled: obs.metrics_enabled,
metrics_path: obs.metrics_path,
seed: self.seed,
}
}
}
impl ServerConfig {
pub fn from_toml(s: &str) -> Result<Self, ConfigError> {
let partial = PartialServerConfig::from_toml_str(s)?;
Ok(Self::from_partial(partial))
}
pub fn from_partial(p: PartialServerConfig) -> Self {
let mut out = Self::default();
if let Some(v) = p.host {
out.bind.host = v;
}
if let Some(v) = p.port {
out.bind.port = v;
}
if let Some(v) = p.model_path {
out.model.path = Some(v);
}
if let Some(v) = p.quantization_hint {
out.model.quantization_hint = Some(v);
}
if let Some(v) = p.tokenizer_path {
out.tokenizer.path = Some(v);
}
if let Some(v) = p.tokenizer_kind {
out.tokenizer.kind = Some(v);
}
if let Some(v) = p.default_max_tokens {
out.sampling.default_max_tokens = v;
}
if let Some(v) = p.default_temperature {
out.sampling.default_temperature = v;
}
if let Some(v) = p.default_top_p {
out.sampling.default_top_p = v;
}
if let Some(v) = p.max_input_tokens {
out.limits.max_input_tokens = v;
}
if let Some(v) = p.max_concurrent_requests {
out.limits.max_concurrent_requests = v;
}
if let Some(v) = p.per_request_timeout_ms {
out.limits.per_request_timeout_ms = v;
}
if let Some(v) = p.bearer_token {
out.auth.bearer_token = Some(v);
}
if let Some(v) = p.log_level {
out.observability.log_level = v;
}
if let Some(v) = p.metrics_enabled {
out.observability.metrics_enabled = v;
}
if let Some(v) = p.metrics_path {
out.observability.metrics_path = v;
}
if let Some(v) = p.seed {
out.seed = v;
}
out
}
pub fn to_toml_string(&self) -> Result<String, ConfigError> {
toml::to_string_pretty(self).map_err(|e| ConfigError::TomlParse(e.to_string()))
}
pub fn from_toml_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self, ConfigError> {
let p = path.as_ref();
let body = std::fs::read_to_string(p).map_err(|e| ConfigError::Io {
path: p.display().to_string(),
source: e,
})?;
Self::from_toml(&body)
}
pub fn partial_from_file<P: AsRef<std::path::Path>>(
path: P,
) -> Result<PartialServerConfig, ConfigError> {
let p = path.as_ref();
let body = std::fs::read_to_string(p).map_err(|e| ConfigError::Io {
path: p.display().to_string(),
source: e,
})?;
PartialServerConfig::from_toml_str(&body)
}
pub fn load(
toml_path: Option<&std::path::Path>,
env_partial: Option<PartialServerConfig>,
cli_partial: Option<PartialServerConfig>,
) -> Result<Self, ConfigError> {
let mut merged = PartialServerConfig::default();
if let Some(p) = toml_path {
let from_file = Self::partial_from_file(p)?;
merged = merged.merge(from_file);
}
if let Some(env) = env_partial {
merged = merged.merge(env);
}
if let Some(cli) = cli_partial {
merged = merged.merge(cli);
}
let cfg = Self::from_partial(merged);
cfg.validate()?;
Ok(cfg)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn defaults_roundtrip() {
let cfg = ServerConfig::default();
let toml = cfg.to_toml_string().expect("to_toml");
let parsed = ServerConfig::from_toml(&toml).expect("from_toml");
assert_eq!(cfg, parsed);
}
#[test]
fn partial_default_is_empty() {
let p = PartialServerConfig::default();
assert!(p.host.is_none());
assert!(p.port.is_none());
}
#[test]
fn partial_merge_overrides() {
let a = PartialServerConfig {
port: Some(1),
log_level: Some("info".to_string()),
..Default::default()
};
let b = PartialServerConfig {
port: Some(2),
..Default::default()
};
let merged = a.merge(b);
assert_eq!(merged.port, Some(2));
assert_eq!(merged.log_level.as_deref(), Some("info"));
}
#[test]
fn from_partial_applies_fields() {
let p = PartialServerConfig {
host: Some("1.2.3.4".to_string()),
port: Some(9999),
default_top_p: Some(0.9),
..Default::default()
};
let cfg = ServerConfig::from_partial(p);
assert_eq!(cfg.bind.host, "1.2.3.4");
assert_eq!(cfg.bind.port, 9999);
assert!((cfg.sampling.default_top_p - 0.9).abs() < f32::EPSILON);
}
}