1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct GatewayConfig {
6 pub server: ServerConfig,
7 pub providers: Vec<ProviderConfig>,
8 #[serde(default, skip_serializing_if = "Option::is_none")]
11 pub reliability: Option<ReliabilityConfig>,
12 #[serde(default, skip_serializing_if = "Option::is_none")]
15 pub routing: Option<RoutingConfig>,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ServerConfig {
20 pub host: String,
21 pub port: u16,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
26#[serde(rename_all = "snake_case")]
27pub enum ProviderType {
28 #[default]
30 OpenAiCompatible,
31 Anthropic,
33 Cursor,
35 ClaudeCode,
37 CodexCli,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ProviderConfig {
43 pub name: String,
44 pub base_url: String,
45 #[serde(default)]
49 pub api_key_envs: Vec<String>,
50 pub enabled: bool,
51 #[serde(default)]
53 pub provider_type: ProviderType,
54 #[serde(default)]
57 pub extra_headers: HashMap<String, String>,
58 #[serde(default)]
59 pub rate_limit: Option<RateLimitConfig>,
60 #[serde(default)]
65 pub models: Vec<String>,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct RateLimitConfig {
70 pub requests_per_minute: u32,
71 pub burst_size: u32,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct ReliabilityConfig {
77 #[serde(default = "default_max_retries")]
79 pub max_retries: u32,
80 #[serde(default = "default_base_backoff_ms")]
82 pub base_backoff_ms: u64,
83 #[serde(default = "default_max_backoff_ms")]
85 pub max_backoff_ms: u64,
86 #[serde(default)]
89 pub fallback_chain: Vec<String>,
90}
91
92fn default_max_retries() -> u32 {
93 3
94}
95fn default_base_backoff_ms() -> u64 {
96 200
97}
98fn default_max_backoff_ms() -> u64 {
99 10_000
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct RoutingConfig {
105 #[serde(default)]
108 pub model_routes: HashMap<String, String>,
109 #[serde(default, skip_serializing_if = "Option::is_none")]
111 pub default_route: Option<String>,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct AgentConfig {
116 pub role: String,
117 pub skills: Vec<String>,
118 pub king_address: String,
119}
120
121impl GatewayConfig {
122 pub fn from_toml(content: &str) -> Result<Self, toml::de::Error> {
123 toml::from_str(content)
124 }
125
126 pub fn to_toml(&self) -> Result<String, toml::ser::Error> {
127 toml::to_string_pretty(self)
128 }
129
130 pub fn from_json(content: &str) -> Result<Self, serde_json::Error> {
131 serde_json::from_str(content)
132 }
133
134 pub fn to_json(&self) -> Result<String, serde_json::Error> {
135 serde_json::to_string_pretty(self)
136 }
137}
138
139impl AgentConfig {
140 pub fn from_toml(content: &str) -> Result<Self, toml::de::Error> {
141 toml::from_str(content)
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[test]
150 fn parse_gateway_config_with_pool() {
151 let toml_str = r#"
152[server]
153host = "0.0.0.0"
154port = 8080
155
156[[providers]]
157name = "openai"
158base_url = "https://api.openai.com/v1"
159api_key_envs = ["OPENAI_API_KEY_1", "OPENAI_API_KEY_2"]
160enabled = true
161provider_type = "open_ai_compatible"
162
163[[providers]]
164name = "anthropic"
165base_url = "https://api.anthropic.com/v1"
166api_key_envs = ["ANTHROPIC_API_KEY"]
167enabled = true
168provider_type = "anthropic"
169
170[[providers]]
171name = "openrouter"
172base_url = "https://openrouter.ai/api/v1"
173api_key_envs = ["OPENROUTER_API_KEY"]
174enabled = true
175provider_type = "open_ai_compatible"
176
177[providers.extra_headers]
178"HTTP-Referer" = "https://github.com/ai-evo-agents"
179"X-Title" = "evo-gateway"
180"#;
181 let config = GatewayConfig::from_toml(toml_str).unwrap();
182 assert_eq!(config.server.port, 8080);
183 assert_eq!(config.providers.len(), 3);
184 assert_eq!(config.providers[0].api_key_envs.len(), 2);
185 assert_eq!(config.providers[1].provider_type, ProviderType::Anthropic);
186 assert!(
187 config.providers[2]
188 .extra_headers
189 .contains_key("HTTP-Referer")
190 );
191 }
192
193 #[test]
194 fn roundtrip_gateway_config_toml() {
195 let config = GatewayConfig {
196 server: ServerConfig {
197 host: "127.0.0.1".into(),
198 port: 3000,
199 },
200 providers: vec![ProviderConfig {
201 name: "test".into(),
202 base_url: "http://localhost:11434".into(),
203 api_key_envs: vec![],
204 enabled: true,
205 provider_type: ProviderType::OpenAiCompatible,
206 extra_headers: HashMap::new(),
207 rate_limit: None,
208 models: vec![],
209 }],
210 reliability: None,
211 routing: None,
212 };
213 let toml_str = config.to_toml().unwrap();
214 let parsed = GatewayConfig::from_toml(&toml_str).unwrap();
215 assert_eq!(parsed.server.port, 3000);
216 assert_eq!(parsed.providers[0].api_key_envs.len(), 0);
217 }
218
219 #[test]
220 fn roundtrip_gateway_config_json() {
221 let config = GatewayConfig {
222 server: ServerConfig {
223 host: "0.0.0.0".into(),
224 port: 8080,
225 },
226 providers: vec![
227 ProviderConfig {
228 name: "openai".into(),
229 base_url: "https://api.openai.com/v1".into(),
230 api_key_envs: vec!["OPENAI_API_KEY".into()],
231 enabled: true,
232 provider_type: ProviderType::OpenAiCompatible,
233 extra_headers: HashMap::new(),
234 rate_limit: None,
235 models: vec![],
236 },
237 ProviderConfig {
238 name: "anthropic".into(),
239 base_url: "https://api.anthropic.com/v1".into(),
240 api_key_envs: vec!["ANTHROPIC_API_KEY".into()],
241 enabled: true,
242 provider_type: ProviderType::Anthropic,
243 extra_headers: HashMap::new(),
244 rate_limit: None,
245 models: vec![],
246 },
247 ],
248 reliability: None,
249 routing: None,
250 };
251 let json_str = config.to_json().unwrap();
252 let parsed = GatewayConfig::from_json(&json_str).unwrap();
253 assert_eq!(parsed.server.port, 8080);
254 assert_eq!(parsed.providers.len(), 2);
255 assert_eq!(parsed.providers[1].provider_type, ProviderType::Anthropic);
256 assert_eq!(parsed.providers[0].api_key_envs[0], "OPENAI_API_KEY");
257 }
258
259 #[test]
260 fn roundtrip_provider_type_claude_code() {
261 let config = GatewayConfig {
262 server: ServerConfig {
263 host: "127.0.0.1".into(),
264 port: 8080,
265 },
266 providers: vec![ProviderConfig {
267 name: "claude-code".into(),
268 base_url: String::new(),
269 api_key_envs: vec![],
270 enabled: false,
271 provider_type: ProviderType::ClaudeCode,
272 extra_headers: HashMap::new(),
273 rate_limit: None,
274 models: vec![],
275 }],
276 reliability: None,
277 routing: None,
278 };
279 let json_str = config.to_json().unwrap();
280 assert!(json_str.contains("\"claude_code\""));
281 let parsed = GatewayConfig::from_json(&json_str).unwrap();
282 assert_eq!(parsed.providers[0].provider_type, ProviderType::ClaudeCode);
283 }
284
285 #[test]
286 fn roundtrip_provider_type_codex_cli() {
287 let config = GatewayConfig {
288 server: ServerConfig {
289 host: "127.0.0.1".into(),
290 port: 8080,
291 },
292 providers: vec![ProviderConfig {
293 name: "codex-cli".into(),
294 base_url: String::new(),
295 api_key_envs: vec![],
296 enabled: false,
297 provider_type: ProviderType::CodexCli,
298 extra_headers: HashMap::new(),
299 rate_limit: None,
300 models: vec![],
301 }],
302 reliability: None,
303 routing: None,
304 };
305 let json_str = config.to_json().unwrap();
306 assert!(json_str.contains("\"codex_cli\""));
307 let parsed = GatewayConfig::from_json(&json_str).unwrap();
308 assert_eq!(parsed.providers[0].provider_type, ProviderType::CodexCli);
309 }
310
311 #[test]
312 fn roundtrip_provider_type_cursor() {
313 let config = GatewayConfig {
314 server: ServerConfig {
315 host: "127.0.0.1".into(),
316 port: 8080,
317 },
318 providers: vec![ProviderConfig {
319 name: "cursor".into(),
320 base_url: String::new(),
321 api_key_envs: vec![],
322 enabled: false,
323 provider_type: ProviderType::Cursor,
324 extra_headers: HashMap::new(),
325 rate_limit: None,
326 models: vec![],
327 }],
328 reliability: None,
329 routing: None,
330 };
331 let json_str = config.to_json().unwrap();
332 assert!(json_str.contains("\"cursor\""));
333 let parsed = GatewayConfig::from_json(&json_str).unwrap();
334 assert_eq!(parsed.providers[0].provider_type, ProviderType::Cursor);
335 }
336
337 #[test]
338 fn roundtrip_provider_models_field() {
339 let config = GatewayConfig {
340 server: ServerConfig {
341 host: "127.0.0.1".into(),
342 port: 8080,
343 },
344 providers: vec![ProviderConfig {
345 name: "openai".into(),
346 base_url: "https://api.openai.com/v1".into(),
347 api_key_envs: vec![],
348 enabled: true,
349 provider_type: ProviderType::OpenAiCompatible,
350 extra_headers: HashMap::new(),
351 rate_limit: None,
352 models: vec!["gpt-4o".into(), "gpt-4o-mini".into()],
353 }],
354 reliability: None,
355 routing: None,
356 };
357 let json_str = config.to_json().unwrap();
358 assert!(json_str.contains("gpt-4o"));
359 let parsed = GatewayConfig::from_json(&json_str).unwrap();
360 assert_eq!(parsed.providers[0].models.len(), 2);
361 assert_eq!(parsed.providers[0].models[0], "gpt-4o");
362 assert_eq!(parsed.providers[0].models[1], "gpt-4o-mini");
363 }
364
365 #[test]
366 fn models_field_defaults_to_empty() {
367 let json_str = r#"{
369 "server": { "host": "127.0.0.1", "port": 8080 },
370 "providers": [{
371 "name": "test",
372 "base_url": "",
373 "enabled": true
374 }]
375 }"#;
376 let config = GatewayConfig::from_json(json_str).unwrap();
377 assert!(config.providers[0].models.is_empty());
378 }
379}