1use serde::{Deserialize, Serialize};
2use std::path::Path;
3
4use crate::error::Result;
5
6#[derive(Debug, Clone, Default, Deserialize, Serialize)]
8pub struct GatewayConfig {
9 #[serde(default)]
10 pub server: ServerConfig,
11 #[serde(default)]
12 pub providers: Vec<ProviderConfig>,
13 #[serde(default)]
14 pub routes: Vec<ModelRoute>,
15 #[serde(default)]
16 pub virtual_keys: Vec<VirtualKeyConfig>,
17 #[serde(default)]
18 pub logging: LoggingConfig,
19}
20
21#[derive(Debug, Clone, Deserialize, Serialize)]
23pub struct ServerConfig {
24 #[serde(default = "default_host")]
25 pub host: String,
26 #[serde(default = "default_port")]
27 pub port: u16,
28}
29
30impl Default for ServerConfig {
31 fn default() -> Self {
32 Self {
33 host: default_host(),
34 port: default_port(),
35 }
36 }
37}
38
39fn default_host() -> String {
40 "0.0.0.0".to_string()
41}
42
43fn default_port() -> u16 {
44 4000
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
49#[serde(rename_all = "snake_case")]
50pub enum ProviderKind {
51 Openai,
53 Anthropic,
55 OpenaiCompatible,
57}
58
59#[derive(Debug, Clone, Deserialize, Serialize)]
61pub struct ProviderConfig {
62 pub name: String,
63 pub kind: ProviderKind,
64 pub api_base: String,
66 #[serde(default)]
68 pub api_key: Option<String>,
69 #[serde(default)]
71 pub api_key_env: Option<String>,
72 #[serde(default)]
74 pub egress_proxy: Option<String>,
75}
76
77impl ProviderConfig {
78 pub fn resolve_api_key(&self) -> Option<String> {
80 if let Some(k) = &self.api_key {
81 return Some(k.clone());
82 }
83 self.api_key_env
84 .as_ref()
85 .and_then(|e| std::env::var(e).ok())
86 }
87}
88
89#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Deserialize, Serialize)]
91#[serde(rename_all = "snake_case")]
92pub enum BalancingStrategy {
93 #[default]
94 RoundRobin,
95 Random,
96 PowerOfTwo,
97 ConsistentHash,
98 CacheAware,
99}
100
101#[derive(Debug, Clone, Deserialize, Serialize)]
103pub struct Target {
104 pub provider: String,
106 #[serde(default)]
108 pub model: Option<String>,
109 #[serde(default = "default_weight")]
110 pub weight: u32,
111}
112
113fn default_weight() -> u32 {
114 1
115}
116
117#[derive(Debug, Clone, Deserialize, Serialize)]
119pub struct ModelRoute {
120 pub model: String,
122 #[serde(default)]
123 pub strategy: BalancingStrategy,
124 pub targets: Vec<Target>,
125}
126
127#[derive(Debug, Clone, Deserialize, Serialize)]
129pub struct VirtualKeyConfig {
130 pub key: String,
131 #[serde(default)]
132 pub name: Option<String>,
133 #[serde(default)]
135 pub models: Vec<String>,
136}
137
138#[derive(Debug, Clone, Default, Deserialize, Serialize)]
140pub struct LoggingConfig {
141 #[serde(default)]
142 pub clickhouse_url: Option<String>,
143}
144
145impl GatewayConfig {
146 pub fn from_toml_str(s: &str) -> Result<Self> {
148 Ok(toml::from_str(s)?)
149 }
150
151 pub fn load(path: &Path) -> Result<Self> {
153 let raw = std::fs::read_to_string(path)?;
154 Self::from_toml_str(&raw)
155 }
156
157 pub fn resolve_provider(&self, name: &str) -> Option<&ProviderConfig> {
159 self.providers.iter().find(|p| p.name == name)
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 #[test]
168 fn parses_minimal_config() {
169 let cfg = GatewayConfig::from_toml_str(
170 r#"
171 [[providers]]
172 name = "openai"
173 kind = "openai"
174 api_base = "https://api.openai.com"
175
176 [[routes]]
177 model = "gpt-4o"
178 strategy = "round_robin"
179 [[routes.targets]]
180 provider = "openai"
181 "#,
182 )
183 .unwrap();
184 assert_eq!(cfg.server.port, 4000);
185 assert_eq!(cfg.providers.len(), 1);
186 assert_eq!(cfg.routes[0].strategy, BalancingStrategy::RoundRobin);
187 assert_eq!(cfg.routes[0].targets[0].weight, 1);
188 }
189}