1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct PricingConfig {
7 pub prompt_cost_per_million: f64,
9 pub completion_cost_per_million: f64,
11}
12
13pub fn cost(pricing: &PricingConfig, prompt_tokens: u32, completion_tokens: u32) -> f64 {
15 (prompt_tokens as f64 * pricing.prompt_cost_per_million
16 + completion_tokens as f64 * pricing.completion_cost_per_million)
17 / 1_000_000.0
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct GatewayConfig {
23 pub listen: String,
25 pub providers: HashMap<String, ProviderConfig>,
27 #[serde(default)]
29 pub keys: Vec<KeyConfig>,
30 #[serde(default)]
32 pub extensions: Option<serde_json::Value>,
33 #[serde(default)]
35 pub storage: Option<StorageConfig>,
36 #[serde(default)]
38 pub aliases: HashMap<String, String>,
39 #[serde(default)]
41 pub pricing: HashMap<String, PricingConfig>,
42 #[serde(default, skip_serializing_if = "Option::is_none")]
44 pub admin_token: Option<String>,
45 #[serde(default = "default_shutdown_timeout")]
47 pub shutdown_timeout: u64,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ProviderConfig {
53 #[serde(
55 default,
56 alias = "standard",
57 skip_serializing_if = "ProviderKind::is_default"
58 )]
59 pub kind: ProviderKind,
60 #[serde(default, skip_serializing_if = "Option::is_none")]
62 pub api_key: Option<String>,
63 #[serde(default, skip_serializing_if = "Option::is_none")]
65 pub base_url: Option<String>,
66 #[serde(default, skip_serializing_if = "Vec::is_empty")]
68 pub models: Vec<String>,
69 #[serde(default, skip_serializing_if = "Option::is_none")]
71 pub weight: Option<u16>,
72 #[serde(default, skip_serializing_if = "Option::is_none")]
74 pub max_retries: Option<u32>,
75 #[serde(default, skip_serializing_if = "Option::is_none")]
77 pub api_version: Option<String>,
78 #[serde(default, skip_serializing_if = "Option::is_none")]
80 pub timeout: Option<u64>,
81 #[serde(default, skip_serializing_if = "Option::is_none")]
83 pub region: Option<String>,
84 #[serde(default, skip_serializing_if = "Option::is_none")]
86 pub access_key: Option<String>,
87 #[serde(default, skip_serializing)]
89 pub secret_key: Option<String>,
90 #[serde(default, skip_serializing_if = "Option::is_none")]
92 pub model_path: Option<String>,
93 #[serde(default, skip_serializing_if = "Option::is_none")]
95 pub n_gpu_layers: Option<u32>,
96 #[serde(default, skip_serializing_if = "Option::is_none")]
98 pub n_ctx: Option<u32>,
99 #[serde(default, skip_serializing_if = "Option::is_none")]
101 pub n_threads: Option<u32>,
102}
103
104fn default_shutdown_timeout() -> u64 {
105 30
106}
107
108#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
110#[serde(rename_all = "snake_case")]
111pub enum ProviderKind {
112 #[default]
113 #[serde(alias = "openai")]
114 OpenaiCompat,
115 Anthropic,
116 Google,
117 Bedrock,
118 Ollama,
119 Azure,
120 #[serde(alias = "llama_cpp")]
121 LlamaCpp,
122}
123
124impl ProviderKind {
125 pub fn is_default(&self) -> bool {
127 *self == Self::OpenaiCompat
128 }
129}
130
131impl ProviderConfig {
132 pub fn effective_kind(&self) -> ProviderKind {
138 if self.kind == ProviderKind::Anthropic {
139 return ProviderKind::Anthropic;
140 }
141 if let Some(url) = &self.base_url
142 && url.contains("anthropic")
143 {
144 return ProviderKind::Anthropic;
145 }
146 self.kind
147 }
148
149 pub fn validate(&self, provider_name: &str) -> Result<(), String> {
151 if self.models.is_empty() {
152 return Err(format!("provider '{provider_name}' has no models"));
153 }
154 match self.kind {
155 ProviderKind::Bedrock => {
156 if self.region.is_none() {
157 return Err(format!(
158 "provider '{provider_name}' (bedrock) requires region"
159 ));
160 }
161 if self.access_key.is_none() {
162 return Err(format!(
163 "provider '{provider_name}' (bedrock) requires access_key"
164 ));
165 }
166 if self.secret_key.is_none() {
167 return Err(format!(
168 "provider '{provider_name}' (bedrock) requires secret_key"
169 ));
170 }
171 }
172 ProviderKind::Ollama => {
173 }
175 ProviderKind::LlamaCpp => match &self.model_path {
176 None => {
177 return Err(format!(
178 "provider '{provider_name}' (llamacpp) requires model_path"
179 ));
180 }
181 Some(path) => {
182 if !std::path::Path::new(path).exists() {
183 return Err(format!(
184 "provider '{provider_name}' (llamacpp): model_path '{path}' does not exist"
185 ));
186 }
187 }
188 },
189 _ => {
190 if self.api_key.is_none() && self.base_url.is_none() {
191 return Err(format!(
192 "provider '{provider_name}' requires api_key or base_url"
193 ));
194 }
195 }
196 }
197 Ok(())
198 }
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct KeyConfig {
204 pub name: String,
206 pub key: String,
208 pub models: Vec<String>,
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct StorageConfig {
215 #[serde(default = "StorageConfig::default_kind")]
217 pub kind: String,
218 #[serde(default)]
220 pub path: Option<String>,
221}
222
223impl StorageConfig {
224 fn default_kind() -> String {
225 "memory".to_string()
226 }
227}
228
229impl GatewayConfig {
230 #[cfg(feature = "gateway")]
232 pub fn from_file(path: &std::path::Path) -> Result<Self, Box<dyn std::error::Error>> {
233 let raw = std::fs::read_to_string(path)?;
234 let expanded = expand_env_vars(&raw);
235 let config: GatewayConfig = toml::from_str(&expanded)?;
236 Ok(config)
237 }
238}
239
240#[cfg(feature = "gateway")]
243fn expand_env_vars(input: &str) -> String {
244 let mut result = String::with_capacity(input.len());
245 let mut chars = input.chars().peekable();
246
247 while let Some(c) = chars.next() {
248 if c == '$' && chars.peek() == Some(&'{') {
249 chars.next(); let mut var_name = String::new();
251 for ch in chars.by_ref() {
252 if ch == '}' {
253 break;
254 }
255 var_name.push(ch);
256 }
257 if let Ok(val) = std::env::var(&var_name) {
258 result.push_str(&val);
259 }
260 } else {
261 result.push(c);
262 }
263 }
264
265 result
266}