Skip to main content

crabllm_core/
config.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4/// Per-model token pricing configuration.
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct PricingConfig {
7    /// Cost per million prompt tokens in USD.
8    pub prompt_cost_per_million: f64,
9    /// Cost per million completion tokens in USD.
10    pub completion_cost_per_million: f64,
11}
12
13/// Compute the cost in USD for a given number of prompt and completion tokens.
14pub fn cost(pricing: &PricingConfig, prompt_tokens: u32, completion_tokens: u32) -> f64 {
15    (prompt_tokens as f64 * pricing.prompt_cost_per_million
16        + completion_tokens as f64 * pricing.completion_cost_per_million)
17        / 1_000_000.0
18}
19
20/// Top-level gateway configuration, loaded from TOML.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct GatewayConfig {
23    /// Address to listen on, e.g. "0.0.0.0:8080".
24    pub listen: String,
25    /// Named provider configurations.
26    #[serde(default)]
27    pub providers: HashMap<String, ProviderConfig>,
28    /// Virtual API keys for client authentication.
29    #[serde(default)]
30    pub keys: Vec<KeyConfig>,
31    /// Extension configurations. Each key is an extension name, value is its config.
32    #[serde(default)]
33    pub extensions: Option<serde_json::Value>,
34    /// Storage backend configuration.
35    #[serde(default)]
36    pub storage: Option<StorageConfig>,
37    /// Model name aliases. Maps friendly names to canonical model names.
38    #[serde(default)]
39    pub aliases: HashMap<String, String>,
40    /// Per-model token pricing for cost tracking and budget enforcement.
41    #[serde(default)]
42    pub pricing: HashMap<String, PricingConfig>,
43    /// Admin API bearer token. If set, enables /v1/admin/* endpoints.
44    #[serde(default, skip_serializing_if = "Option::is_none")]
45    pub admin_token: Option<String>,
46    /// Graceful shutdown timeout in seconds. Default: 30.
47    #[serde(default = "default_shutdown_timeout")]
48    pub shutdown_timeout: u64,
49    /// Optional llama.cpp local backend configuration.
50    #[serde(default, skip_serializing_if = "Option::is_none")]
51    pub llamacpp: Option<LlamaCppGatewayConfig>,
52}
53
54/// Gateway-level configuration for the llama.cpp local backend.
55///
56/// The `models` list determines which model names are registered for the
57/// llama.cpp provider. Each entry is either an Ollama-registry model name
58/// (`qwen2.5:0.5b`) or a filesystem path to a GGUF file. The pool spawns
59/// a separate `llama-server` subprocess per model on first request and
60/// evicts idle servers after `idle_timeout_secs`.
61#[derive(Debug, Clone, Default, Serialize, Deserialize)]
62pub struct LlamaCppGatewayConfig {
63    /// Model names or GGUF paths to serve.
64    #[serde(default)]
65    pub models: Vec<String>,
66    /// Idle timeout for the per-model server pool, in seconds. Default: 1800.
67    #[serde(default, skip_serializing_if = "Option::is_none")]
68    pub idle_timeout_secs: Option<u64>,
69    /// Number of GPU layers to offload. Default: 999 (auto).
70    #[serde(default, skip_serializing_if = "Option::is_none")]
71    pub n_gpu_layers: Option<u32>,
72    /// Context size in tokens. Default: 4096.
73    #[serde(default, skip_serializing_if = "Option::is_none")]
74    pub n_ctx: Option<u32>,
75    /// Number of inference threads. Default: system-chosen.
76    #[serde(default, skip_serializing_if = "Option::is_none")]
77    pub n_threads: Option<u32>,
78    /// Override for the GGUF cache directory. Defaults to `~/.crabtalk/models`.
79    #[serde(default, skip_serializing_if = "Option::is_none")]
80    pub cache_dir: Option<String>,
81}
82
83/// Configuration for a single LLM provider.
84#[derive(Debug, Default, Clone, Serialize, Deserialize)]
85pub struct ProviderConfig {
86    /// Provider kind determines the dispatch path.
87    #[serde(
88        default,
89        alias = "standard",
90        skip_serializing_if = "ProviderKind::is_default"
91    )]
92    pub kind: ProviderKind,
93    /// API key (supports `${ENV_VAR}` interpolation).
94    #[serde(default, skip_serializing_if = "Option::is_none")]
95    pub api_key: Option<String>,
96    /// Base URL override. OpenAI-compat providers have sensible defaults.
97    #[serde(default, skip_serializing_if = "Option::is_none")]
98    pub base_url: Option<String>,
99    /// Model names served by this provider.
100    #[serde(default, skip_serializing_if = "Vec::is_empty")]
101    pub models: Vec<String>,
102    /// Routing weight for weighted random selection. Higher = more traffic.
103    #[serde(default, skip_serializing_if = "Option::is_none")]
104    pub weight: Option<u16>,
105    /// Max retries on transient errors before fallback. 0 disables retry.
106    #[serde(default, skip_serializing_if = "Option::is_none")]
107    pub max_retries: Option<u32>,
108    /// API version string, used by Azure OpenAI.
109    #[serde(default, skip_serializing_if = "Option::is_none")]
110    pub api_version: Option<String>,
111    /// Per-request timeout in seconds. Default: 30.
112    #[serde(default, skip_serializing_if = "Option::is_none")]
113    pub timeout: Option<u64>,
114    /// AWS region for Bedrock provider.
115    #[serde(default, skip_serializing_if = "Option::is_none")]
116    pub region: Option<String>,
117    /// AWS access key ID for Bedrock provider.
118    #[serde(default, skip_serializing_if = "Option::is_none")]
119    pub access_key: Option<String>,
120    /// AWS secret access key for Bedrock provider.
121    #[serde(default, skip_serializing)]
122    pub secret_key: Option<String>,
123}
124
125fn default_shutdown_timeout() -> u64 {
126    30
127}
128
129/// Which provider implementation to use.
130#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
131#[serde(rename_all = "snake_case")]
132pub enum ProviderKind {
133    #[default]
134    Openai,
135    Anthropic,
136    Google,
137    Bedrock,
138    Ollama,
139    Azure,
140}
141
142impl ProviderKind {
143    /// Returns true if this is the default variant (Openai).
144    pub fn is_default(&self) -> bool {
145        *self == Self::Openai
146    }
147}
148
149impl ProviderConfig {
150    /// Resolve the effective provider kind.
151    ///
152    /// Returns `Anthropic` if the field is explicitly set to `Anthropic`,
153    /// or if `base_url` contains "anthropic". Otherwise returns the
154    /// configured kind.
155    pub fn effective_kind(&self) -> ProviderKind {
156        if self.kind == ProviderKind::Anthropic {
157            return ProviderKind::Anthropic;
158        }
159        if let Some(url) = &self.base_url
160            && url.contains("anthropic")
161        {
162            return ProviderKind::Anthropic;
163        }
164        self.kind
165    }
166
167    /// Validate field combinations.
168    pub fn validate(&self, provider_name: &str) -> Result<(), String> {
169        if self.models.is_empty() {
170            return Err(format!("provider '{provider_name}' has no models"));
171        }
172        match self.kind {
173            ProviderKind::Bedrock => {
174                if self.region.is_none() {
175                    return Err(format!(
176                        "provider '{provider_name}' (bedrock) requires region"
177                    ));
178                }
179                if self.access_key.is_none() {
180                    return Err(format!(
181                        "provider '{provider_name}' (bedrock) requires access_key"
182                    ));
183                }
184                if self.secret_key.is_none() {
185                    return Err(format!(
186                        "provider '{provider_name}' (bedrock) requires secret_key"
187                    ));
188                }
189            }
190            ProviderKind::Ollama => {
191                // Ollama doesn't require api_key or base_url.
192            }
193            _ => {
194                if self.api_key.is_none() && self.base_url.is_none() {
195                    return Err(format!(
196                        "provider '{provider_name}' requires api_key or base_url"
197                    ));
198                }
199            }
200        }
201        Ok(())
202    }
203}
204
205/// Virtual API key for client authentication.
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct KeyConfig {
208    /// Human-readable name for this key.
209    pub name: String,
210    /// The key string clients send in Authorization header.
211    pub key: String,
212    /// Which models this key can access. `["*"]` means all.
213    pub models: Vec<String>,
214}
215
216/// Storage backend configuration.
217#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct StorageConfig {
219    /// Backend kind: "memory" (default) or "sqlite" (requires feature).
220    #[serde(default = "StorageConfig::default_kind")]
221    pub kind: String,
222    /// File path for persistent backends (required for sqlite).
223    #[serde(default)]
224    pub path: Option<String>,
225}
226
227impl StorageConfig {
228    fn default_kind() -> String {
229        "memory".to_string()
230    }
231}
232
233impl GatewayConfig {
234    /// Load config from a TOML file, expanding `${VAR}` patterns in string values.
235    #[cfg(feature = "gateway")]
236    pub fn from_file(path: &std::path::Path) -> Result<Self, Box<dyn std::error::Error>> {
237        let raw = std::fs::read_to_string(path)?;
238        let expanded = expand_env_vars(&raw);
239
240        // Pre-parse as a generic toml::Value so we can surface a clear
241        // migration error for the removed `kind = "llamacpp"` provider
242        // variant before the typed deserialize turns it into a cryptic
243        // "unknown variant" error.
244        let raw_value: toml::Value = toml::from_str(&expanded)?;
245        if let Some(providers) = raw_value.get("providers").and_then(|v| v.as_table()) {
246            for (name, entry) in providers {
247                if let Some(kind) = entry.get("kind").and_then(|v| v.as_str())
248                    && (kind == "llamacpp" || kind == "llama_cpp")
249                {
250                    return Err(format!(
251                        "provider '{name}' uses kind = '{kind}', which is no longer supported. \
252                         Move llama.cpp configuration to a top-level [llamacpp] section. \
253                         Each model becomes an entry in llamacpp.models; pool-wide settings \
254                         (n_ctx, n_gpu_layers, n_threads, idle_timeout_secs) live under [llamacpp]."
255                    )
256                    .into());
257                }
258            }
259        }
260
261        let config: GatewayConfig = toml::from_str(&expanded)?;
262        Ok(config)
263    }
264}
265
266/// Expand `${VAR}` patterns in a string using environment variables.
267/// Unknown variables are replaced with empty string.
268#[cfg(feature = "gateway")]
269fn expand_env_vars(input: &str) -> String {
270    let mut result = String::with_capacity(input.len());
271    let mut chars = input.chars().peekable();
272
273    while let Some(c) = chars.next() {
274        if c == '$' && chars.peek() == Some(&'{') {
275            chars.next(); // consume '{'
276            let mut var_name = String::new();
277            for ch in chars.by_ref() {
278                if ch == '}' {
279                    break;
280                }
281                var_name.push(ch);
282            }
283            if let Ok(val) = std::env::var(&var_name) {
284                result.push_str(&val);
285            }
286        } else {
287            result.push(c);
288        }
289    }
290
291    result
292}