1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct PricingConfig {
7 pub prompt_cost_per_million: f64,
9 pub completion_cost_per_million: f64,
11}
12
13pub fn cost(pricing: &PricingConfig, prompt_tokens: u32, completion_tokens: u32) -> f64 {
15 (prompt_tokens as f64 * pricing.prompt_cost_per_million
16 + completion_tokens as f64 * pricing.completion_cost_per_million)
17 / 1_000_000.0
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct GatewayConfig {
23 pub listen: String,
25 #[serde(default)]
27 pub providers: HashMap<String, ProviderConfig>,
28 #[serde(default)]
30 pub keys: Vec<KeyConfig>,
31 #[serde(default)]
33 pub extensions: Option<serde_json::Value>,
34 #[serde(default)]
36 pub storage: Option<StorageConfig>,
37 #[serde(default)]
39 pub aliases: HashMap<String, String>,
40 #[serde(default)]
42 pub pricing: HashMap<String, PricingConfig>,
43 #[serde(default, skip_serializing_if = "Option::is_none")]
45 pub admin_token: Option<String>,
46 #[serde(default = "default_shutdown_timeout")]
48 pub shutdown_timeout: u64,
49 #[serde(default, skip_serializing_if = "Option::is_none")]
51 pub llamacpp: Option<LlamaCppGatewayConfig>,
52}
53
54#[derive(Debug, Clone, Default, Serialize, Deserialize)]
62pub struct LlamaCppGatewayConfig {
63 #[serde(default)]
65 pub models: Vec<String>,
66 #[serde(default, skip_serializing_if = "Option::is_none")]
68 pub idle_timeout_secs: Option<u64>,
69 #[serde(default, skip_serializing_if = "Option::is_none")]
71 pub n_gpu_layers: Option<u32>,
72 #[serde(default, skip_serializing_if = "Option::is_none")]
74 pub n_ctx: Option<u32>,
75 #[serde(default, skip_serializing_if = "Option::is_none")]
77 pub n_threads: Option<u32>,
78 #[serde(default, skip_serializing_if = "Option::is_none")]
80 pub cache_dir: Option<String>,
81}
82
83#[derive(Debug, Default, Clone, Serialize, Deserialize)]
85pub struct ProviderConfig {
86 #[serde(
88 default,
89 alias = "standard",
90 skip_serializing_if = "ProviderKind::is_default"
91 )]
92 pub kind: ProviderKind,
93 #[serde(default, skip_serializing_if = "Option::is_none")]
95 pub api_key: Option<String>,
96 #[serde(default, skip_serializing_if = "Option::is_none")]
98 pub base_url: Option<String>,
99 #[serde(default, skip_serializing_if = "Vec::is_empty")]
101 pub models: Vec<String>,
102 #[serde(default, skip_serializing_if = "Option::is_none")]
104 pub weight: Option<u16>,
105 #[serde(default, skip_serializing_if = "Option::is_none")]
107 pub max_retries: Option<u32>,
108 #[serde(default, skip_serializing_if = "Option::is_none")]
110 pub api_version: Option<String>,
111 #[serde(default, skip_serializing_if = "Option::is_none")]
113 pub timeout: Option<u64>,
114 #[serde(default, skip_serializing_if = "Option::is_none")]
116 pub region: Option<String>,
117 #[serde(default, skip_serializing_if = "Option::is_none")]
119 pub access_key: Option<String>,
120 #[serde(default, skip_serializing)]
122 pub secret_key: Option<String>,
123}
124
125fn default_shutdown_timeout() -> u64 {
126 30
127}
128
129#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
131#[serde(rename_all = "snake_case")]
132pub enum ProviderKind {
133 #[default]
134 Openai,
135 Anthropic,
136 Google,
137 Bedrock,
138 Ollama,
139 Azure,
140}
141
142impl ProviderKind {
143 pub fn is_default(&self) -> bool {
145 *self == Self::Openai
146 }
147}
148
149impl ProviderConfig {
150 pub fn effective_kind(&self) -> ProviderKind {
156 if self.kind == ProviderKind::Anthropic {
157 return ProviderKind::Anthropic;
158 }
159 if let Some(url) = &self.base_url
160 && url.contains("anthropic")
161 {
162 return ProviderKind::Anthropic;
163 }
164 self.kind
165 }
166
167 pub fn validate(&self, provider_name: &str) -> Result<(), String> {
169 if self.models.is_empty() {
170 return Err(format!("provider '{provider_name}' has no models"));
171 }
172 match self.kind {
173 ProviderKind::Bedrock => {
174 if self.region.is_none() {
175 return Err(format!(
176 "provider '{provider_name}' (bedrock) requires region"
177 ));
178 }
179 if self.access_key.is_none() {
180 return Err(format!(
181 "provider '{provider_name}' (bedrock) requires access_key"
182 ));
183 }
184 if self.secret_key.is_none() {
185 return Err(format!(
186 "provider '{provider_name}' (bedrock) requires secret_key"
187 ));
188 }
189 }
190 ProviderKind::Ollama => {
191 }
193 _ => {
194 if self.api_key.is_none() && self.base_url.is_none() {
195 return Err(format!(
196 "provider '{provider_name}' requires api_key or base_url"
197 ));
198 }
199 }
200 }
201 Ok(())
202 }
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct KeyConfig {
208 pub name: String,
210 pub key: String,
212 pub models: Vec<String>,
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct StorageConfig {
219 #[serde(default = "StorageConfig::default_kind")]
221 pub kind: String,
222 #[serde(default)]
224 pub path: Option<String>,
225}
226
227impl StorageConfig {
228 fn default_kind() -> String {
229 "memory".to_string()
230 }
231}
232
233impl GatewayConfig {
234 #[cfg(feature = "gateway")]
236 pub fn from_file(path: &std::path::Path) -> Result<Self, Box<dyn std::error::Error>> {
237 let raw = std::fs::read_to_string(path)?;
238 let expanded = expand_env_vars(&raw);
239
240 let raw_value: toml::Value = toml::from_str(&expanded)?;
245 if let Some(providers) = raw_value.get("providers").and_then(|v| v.as_table()) {
246 for (name, entry) in providers {
247 if let Some(kind) = entry.get("kind").and_then(|v| v.as_str())
248 && (kind == "llamacpp" || kind == "llama_cpp")
249 {
250 return Err(format!(
251 "provider '{name}' uses kind = '{kind}', which is no longer supported. \
252 Move llama.cpp configuration to a top-level [llamacpp] section. \
253 Each model becomes an entry in llamacpp.models; pool-wide settings \
254 (n_ctx, n_gpu_layers, n_threads, idle_timeout_secs) live under [llamacpp]."
255 )
256 .into());
257 }
258 }
259 }
260
261 let config: GatewayConfig = toml::from_str(&expanded)?;
262 Ok(config)
263 }
264}
265
266#[cfg(feature = "gateway")]
269fn expand_env_vars(input: &str) -> String {
270 let mut result = String::with_capacity(input.len());
271 let mut chars = input.chars().peekable();
272
273 while let Some(c) = chars.next() {
274 if c == '$' && chars.peek() == Some(&'{') {
275 chars.next(); let mut var_name = String::new();
277 for ch in chars.by_ref() {
278 if ch == '}' {
279 break;
280 }
281 var_name.push(ch);
282 }
283 if let Ok(val) = std::env::var(&var_name) {
284 result.push_str(&val);
285 }
286 } else {
287 result.push(c);
288 }
289 }
290
291 result
292}