use serde::{Deserialize, Serialize};
use std::path::Path;
use crate::error::Result;
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct GatewayConfig {
#[serde(default)]
pub server: ServerConfig,
#[serde(default)]
pub providers: Vec<ProviderConfig>,
#[serde(default)]
pub routes: Vec<ModelRoute>,
#[serde(default)]
pub virtual_keys: Vec<VirtualKeyConfig>,
#[serde(default)]
pub logging: LoggingConfig,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ServerConfig {
#[serde(default = "default_host")]
pub host: String,
#[serde(default = "default_port")]
pub port: u16,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
host: default_host(),
port: default_port(),
}
}
}
fn default_host() -> String {
"0.0.0.0".to_string()
}
fn default_port() -> u16 {
4000
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ProviderKind {
Openai,
Anthropic,
OpenaiCompatible,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ProviderConfig {
pub name: String,
pub kind: ProviderKind,
pub api_base: String,
#[serde(default)]
pub api_key: Option<String>,
#[serde(default)]
pub api_key_env: Option<String>,
#[serde(default)]
pub egress_proxy: Option<String>,
}
impl ProviderConfig {
pub fn resolve_api_key(&self) -> Option<String> {
if let Some(k) = &self.api_key {
return Some(k.clone());
}
self.api_key_env
.as_ref()
.and_then(|e| std::env::var(e).ok())
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum BalancingStrategy {
#[default]
RoundRobin,
Random,
PowerOfTwo,
ConsistentHash,
CacheAware,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Target {
pub provider: String,
#[serde(default)]
pub model: Option<String>,
#[serde(default = "default_weight")]
pub weight: u32,
}
fn default_weight() -> u32 {
1
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ModelRoute {
pub model: String,
#[serde(default)]
pub strategy: BalancingStrategy,
pub targets: Vec<Target>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct VirtualKeyConfig {
pub key: String,
#[serde(default)]
pub name: Option<String>,
#[serde(default)]
pub models: Vec<String>,
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct LoggingConfig {
#[serde(default)]
pub clickhouse_url: Option<String>,
}
impl GatewayConfig {
pub fn from_toml_str(s: &str) -> Result<Self> {
Ok(toml::from_str(s)?)
}
pub fn load(path: &Path) -> Result<Self> {
let raw = std::fs::read_to_string(path)?;
Self::from_toml_str(&raw)
}
pub fn resolve_provider(&self, name: &str) -> Option<&ProviderConfig> {
self.providers.iter().find(|p| p.name == name)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_minimal_config() {
let cfg = GatewayConfig::from_toml_str(
r#"
[[providers]]
name = "openai"
kind = "openai"
api_base = "https://api.openai.com"
[[routes]]
model = "gpt-4o"
strategy = "round_robin"
[[routes.targets]]
provider = "openai"
"#,
)
.unwrap();
assert_eq!(cfg.server.port, 4000);
assert_eq!(cfg.providers.len(), 1);
assert_eq!(cfg.routes[0].strategy, BalancingStrategy::RoundRobin);
assert_eq!(cfg.routes[0].targets[0].weight, 1);
}
}