Skip to main content

crabllm_core/
config.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4/// Per-model token pricing configuration.
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct PricingConfig {
7    /// Cost per million prompt tokens in USD.
8    pub prompt_cost_per_million: f64,
9    /// Cost per million completion tokens in USD.
10    pub completion_cost_per_million: f64,
11}
12
13/// Compute the cost in USD for a given number of prompt and completion tokens.
14pub fn cost(pricing: &PricingConfig, prompt_tokens: u32, completion_tokens: u32) -> f64 {
15    (prompt_tokens as f64 * pricing.prompt_cost_per_million
16        + completion_tokens as f64 * pricing.completion_cost_per_million)
17        / 1_000_000.0
18}
19
20/// Top-level gateway configuration, loaded from TOML.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct GatewayConfig {
23    /// Address to listen on, e.g. "0.0.0.0:8080".
24    pub listen: String,
25    /// Named provider configurations.
26    pub providers: HashMap<String, ProviderConfig>,
27    /// Virtual API keys for client authentication.
28    #[serde(default)]
29    pub keys: Vec<KeyConfig>,
30    /// Extension configurations. Each key is an extension name, value is its config.
31    #[serde(default)]
32    pub extensions: Option<serde_json::Value>,
33    /// Storage backend configuration.
34    #[serde(default)]
35    pub storage: Option<StorageConfig>,
36    /// Model name aliases. Maps friendly names to canonical model names.
37    #[serde(default)]
38    pub aliases: HashMap<String, String>,
39    /// Per-model token pricing for cost tracking and budget enforcement.
40    #[serde(default)]
41    pub pricing: HashMap<String, PricingConfig>,
42    /// Admin API bearer token. If set, enables /v1/admin/* endpoints.
43    #[serde(default, skip_serializing_if = "Option::is_none")]
44    pub admin_token: Option<String>,
45    /// Graceful shutdown timeout in seconds. Default: 30.
46    #[serde(default = "default_shutdown_timeout")]
47    pub shutdown_timeout: u64,
48}
49
50/// Configuration for a single LLM provider.
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ProviderConfig {
53    /// Provider kind determines the dispatch path.
54    #[serde(
55        default,
56        alias = "standard",
57        skip_serializing_if = "ProviderKind::is_default"
58    )]
59    pub kind: ProviderKind,
60    /// API key (supports `${ENV_VAR}` interpolation).
61    #[serde(default, skip_serializing_if = "Option::is_none")]
62    pub api_key: Option<String>,
63    /// Base URL override. OpenAI-compat providers have sensible defaults.
64    #[serde(default, skip_serializing_if = "Option::is_none")]
65    pub base_url: Option<String>,
66    /// Model names served by this provider.
67    #[serde(default, skip_serializing_if = "Vec::is_empty")]
68    pub models: Vec<String>,
69    /// Routing weight for weighted random selection. Higher = more traffic.
70    #[serde(default, skip_serializing_if = "Option::is_none")]
71    pub weight: Option<u16>,
72    /// Max retries on transient errors before fallback. 0 disables retry.
73    #[serde(default, skip_serializing_if = "Option::is_none")]
74    pub max_retries: Option<u32>,
75    /// API version string, used by Azure OpenAI.
76    #[serde(default, skip_serializing_if = "Option::is_none")]
77    pub api_version: Option<String>,
78    /// Per-request timeout in seconds. Default: 30.
79    #[serde(default, skip_serializing_if = "Option::is_none")]
80    pub timeout: Option<u64>,
81    /// AWS region for Bedrock provider.
82    #[serde(default, skip_serializing_if = "Option::is_none")]
83    pub region: Option<String>,
84    /// AWS access key ID for Bedrock provider.
85    #[serde(default, skip_serializing_if = "Option::is_none")]
86    pub access_key: Option<String>,
87    /// AWS secret access key for Bedrock provider.
88    #[serde(default, skip_serializing)]
89    pub secret_key: Option<String>,
90    /// Path to a GGUF model file for the LlamaCpp provider.
91    #[serde(default, skip_serializing_if = "Option::is_none")]
92    pub model_path: Option<String>,
93    /// Number of GPU layers to offload (LlamaCpp). Default: 0 (CPU only).
94    #[serde(default, skip_serializing_if = "Option::is_none")]
95    pub n_gpu_layers: Option<u32>,
96    /// Context size in tokens (LlamaCpp). Default: 2048.
97    #[serde(default, skip_serializing_if = "Option::is_none")]
98    pub n_ctx: Option<u32>,
99    /// Number of threads for inference (LlamaCpp). Default: system-chosen.
100    #[serde(default, skip_serializing_if = "Option::is_none")]
101    pub n_threads: Option<u32>,
102}
103
104fn default_shutdown_timeout() -> u64 {
105    30
106}
107
108/// Which provider implementation to use.
109#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
110#[serde(rename_all = "snake_case")]
111pub enum ProviderKind {
112    #[default]
113    #[serde(alias = "openai")]
114    OpenaiCompat,
115    Anthropic,
116    Google,
117    Bedrock,
118    Ollama,
119    Azure,
120    #[serde(alias = "llama_cpp")]
121    LlamaCpp,
122}
123
124impl ProviderKind {
125    /// Returns true if this is the default variant (OpenaiCompat).
126    pub fn is_default(&self) -> bool {
127        *self == Self::OpenaiCompat
128    }
129}
130
131impl ProviderConfig {
132    /// Resolve the effective provider kind.
133    ///
134    /// Returns `Anthropic` if the field is explicitly set to `Anthropic`,
135    /// or if `base_url` contains "anthropic". Otherwise returns the
136    /// configured kind.
137    pub fn effective_kind(&self) -> ProviderKind {
138        if self.kind == ProviderKind::Anthropic {
139            return ProviderKind::Anthropic;
140        }
141        if let Some(url) = &self.base_url
142            && url.contains("anthropic")
143        {
144            return ProviderKind::Anthropic;
145        }
146        self.kind
147    }
148
149    /// Validate field combinations.
150    pub fn validate(&self, provider_name: &str) -> Result<(), String> {
151        if self.models.is_empty() {
152            return Err(format!("provider '{provider_name}' has no models"));
153        }
154        match self.kind {
155            ProviderKind::Bedrock => {
156                if self.region.is_none() {
157                    return Err(format!(
158                        "provider '{provider_name}' (bedrock) requires region"
159                    ));
160                }
161                if self.access_key.is_none() {
162                    return Err(format!(
163                        "provider '{provider_name}' (bedrock) requires access_key"
164                    ));
165                }
166                if self.secret_key.is_none() {
167                    return Err(format!(
168                        "provider '{provider_name}' (bedrock) requires secret_key"
169                    ));
170                }
171            }
172            ProviderKind::Ollama => {
173                // Ollama doesn't require api_key or base_url.
174            }
175            ProviderKind::LlamaCpp => match &self.model_path {
176                None => {
177                    return Err(format!(
178                        "provider '{provider_name}' (llamacpp) requires model_path"
179                    ));
180                }
181                Some(path) => {
182                    if !std::path::Path::new(path).exists() {
183                        return Err(format!(
184                            "provider '{provider_name}' (llamacpp): model_path '{path}' does not exist"
185                        ));
186                    }
187                }
188            },
189            _ => {
190                if self.api_key.is_none() && self.base_url.is_none() {
191                    return Err(format!(
192                        "provider '{provider_name}' requires api_key or base_url"
193                    ));
194                }
195            }
196        }
197        Ok(())
198    }
199}
200
201/// Virtual API key for client authentication.
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct KeyConfig {
204    /// Human-readable name for this key.
205    pub name: String,
206    /// The key string clients send in Authorization header.
207    pub key: String,
208    /// Which models this key can access. `["*"]` means all.
209    pub models: Vec<String>,
210}
211
212/// Storage backend configuration.
213#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct StorageConfig {
215    /// Backend kind: "memory" (default) or "sqlite" (requires feature).
216    #[serde(default = "StorageConfig::default_kind")]
217    pub kind: String,
218    /// File path for persistent backends (required for sqlite).
219    #[serde(default)]
220    pub path: Option<String>,
221}
222
223impl StorageConfig {
224    fn default_kind() -> String {
225        "memory".to_string()
226    }
227}
228
229impl GatewayConfig {
230    /// Load config from a TOML file, expanding `${VAR}` patterns in string values.
231    #[cfg(feature = "gateway")]
232    pub fn from_file(path: &std::path::Path) -> Result<Self, Box<dyn std::error::Error>> {
233        let raw = std::fs::read_to_string(path)?;
234        let expanded = expand_env_vars(&raw);
235        let config: GatewayConfig = toml::from_str(&expanded)?;
236        Ok(config)
237    }
238}
239
240/// Expand `${VAR}` patterns in a string using environment variables.
241/// Unknown variables are replaced with empty string.
242#[cfg(feature = "gateway")]
243fn expand_env_vars(input: &str) -> String {
244    let mut result = String::with_capacity(input.len());
245    let mut chars = input.chars().peekable();
246
247    while let Some(c) = chars.next() {
248        if c == '$' && chars.peek() == Some(&'{') {
249            chars.next(); // consume '{'
250            let mut var_name = String::new();
251            for ch in chars.by_ref() {
252                if ch == '}' {
253                    break;
254                }
255                var_name.push(ch);
256            }
257            if let Ok(val) = std::env::var(&var_name) {
258                result.push_str(&val);
259            }
260        } else {
261            result.push(c);
262        }
263    }
264
265    result
266}