Skip to main content

crabllm_core/
config.rs

1use serde::{Deserialize, Serialize};
2use std::collections::{BTreeMap, HashMap};
3
4/// Per-model token pricing. One rate per usage axis. Secondary rates are
5/// `Option<f64>` so "absent" means *fall back to the coarser bucket* rather
6/// than *free* — see [`crate::ModelInfo::cost`] for the fallback chain.
7///
8/// Field names use the canonical "input/output" vocabulary; legacy
9/// "prompt/completion/cache_hit" names are accepted via serde aliases so
10/// existing configs and the generated `models/cloud.toml` continue to load.
11#[derive(Debug, Clone, Default, Serialize, Deserialize)]
12#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
13pub struct PricingConfig {
14    /// Cost per million uncached input tokens in USD.
15    #[serde(alias = "prompt_cost_per_million")]
16    pub input_cost_per_million: f64,
17    /// Cost per million output tokens in USD.
18    #[serde(alias = "completion_cost_per_million")]
19    pub output_cost_per_million: f64,
20
21    /// Cost per million cache-read input tokens in USD.
22    /// `None` → falls back to `input_cost_per_million`.
23    #[serde(
24        alias = "cache_hit_cost_per_million",
25        default,
26        skip_serializing_if = "Option::is_none"
27    )]
28    pub cache_read_cost_per_million: Option<f64>,
29    /// Cost per million cache-write input tokens in USD (Anthropic charges
30    /// ~1.25× of base input for this). `None` → falls back to
31    /// `input_cost_per_million`.
32    #[serde(default, skip_serializing_if = "Option::is_none")]
33    pub cache_write_cost_per_million: Option<f64>,
34    /// Cost per million reasoning output tokens in USD.
35    /// `None` → falls back to `output_cost_per_million`.
36    #[serde(default, skip_serializing_if = "Option::is_none")]
37    pub reasoning_cost_per_million: Option<f64>,
38    /// Cost per million audio input tokens in USD.
39    /// `None` → falls back to `input_cost_per_million`.
40    #[serde(default, skip_serializing_if = "Option::is_none")]
41    pub audio_input_cost_per_million: Option<f64>,
42    /// Cost per million audio output tokens in USD.
43    /// `None` → falls back to `output_cost_per_million`.
44    #[serde(default, skip_serializing_if = "Option::is_none")]
45    pub audio_output_cost_per_million: Option<f64>,
46
47    /// Per-call cost in USD for upstream-side tools like web search. Keyed by
48    /// tool name (must match the names crabllm reports in
49    /// [`crate::Usage::server_tool_calls`]).
50    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
51    pub server_tool_cost_per_call: BTreeMap<String, f64>,
52}
53
54/// Top-level gateway configuration, loaded from TOML.
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct GatewayConfig {
57    /// Address to listen on, e.g. "0.0.0.0:8080".
58    #[serde(default = "default_listen")]
59    pub listen: String,
60    /// Named provider configurations.
61    #[serde(default)]
62    pub providers: HashMap<String, ProviderConfig>,
63    /// Virtual API keys for client authentication.
64    #[serde(default)]
65    pub keys: Vec<KeyConfig>,
66    /// Extension configurations. Each key is an extension name, value is its config.
67    #[serde(default)]
68    pub extensions: Option<serde_json::Value>,
69    /// Storage backend configuration.
70    #[serde(default)]
71    pub storage: Option<StorageConfig>,
72    /// Model name aliases. Maps friendly names to canonical model names.
73    #[serde(default)]
74    pub aliases: HashMap<String, String>,
75    /// Per-model metadata overrides (context window, pricing). Merged with
76    /// built-in defaults at lookup time — only specify what you want to override.
77    #[serde(default)]
78    pub models: HashMap<String, crate::ModelInfo>,
79    /// Path to cloud model metadata TOML file (pricing + context windows).
80    /// Entries are merged into `models` at startup (config entries win).
81    #[serde(default, skip_serializing_if = "Option::is_none")]
82    pub cloud_models: Option<String>,
83    /// Admin API bearer token. If set, enables /v1/admin/* endpoints.
84    #[serde(default, skip_serializing_if = "Option::is_none")]
85    pub admin_token: Option<String>,
86    /// Graceful shutdown timeout in seconds. Default: 30.
87    #[serde(default = "default_shutdown_timeout")]
88    pub shutdown_timeout: u64,
89    /// Serve OpenAPI documentation at `/openapi.json` and `/docs`.
90    /// Defaults to `true`; set `openapi = false` to disable.
91    /// Ignored unless the binary is built with the `openapi` feature.
92    #[serde(default = "default_openapi")]
93    pub openapi: bool,
94}
95
96/// Configuration for a single LLM provider.
97#[derive(Debug, Default, Clone, Serialize, Deserialize)]
98pub struct ProviderConfig {
99    /// Provider kind determines the dispatch path.
100    #[serde(default, skip_serializing_if = "ProviderKind::is_default")]
101    pub kind: ProviderKind,
102    /// API key (supports `${ENV_VAR}` interpolation).
103    #[serde(default, skip_serializing_if = "Option::is_none")]
104    pub api_key: Option<String>,
105    /// Base URL override. OpenAI-compat providers have sensible defaults.
106    #[serde(default, skip_serializing_if = "Option::is_none")]
107    pub base_url: Option<String>,
108    /// Model names served by this provider.
109    #[serde(default, skip_serializing_if = "Vec::is_empty")]
110    pub models: Vec<String>,
111    /// Routing weight for weighted random selection. Higher = more traffic.
112    #[serde(default, skip_serializing_if = "Option::is_none")]
113    pub weight: Option<u16>,
114    /// Max retries on transient errors before fallback. 0 disables retry.
115    #[serde(default, skip_serializing_if = "Option::is_none")]
116    pub max_retries: Option<u32>,
117    /// API version string, used by Azure OpenAI.
118    #[serde(default, skip_serializing_if = "Option::is_none")]
119    pub api_version: Option<String>,
120    /// Per-request timeout in seconds. Default: 30.
121    #[serde(default, skip_serializing_if = "Option::is_none")]
122    pub timeout: Option<u64>,
123    /// Wall-clock deadline for the entire retry loop in seconds. Default: 15.
124    /// The loop stops retrying once this much time has elapsed since the
125    /// first attempt, even if `max_retries` has not been exhausted.
126    #[serde(default, skip_serializing_if = "Option::is_none")]
127    pub retry_deadline: Option<u64>,
128    /// AWS region for Bedrock provider.
129    #[serde(default, skip_serializing_if = "Option::is_none")]
130    pub region: Option<String>,
131    /// AWS access key ID for Bedrock provider.
132    #[serde(default, skip_serializing_if = "Option::is_none")]
133    pub access_key: Option<String>,
134    /// AWS secret access key for Bedrock provider.
135    #[serde(default, skip_serializing)]
136    pub secret_key: Option<String>,
137}
138
139fn default_shutdown_timeout() -> u64 {
140    30
141}
142
143fn default_openapi() -> bool {
144    true
145}
146
147fn default_listen() -> String {
148    "127.0.0.1:5632".to_string()
149}
150
151/// Which provider implementation to use. Known variants map to named
152/// dispatch paths. A self-defined name deserializes to [`Custom`], which
153/// dispatches as OpenAI-compatible and requires `base_url` at validation.
154///
155/// [`Custom`]: ProviderKind::Custom
156#[derive(Debug, Default, Clone, PartialEq, Eq)]
157pub enum ProviderKind {
158    #[default]
159    Openai,
160    Anthropic,
161    Deepseek,
162    Google,
163    Bedrock,
164    Ollama,
165    Azure,
166    /// Self-defined kind — any string that doesn't match a known variant.
167    /// Dispatched through the OpenAI-compatible path; `base_url` required.
168    Custom(String),
169}
170
171impl ProviderKind {
172    pub fn as_str(&self) -> &str {
173        match self {
174            Self::Openai => "openai",
175            Self::Anthropic => "anthropic",
176            Self::Deepseek => "deepseek",
177            Self::Google => "google",
178            Self::Bedrock => "bedrock",
179            Self::Ollama => "ollama",
180            Self::Azure => "azure",
181            Self::Custom(s) => s,
182        }
183    }
184
185    /// Returns true if this is the default variant (Openai).
186    pub fn is_default(&self) -> bool {
187        matches!(self, Self::Openai)
188    }
189}
190
191impl std::fmt::Display for ProviderKind {
192    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193        f.write_str(self.as_str())
194    }
195}
196
197impl serde::Serialize for ProviderKind {
198    fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
199        s.serialize_str(self.as_str())
200    }
201}
202
203impl<'de> serde::Deserialize<'de> for ProviderKind {
204    fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
205        let s = String::deserialize(d)?;
206        Ok(match s.as_str() {
207            "openai" => Self::Openai,
208            "anthropic" => Self::Anthropic,
209            "deepseek" => Self::Deepseek,
210            "google" => Self::Google,
211            "bedrock" => Self::Bedrock,
212            "ollama" => Self::Ollama,
213            "azure" => Self::Azure,
214            _ => Self::Custom(s),
215        })
216    }
217}
218
219impl ProviderConfig {
220    /// Resolve the effective provider kind.
221    ///
222    /// Returns `Anthropic` if the field is explicitly set to `Anthropic`,
223    /// or if `base_url` contains "anthropic". Otherwise returns the
224    /// configured kind.
225    pub fn effective_kind(&self) -> ProviderKind {
226        if self.kind != ProviderKind::Openai {
227            return self.kind.clone();
228        }
229        if let Some(url) = &self.base_url
230            && url.contains("anthropic")
231        {
232            return ProviderKind::Anthropic;
233        }
234        self.kind.clone()
235    }
236
237    /// Validate field combinations.
238    pub fn validate(&self, provider_name: &str) -> Result<(), String> {
239        if self.models.is_empty() {
240            return Err(format!("provider '{provider_name}' has no models"));
241        }
242        match &self.kind {
243            ProviderKind::Bedrock => {
244                if self.region.is_none() {
245                    return Err(format!(
246                        "provider '{provider_name}' (bedrock) requires region"
247                    ));
248                }
249                if self.access_key.is_none() {
250                    return Err(format!(
251                        "provider '{provider_name}' (bedrock) requires access_key"
252                    ));
253                }
254                if self.secret_key.is_none() {
255                    return Err(format!(
256                        "provider '{provider_name}' (bedrock) requires secret_key"
257                    ));
258                }
259            }
260            ProviderKind::Ollama => {
261                // Ollama doesn't require api_key or base_url.
262            }
263            ProviderKind::Custom(name) => {
264                if self.base_url.is_none() {
265                    return Err(format!(
266                        "provider '{provider_name}' (custom kind '{name}') requires base_url"
267                    ));
268                }
269            }
270            _ => {
271                if self.api_key.is_none() && self.base_url.is_none() {
272                    return Err(format!(
273                        "provider '{provider_name}' requires api_key or base_url"
274                    ));
275                }
276            }
277        }
278        Ok(())
279    }
280}
281
282/// Per-key rate limit override. When set on a key, these values take
283/// precedence over the global `[extensions.rate_limit]` config.
284#[derive(Debug, Clone, Serialize, Deserialize)]
285#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
286pub struct KeyRateLimit {
287    #[serde(default, skip_serializing_if = "Option::is_none")]
288    pub requests_per_minute: Option<u64>,
289    #[serde(default, skip_serializing_if = "Option::is_none")]
290    pub tokens_per_minute: Option<u64>,
291}
292
293/// Virtual API key for client authentication.
294#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct KeyConfig {
296    /// Human-readable name for this key.
297    pub name: String,
298    /// The key string clients send in Authorization header.
299    pub key: String,
300    /// Which models this key can access. `["*"]` means all.
301    pub models: Vec<String>,
302    /// Per-key rate limit override. Takes precedence over the global
303    /// `[extensions.rate_limit]` config when present.
304    #[serde(default, skip_serializing_if = "Option::is_none")]
305    pub rate_limit: Option<KeyRateLimit>,
306}
307
308/// Storage backend configuration.
309#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct StorageConfig {
311    /// Backend kind: "memory" (default) or "sqlite" (requires feature).
312    #[serde(default = "StorageConfig::default_kind")]
313    pub kind: String,
314    /// File path for persistent backends (required for sqlite).
315    #[serde(default)]
316    pub path: Option<String>,
317}
318
319impl StorageConfig {
320    fn default_kind() -> String {
321        "memory".to_string()
322    }
323}
324
325impl GatewayConfig {
326    /// Load config from a TOML file, expanding `${VAR}` patterns in
327    /// string values. If `cloud_models` is set, loads the referenced
328    /// file and merges entries into `models` (config entries win over
329    /// cloud file entries).
330    #[cfg(feature = "gateway")]
331    pub fn from_file(path: &std::path::Path) -> Result<Self, Box<dyn std::error::Error>> {
332        let raw = std::fs::read_to_string(path)?;
333        let expanded = expand_env_vars(&raw);
334
335        let mut config: GatewayConfig = toml::from_str(&expanded)?;
336
337        let config_dir = path.parent().unwrap_or_else(|| std::path::Path::new("."));
338        config.load_cloud_models(config_dir)?;
339
340        Ok(config)
341    }
342
343    /// Load cloud model metadata from the configured TOML file and merge
344    /// into `self.models`. Config entries take precedence — cloud file
345    /// entries only fill gaps.
346    #[cfg(feature = "gateway")]
347    fn load_cloud_models(
348        &mut self,
349        config_dir: &std::path::Path,
350    ) -> Result<(), Box<dyn std::error::Error>> {
351        let Some(ref path) = self.cloud_models else {
352            return Ok(());
353        };
354        let full = config_dir.join(path);
355        let raw = std::fs::read_to_string(&full)
356            .map_err(|e| format!("cloud_models '{}': {e}", full.display()))?;
357        let table: HashMap<String, crate::ModelInfo> =
358            toml::from_str(&raw).map_err(|e| format!("cloud_models '{}': {e}", full.display()))?;
359        for (model, info) in table {
360            self.models.entry(model).or_insert(info);
361        }
362        Ok(())
363    }
364}
365
366/// Expand `${VAR}` patterns in a string using environment variables.
367/// Unknown variables are replaced with empty string.
368#[cfg(feature = "gateway")]
369fn expand_env_vars(input: &str) -> String {
370    let mut result = String::with_capacity(input.len());
371    let mut chars = input.chars().peekable();
372
373    while let Some(c) = chars.next() {
374        if c == '$' && chars.peek() == Some(&'{') {
375            chars.next(); // consume '{'
376            let mut var_name = String::new();
377            for ch in chars.by_ref() {
378                if ch == '}' {
379                    break;
380                }
381                var_name.push(ch);
382            }
383            if let Ok(val) = std::env::var(&var_name) {
384                result.push_str(&val);
385            }
386        } else {
387            result.push(c);
388        }
389    }
390
391    result
392}