Skip to main content

crabtalk_core/
config.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4/// Per-model token pricing configuration.
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct PricingConfig {
7    /// Cost per million prompt tokens in USD.
8    pub prompt_cost_per_million: f64,
9    /// Cost per million completion tokens in USD.
10    pub completion_cost_per_million: f64,
11}
12
13/// Compute the cost in USD for a given number of prompt and completion tokens.
14pub fn cost(pricing: &PricingConfig, prompt_tokens: u32, completion_tokens: u32) -> f64 {
15    (prompt_tokens as f64 * pricing.prompt_cost_per_million
16        + completion_tokens as f64 * pricing.completion_cost_per_million)
17        / 1_000_000.0
18}
19
20/// Top-level gateway configuration, loaded from TOML.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct GatewayConfig {
23    /// Address to listen on, e.g. "0.0.0.0:8080".
24    pub listen: String,
25    /// Named provider configurations.
26    pub providers: HashMap<String, ProviderConfig>,
27    /// Virtual API keys for client authentication.
28    #[serde(default)]
29    pub keys: Vec<KeyConfig>,
30    /// Extension configurations. Each key is an extension name, value is its config.
31    #[serde(default)]
32    pub extensions: Option<serde_json::Value>,
33    /// Storage backend configuration.
34    #[serde(default)]
35    pub storage: Option<StorageConfig>,
36    /// Model name aliases. Maps friendly names to canonical model names.
37    #[serde(default)]
38    pub aliases: HashMap<String, String>,
39    /// Per-model token pricing for cost tracking and budget enforcement.
40    #[serde(default)]
41    pub pricing: HashMap<String, PricingConfig>,
42    /// Graceful shutdown timeout in seconds. Default: 30.
43    #[serde(default = "default_shutdown_timeout")]
44    pub shutdown_timeout: u64,
45}
46
47/// Configuration for a single LLM provider.
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct ProviderConfig {
50    /// Provider kind determines the dispatch path.
51    #[serde(
52        default,
53        alias = "standard",
54        skip_serializing_if = "ProviderKind::is_default"
55    )]
56    pub kind: ProviderKind,
57    /// API key (supports `${ENV_VAR}` interpolation).
58    #[serde(default, skip_serializing_if = "Option::is_none")]
59    pub api_key: Option<String>,
60    /// Base URL override. OpenAI-compat providers have sensible defaults.
61    #[serde(default, skip_serializing_if = "Option::is_none")]
62    pub base_url: Option<String>,
63    /// Model names served by this provider.
64    #[serde(default, skip_serializing_if = "Vec::is_empty")]
65    pub models: Vec<String>,
66    /// Routing weight for weighted random selection. Higher = more traffic.
67    #[serde(default, skip_serializing_if = "Option::is_none")]
68    pub weight: Option<u16>,
69    /// Max retries on transient errors before fallback. 0 disables retry.
70    #[serde(default, skip_serializing_if = "Option::is_none")]
71    pub max_retries: Option<u32>,
72    /// API version string, used by Azure OpenAI.
73    #[serde(default, skip_serializing_if = "Option::is_none")]
74    pub api_version: Option<String>,
75    /// Per-request timeout in seconds. Default: 30.
76    #[serde(default, skip_serializing_if = "Option::is_none")]
77    pub timeout: Option<u64>,
78    /// AWS region for Bedrock provider.
79    #[serde(default, skip_serializing_if = "Option::is_none")]
80    pub region: Option<String>,
81    /// AWS access key ID for Bedrock provider.
82    #[serde(default, skip_serializing_if = "Option::is_none")]
83    pub access_key: Option<String>,
84    /// AWS secret access key for Bedrock provider.
85    #[serde(default, skip_serializing)]
86    pub secret_key: Option<String>,
87}
88
89fn default_shutdown_timeout() -> u64 {
90    30
91}
92
93/// Which provider implementation to use.
94#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
95#[serde(rename_all = "snake_case")]
96pub enum ProviderKind {
97    #[default]
98    #[serde(alias = "openai")]
99    OpenaiCompat,
100    Anthropic,
101    Google,
102    Bedrock,
103    Ollama,
104    Azure,
105}
106
107impl ProviderKind {
108    /// Returns true if this is the default variant (OpenaiCompat).
109    pub fn is_default(&self) -> bool {
110        *self == Self::OpenaiCompat
111    }
112}
113
114impl ProviderConfig {
115    /// Resolve the effective provider kind.
116    ///
117    /// Returns `Anthropic` if the field is explicitly set to `Anthropic`,
118    /// or if `base_url` contains "anthropic". Otherwise returns the
119    /// configured kind.
120    pub fn effective_kind(&self) -> ProviderKind {
121        if self.kind == ProviderKind::Anthropic {
122            return ProviderKind::Anthropic;
123        }
124        if let Some(url) = &self.base_url
125            && url.contains("anthropic")
126        {
127            return ProviderKind::Anthropic;
128        }
129        self.kind
130    }
131
132    /// Validate field combinations.
133    pub fn validate(&self, provider_name: &str) -> Result<(), String> {
134        if self.models.is_empty() {
135            return Err(format!("provider '{provider_name}' has no models"));
136        }
137        match self.kind {
138            ProviderKind::Bedrock => {
139                if self.region.is_none() {
140                    return Err(format!(
141                        "provider '{provider_name}' (bedrock) requires region"
142                    ));
143                }
144                if self.access_key.is_none() {
145                    return Err(format!(
146                        "provider '{provider_name}' (bedrock) requires access_key"
147                    ));
148                }
149                if self.secret_key.is_none() {
150                    return Err(format!(
151                        "provider '{provider_name}' (bedrock) requires secret_key"
152                    ));
153                }
154            }
155            ProviderKind::Ollama => {
156                // Ollama doesn't require api_key or base_url.
157            }
158            _ => {
159                if self.api_key.is_none() && self.base_url.is_none() {
160                    return Err(format!(
161                        "provider '{provider_name}' requires api_key or base_url"
162                    ));
163                }
164            }
165        }
166        Ok(())
167    }
168}
169
170/// Virtual API key for client authentication.
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct KeyConfig {
173    /// Human-readable name for this key.
174    pub name: String,
175    /// The key string clients send in Authorization header.
176    pub key: String,
177    /// Which models this key can access. `["*"]` means all.
178    pub models: Vec<String>,
179}
180
181/// Storage backend configuration.
182#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct StorageConfig {
184    /// Backend kind: "memory" (default) or "sqlite" (requires feature).
185    #[serde(default = "StorageConfig::default_kind")]
186    pub kind: String,
187    /// File path for persistent backends (required for sqlite).
188    #[serde(default)]
189    pub path: Option<String>,
190}
191
192impl StorageConfig {
193    fn default_kind() -> String {
194        "memory".to_string()
195    }
196}
197
198impl GatewayConfig {
199    /// Load config from a TOML file, expanding `${VAR}` patterns in string values.
200    #[cfg(feature = "gateway")]
201    pub fn from_file(path: &std::path::Path) -> Result<Self, Box<dyn std::error::Error>> {
202        let raw = std::fs::read_to_string(path)?;
203        let expanded = expand_env_vars(&raw);
204        let config: GatewayConfig = toml::from_str(&expanded)?;
205        Ok(config)
206    }
207}
208
209/// Expand `${VAR}` patterns in a string using environment variables.
210/// Unknown variables are replaced with empty string.
211#[cfg(feature = "gateway")]
212fn expand_env_vars(input: &str) -> String {
213    let mut result = String::with_capacity(input.len());
214    let mut chars = input.chars().peekable();
215
216    while let Some(c) = chars.next() {
217        if c == '$' && chars.peek() == Some(&'{') {
218            chars.next(); // consume '{'
219            let mut var_name = String::new();
220            for ch in chars.by_ref() {
221                if ch == '}' {
222                    break;
223                }
224                var_name.push(ch);
225            }
226            if let Ok(val) = std::env::var(&var_name) {
227                result.push_str(&val);
228            }
229        } else {
230            result.push(c);
231        }
232    }
233
234    result
235}