1use serde::{Deserialize, Serialize};
2use std::collections::{BTreeMap, HashMap};
3
4#[derive(Debug, Clone, Default, Serialize, Deserialize)]
12#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
13pub struct PricingConfig {
14 #[serde(alias = "prompt_cost_per_million")]
16 pub input_cost_per_million: f64,
17 #[serde(alias = "completion_cost_per_million")]
19 pub output_cost_per_million: f64,
20
21 #[serde(
24 alias = "cache_hit_cost_per_million",
25 default,
26 skip_serializing_if = "Option::is_none"
27 )]
28 pub cache_read_cost_per_million: Option<f64>,
29 #[serde(default, skip_serializing_if = "Option::is_none")]
33 pub cache_write_cost_per_million: Option<f64>,
34 #[serde(default, skip_serializing_if = "Option::is_none")]
37 pub reasoning_cost_per_million: Option<f64>,
38 #[serde(default, skip_serializing_if = "Option::is_none")]
41 pub audio_input_cost_per_million: Option<f64>,
42 #[serde(default, skip_serializing_if = "Option::is_none")]
45 pub audio_output_cost_per_million: Option<f64>,
46
47 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
51 pub server_tool_cost_per_call: BTreeMap<String, f64>,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct GatewayConfig {
57 #[serde(default = "default_listen")]
59 pub listen: String,
60 #[serde(default)]
62 pub providers: HashMap<String, ProviderConfig>,
63 #[serde(default)]
65 pub keys: Vec<KeyConfig>,
66 #[serde(default)]
68 pub extensions: Option<serde_json::Value>,
69 #[serde(default)]
71 pub storage: Option<StorageConfig>,
72 #[serde(default)]
74 pub aliases: HashMap<String, String>,
75 #[serde(default)]
78 pub models: HashMap<String, crate::ModelInfo>,
79 #[serde(default, skip_serializing_if = "Option::is_none")]
82 pub cloud_models: Option<String>,
83 #[serde(default, skip_serializing_if = "Option::is_none")]
85 pub admin_token: Option<String>,
86 #[serde(default = "default_shutdown_timeout")]
88 pub shutdown_timeout: u64,
89 #[serde(default = "default_openapi")]
93 pub openapi: bool,
94}
95
96#[derive(Debug, Default, Clone, Serialize, Deserialize)]
98pub struct ProviderConfig {
99 #[serde(default, skip_serializing_if = "ProviderKind::is_default")]
101 pub kind: ProviderKind,
102 #[serde(default, skip_serializing_if = "Option::is_none")]
104 pub api_key: Option<String>,
105 #[serde(default, skip_serializing_if = "Option::is_none")]
107 pub base_url: Option<String>,
108 #[serde(default, skip_serializing_if = "Vec::is_empty")]
110 pub models: Vec<String>,
111 #[serde(default, skip_serializing_if = "Option::is_none")]
113 pub weight: Option<u16>,
114 #[serde(default, skip_serializing_if = "Option::is_none")]
116 pub max_retries: Option<u32>,
117 #[serde(default, skip_serializing_if = "Option::is_none")]
119 pub api_version: Option<String>,
120 #[serde(default, skip_serializing_if = "Option::is_none")]
122 pub timeout: Option<u64>,
123 #[serde(default, skip_serializing_if = "Option::is_none")]
127 pub retry_deadline: Option<u64>,
128 #[serde(default, skip_serializing_if = "Option::is_none")]
130 pub region: Option<String>,
131 #[serde(default, skip_serializing_if = "Option::is_none")]
133 pub access_key: Option<String>,
134 #[serde(default, skip_serializing)]
136 pub secret_key: Option<String>,
137}
138
139fn default_shutdown_timeout() -> u64 {
140 30
141}
142
143fn default_openapi() -> bool {
144 true
145}
146
147fn default_listen() -> String {
148 "127.0.0.1:5632".to_string()
149}
150
151#[derive(Debug, Default, Clone, PartialEq, Eq)]
157pub enum ProviderKind {
158 #[default]
159 Openai,
160 Anthropic,
161 Deepseek,
162 Google,
163 Bedrock,
164 Ollama,
165 Azure,
166 Custom(String),
169}
170
171impl ProviderKind {
172 pub fn as_str(&self) -> &str {
173 match self {
174 Self::Openai => "openai",
175 Self::Anthropic => "anthropic",
176 Self::Deepseek => "deepseek",
177 Self::Google => "google",
178 Self::Bedrock => "bedrock",
179 Self::Ollama => "ollama",
180 Self::Azure => "azure",
181 Self::Custom(s) => s,
182 }
183 }
184
185 pub fn is_default(&self) -> bool {
187 matches!(self, Self::Openai)
188 }
189}
190
191impl std::fmt::Display for ProviderKind {
192 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193 f.write_str(self.as_str())
194 }
195}
196
197impl serde::Serialize for ProviderKind {
198 fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
199 s.serialize_str(self.as_str())
200 }
201}
202
203impl<'de> serde::Deserialize<'de> for ProviderKind {
204 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
205 let s = String::deserialize(d)?;
206 Ok(match s.as_str() {
207 "openai" => Self::Openai,
208 "anthropic" => Self::Anthropic,
209 "deepseek" => Self::Deepseek,
210 "google" => Self::Google,
211 "bedrock" => Self::Bedrock,
212 "ollama" => Self::Ollama,
213 "azure" => Self::Azure,
214 _ => Self::Custom(s),
215 })
216 }
217}
218
219impl ProviderConfig {
220 pub fn effective_kind(&self) -> ProviderKind {
226 if self.kind != ProviderKind::Openai {
227 return self.kind.clone();
228 }
229 if let Some(url) = &self.base_url
230 && url.contains("anthropic")
231 {
232 return ProviderKind::Anthropic;
233 }
234 self.kind.clone()
235 }
236
237 pub fn validate(&self, provider_name: &str) -> Result<(), String> {
239 if self.models.is_empty() {
240 return Err(format!("provider '{provider_name}' has no models"));
241 }
242 match &self.kind {
243 ProviderKind::Bedrock => {
244 if self.region.is_none() {
245 return Err(format!(
246 "provider '{provider_name}' (bedrock) requires region"
247 ));
248 }
249 if self.access_key.is_none() {
250 return Err(format!(
251 "provider '{provider_name}' (bedrock) requires access_key"
252 ));
253 }
254 if self.secret_key.is_none() {
255 return Err(format!(
256 "provider '{provider_name}' (bedrock) requires secret_key"
257 ));
258 }
259 }
260 ProviderKind::Ollama => {
261 }
263 ProviderKind::Custom(name) => {
264 if self.base_url.is_none() {
265 return Err(format!(
266 "provider '{provider_name}' (custom kind '{name}') requires base_url"
267 ));
268 }
269 }
270 _ => {
271 if self.api_key.is_none() && self.base_url.is_none() {
272 return Err(format!(
273 "provider '{provider_name}' requires api_key or base_url"
274 ));
275 }
276 }
277 }
278 Ok(())
279 }
280}
281
282#[derive(Debug, Clone, Serialize, Deserialize)]
285#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
286pub struct KeyRateLimit {
287 #[serde(default, skip_serializing_if = "Option::is_none")]
288 pub requests_per_minute: Option<u64>,
289 #[serde(default, skip_serializing_if = "Option::is_none")]
290 pub tokens_per_minute: Option<u64>,
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct KeyConfig {
296 pub name: String,
298 pub key: String,
300 pub models: Vec<String>,
302 #[serde(default, skip_serializing_if = "Option::is_none")]
305 pub rate_limit: Option<KeyRateLimit>,
306}
307
308#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct StorageConfig {
311 #[serde(default = "StorageConfig::default_kind")]
313 pub kind: String,
314 #[serde(default)]
316 pub path: Option<String>,
317}
318
319impl StorageConfig {
320 fn default_kind() -> String {
321 "memory".to_string()
322 }
323}
324
325impl GatewayConfig {
326 #[cfg(feature = "gateway")]
331 pub fn from_file(path: &std::path::Path) -> Result<Self, Box<dyn std::error::Error>> {
332 let raw = std::fs::read_to_string(path)?;
333 let expanded = expand_env_vars(&raw);
334
335 let mut config: GatewayConfig = toml::from_str(&expanded)?;
336
337 let config_dir = path.parent().unwrap_or_else(|| std::path::Path::new("."));
338 config.load_cloud_models(config_dir)?;
339
340 Ok(config)
341 }
342
343 #[cfg(feature = "gateway")]
347 fn load_cloud_models(
348 &mut self,
349 config_dir: &std::path::Path,
350 ) -> Result<(), Box<dyn std::error::Error>> {
351 let Some(ref path) = self.cloud_models else {
352 return Ok(());
353 };
354 let full = config_dir.join(path);
355 let raw = std::fs::read_to_string(&full)
356 .map_err(|e| format!("cloud_models '{}': {e}", full.display()))?;
357 let table: HashMap<String, crate::ModelInfo> =
358 toml::from_str(&raw).map_err(|e| format!("cloud_models '{}': {e}", full.display()))?;
359 for (model, info) in table {
360 self.models.entry(model).or_insert(info);
361 }
362 Ok(())
363 }
364}
365
366#[cfg(feature = "gateway")]
369fn expand_env_vars(input: &str) -> String {
370 let mut result = String::with_capacity(input.len());
371 let mut chars = input.chars().peekable();
372
373 while let Some(c) = chars.next() {
374 if c == '$' && chars.peek() == Some(&'{') {
375 chars.next(); let mut var_name = String::new();
377 for ch in chars.by_ref() {
378 if ch == '}' {
379 break;
380 }
381 var_name.push(ch);
382 }
383 if let Ok(val) = std::env::var(&var_name) {
384 result.push_str(&val);
385 }
386 } else {
387 result.push(c);
388 }
389 }
390
391 result
392}