Skip to main content

liter_llm_proxy/config/
mod.rs

1pub mod files;
2pub mod key;
3pub mod model;
4pub mod server;
5
6pub use files::FileStorageConfig;
7pub use key::VirtualKeyConfig;
8pub use model::{AliasEntry, ModelEntry};
9pub use server::ServerConfig;
10
11use std::collections::HashMap;
12use std::path::Path;
13
14use serde::Deserialize;
15
16// ---------------------------------------------------------------------------
17// Default helpers
18// ---------------------------------------------------------------------------
19
20fn default_timeout() -> u64 {
21    120
22}
23
24fn default_retries() -> u32 {
25    3
26}
27
28fn default_cache_backend() -> String {
29    "memory".to_string()
30}
31
32// ---------------------------------------------------------------------------
33// Sub-configs defined in mod.rs (not large enough for their own files)
34// ---------------------------------------------------------------------------
35
36/// General proxy behaviour: master key, timeouts, retries, feature flags.
37#[derive(Debug, Clone, Deserialize)]
38#[serde(deny_unknown_fields)]
39pub struct GeneralConfig {
40    pub master_key: Option<String>,
41    #[serde(default = "default_timeout")]
42    pub default_timeout_secs: u64,
43    #[serde(default = "default_retries")]
44    pub max_retries: u32,
45    #[serde(default)]
46    pub enable_cost_tracking: bool,
47    #[serde(default)]
48    pub enable_tracing: bool,
49}
50
51impl Default for GeneralConfig {
52    fn default() -> Self {
53        Self {
54            master_key: None,
55            default_timeout_secs: default_timeout(),
56            max_retries: default_retries(),
57            enable_cost_tracking: false,
58            enable_tracing: false,
59        }
60    }
61}
62
63/// Global rate-limit settings (requests-per-minute / tokens-per-minute).
64#[derive(Debug, Clone, Deserialize)]
65#[serde(deny_unknown_fields)]
66pub struct RateLimitConfig {
67    pub rpm: Option<u32>,
68    pub tpm: Option<u64>,
69}
70
71/// How budget limits are enforced.
72#[derive(Debug, Clone, Default, Deserialize, PartialEq, Eq)]
73#[serde(rename_all = "lowercase")]
74pub enum EnforcementMode {
75    /// Requests exceeding the budget are rejected.
76    #[default]
77    Hard,
78    /// Requests exceeding the budget are logged but allowed through.
79    Soft,
80}
81
82/// Budget enforcement settings with optional per-model limits.
83#[derive(Debug, Clone, Deserialize)]
84#[serde(deny_unknown_fields)]
85pub struct BudgetConfig {
86    pub global_limit: Option<f64>,
87    #[serde(default)]
88    pub model_limits: HashMap<String, f64>,
89    #[serde(default)]
90    pub enforcement: EnforcementMode,
91}
92
93/// Semantic cache configuration.
94#[derive(Debug, Clone, Deserialize)]
95#[serde(deny_unknown_fields)]
96pub struct CacheConfig {
97    pub max_entries: Option<usize>,
98    pub ttl_seconds: Option<u64>,
99    #[serde(default = "default_cache_backend")]
100    pub backend: String,
101    #[serde(default)]
102    pub backend_config: HashMap<String, String>,
103}
104
105/// Periodic health-check probe settings.
106#[derive(Debug, Clone, Deserialize)]
107#[serde(deny_unknown_fields)]
108pub struct HealthConfig {
109    pub interval_secs: Option<u64>,
110    pub probe_model: Option<String>,
111}
112
113/// Provider cooldown duration after consecutive failures.
114#[derive(Debug, Clone, Deserialize)]
115#[serde(deny_unknown_fields)]
116pub struct CooldownConfig {
117    pub duration_secs: u64,
118}
119
120// ---------------------------------------------------------------------------
121// Top-level ProxyConfig
122// ---------------------------------------------------------------------------
123
124/// Root configuration for the liter-llm proxy server.
125///
126/// Loaded from a `liter-llm-proxy.toml` file. After deserialization all
127/// `${VAR_NAME}` patterns in string values are replaced with the
128/// corresponding environment variable.
129#[derive(Debug, Clone, Deserialize)]
130#[serde(deny_unknown_fields)]
131#[derive(Default)]
132pub struct ProxyConfig {
133    #[serde(default)]
134    pub server: ServerConfig,
135    #[serde(default)]
136    pub general: GeneralConfig,
137    #[serde(default)]
138    pub models: Vec<ModelEntry>,
139    #[serde(default)]
140    pub aliases: Vec<AliasEntry>,
141    pub rate_limit: Option<RateLimitConfig>,
142    pub budget: Option<BudgetConfig>,
143    pub cache: Option<CacheConfig>,
144    pub files: Option<FileStorageConfig>,
145    #[serde(default)]
146    pub keys: Vec<VirtualKeyConfig>,
147    pub health: Option<HealthConfig>,
148    pub cooldown: Option<CooldownConfig>,
149}
150
151// ---------------------------------------------------------------------------
152// Environment variable interpolation
153// ---------------------------------------------------------------------------
154
155/// Replace all `${VAR_NAME}` occurrences in `s` with the value of the
156/// corresponding environment variable. Unknown variables are replaced with
157/// the empty string.
158pub fn interpolate_env_vars(s: &str) -> String {
159    let mut result = String::with_capacity(s.len());
160    let mut chars = s.chars().peekable();
161
162    while let Some(ch) = chars.next() {
163        if ch == '$' && chars.peek() == Some(&'{') {
164            // consume '{'
165            chars.next();
166            let mut var_name = String::new();
167            let mut found_closing = false;
168            for c in chars.by_ref() {
169                if c == '}' {
170                    found_closing = true;
171                    break;
172                }
173                var_name.push(c);
174            }
175            if found_closing {
176                if let Ok(val) = std::env::var(&var_name) {
177                    result.push_str(&val);
178                }
179            } else {
180                // No closing '}' found — treat `${...` as literal text.
181                result.push('$');
182                result.push('{');
183                result.push_str(&var_name);
184            }
185        } else {
186            result.push(ch);
187        }
188    }
189
190    result
191}
192
193/// Apply env-var interpolation to a raw TOML string, then deserialize.
194///
195/// This is the simplest correct approach: interpolate the whole TOML source
196/// before parsing, so every string value (including nested tables and arrays)
197/// gets expanded uniformly.
198fn parse_with_env_interpolation(raw: &str) -> Result<ProxyConfig, String> {
199    let expanded = interpolate_env_vars(raw);
200    toml::from_str(&expanded).map_err(|e| format!("invalid TOML config: {e}"))
201}
202
203impl ProxyConfig {
204    /// Parse from a TOML string with env-var interpolation.
205    pub fn from_toml_str(s: &str) -> Result<Self, String> {
206        parse_with_env_interpolation(s)
207    }
208
209    /// Load from a TOML file path with env-var interpolation.
210    pub fn from_toml_file(path: impl AsRef<Path>) -> Result<Self, String> {
211        let path = path.as_ref();
212        let content =
213            std::fs::read_to_string(path).map_err(|e| format!("failed to read config file {}: {e}", path.display()))?;
214        Self::from_toml_str(&content)
215    }
216
217    /// Discover `liter-llm-proxy.toml` by walking from the current directory
218    /// up to the filesystem root.
219    ///
220    /// Returns `Ok(None)` if no config file is found.
221    pub fn discover() -> Result<Option<Self>, String> {
222        let mut current = std::env::current_dir().map_err(|e| format!("failed to get current directory: {e}"))?;
223        loop {
224            let config_path = current.join("liter-llm-proxy.toml");
225            if config_path.exists() {
226                return Ok(Some(Self::from_toml_file(config_path)?));
227            }
228            match current.parent() {
229                Some(parent) => current = parent.to_path_buf(),
230                None => break,
231            }
232        }
233        Ok(None)
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    // 1. Parse minimal config (empty string)
242    #[test]
243    fn parse_minimal_config() {
244        let config = ProxyConfig::from_toml_str("").unwrap();
245        assert_eq!(config.server.host, "0.0.0.0");
246        assert_eq!(config.server.port, 4000);
247        assert_eq!(config.general.default_timeout_secs, 120);
248        assert_eq!(config.general.max_retries, 3);
249        assert!(config.models.is_empty());
250        assert!(config.keys.is_empty());
251        assert!(config.rate_limit.is_none());
252        assert!(config.budget.is_none());
253        assert!(config.cache.is_none());
254        assert!(config.files.is_none());
255        assert!(config.health.is_none());
256        assert!(config.cooldown.is_none());
257    }
258
259    // 2. Parse full config with all sections
260    #[test]
261    fn parse_full_config() {
262        let toml = r#"
263[server]
264host = "127.0.0.1"
265port = 8080
266request_timeout_secs = 300
267body_limit_bytes = 5242880
268cors_origins = ["https://example.com"]
269
270[general]
271master_key = "sk-master"
272default_timeout_secs = 60
273max_retries = 5
274enable_cost_tracking = true
275enable_tracing = true
276
277[[models]]
278name = "gpt-4o"
279provider_model = "openai/gpt-4o"
280api_key = "sk-openai"
281base_url = "https://api.openai.com/v1"
282timeout_secs = 30
283fallbacks = ["claude-sonnet"]
284
285[[models]]
286name = "claude-sonnet"
287provider_model = "anthropic/claude-sonnet-4-20250514"
288
289[[aliases]]
290pattern = "anthropic/*"
291api_key = "sk-anthropic"
292
293[[keys]]
294key = "vk-team-a"
295description = "Team A key"
296models = ["gpt-4o"]
297rpm = 60
298tpm = 100000
299budget_limit = 50.0
300
301[rate_limit]
302rpm = 120
303tpm = 500000
304
305[budget]
306global_limit = 100.0
307enforcement = "soft"
308
309[budget.model_limits]
310"openai/gpt-4o" = 50.0
311
312[cache]
313max_entries = 1024
314ttl_seconds = 600
315backend = "memory"
316
317[files]
318backend = "s3"
319prefix = "proxy-files/"
320
321[files.backend_config]
322bucket = "my-bucket"
323
324[health]
325interval_secs = 30
326probe_model = "openai/gpt-4o-mini"
327
328[cooldown]
329duration_secs = 60
330"#;
331        let config = ProxyConfig::from_toml_str(toml).unwrap();
332
333        // Server
334        assert_eq!(config.server.host, "127.0.0.1");
335        assert_eq!(config.server.port, 8080);
336        assert_eq!(config.server.request_timeout_secs, 300);
337        assert_eq!(config.server.body_limit_bytes, 5_242_880);
338        assert_eq!(config.server.cors_origins, vec!["https://example.com"]);
339
340        // General
341        assert_eq!(config.general.master_key.as_deref(), Some("sk-master"));
342        assert_eq!(config.general.default_timeout_secs, 60);
343        assert_eq!(config.general.max_retries, 5);
344        assert!(config.general.enable_cost_tracking);
345        assert!(config.general.enable_tracing);
346
347        // Models
348        assert_eq!(config.models.len(), 2);
349        assert_eq!(config.models[0].name, "gpt-4o");
350        assert_eq!(config.models[0].provider_model, "openai/gpt-4o");
351        assert_eq!(config.models[0].api_key.as_deref(), Some("sk-openai"));
352        assert_eq!(config.models[0].fallbacks, vec!["claude-sonnet"]);
353        assert_eq!(config.models[1].name, "claude-sonnet");
354        assert!(config.models[1].api_key.is_none());
355
356        // Aliases
357        assert_eq!(config.aliases.len(), 1);
358        assert_eq!(config.aliases[0].pattern, "anthropic/*");
359
360        // Keys
361        assert_eq!(config.keys.len(), 1);
362        assert_eq!(config.keys[0].key, "vk-team-a");
363        assert_eq!(config.keys[0].models, vec!["gpt-4o"]);
364        assert_eq!(config.keys[0].rpm, Some(60));
365
366        // Rate limit
367        let rl = config.rate_limit.unwrap();
368        assert_eq!(rl.rpm, Some(120));
369        assert_eq!(rl.tpm, Some(500_000));
370
371        // Budget
372        let budget = config.budget.unwrap();
373        assert_eq!(budget.global_limit, Some(100.0));
374        assert_eq!(budget.enforcement, EnforcementMode::Soft);
375        assert_eq!(budget.model_limits.get("openai/gpt-4o"), Some(&50.0));
376
377        // Cache
378        let cache = config.cache.unwrap();
379        assert_eq!(cache.max_entries, Some(1024));
380        assert_eq!(cache.ttl_seconds, Some(600));
381        assert_eq!(cache.backend, "memory");
382
383        // Files
384        let files = config.files.unwrap();
385        assert_eq!(files.backend, "s3");
386        assert_eq!(files.prefix, "proxy-files/");
387        assert_eq!(files.backend_config.get("bucket").unwrap(), "my-bucket");
388
389        // Health
390        let health = config.health.unwrap();
391        assert_eq!(health.interval_secs, Some(30));
392        assert_eq!(health.probe_model.as_deref(), Some("openai/gpt-4o-mini"));
393
394        // Cooldown
395        assert_eq!(config.cooldown.unwrap().duration_secs, 60);
396    }
397
398    // 3. Env var interpolation
399    #[test]
400    fn env_var_interpolation() {
401        // SAFETY: test is not running concurrently with other tests that
402        // depend on these specific env vars.
403        unsafe {
404            std::env::set_var("LITER_TEST_KEY", "sk-from-env");
405            std::env::set_var("LITER_TEST_HOST", "10.0.0.1");
406        }
407
408        let toml = r#"
409[server]
410host = "${LITER_TEST_HOST}"
411
412[general]
413master_key = "${LITER_TEST_KEY}"
414"#;
415        let config = ProxyConfig::from_toml_str(toml).unwrap();
416        assert_eq!(config.server.host, "10.0.0.1");
417        assert_eq!(config.general.master_key.as_deref(), Some("sk-from-env"));
418
419        // SAFETY: cleaning up test-only env vars.
420        unsafe {
421            std::env::remove_var("LITER_TEST_KEY");
422            std::env::remove_var("LITER_TEST_HOST");
423        }
424    }
425
426    #[test]
427    fn env_var_interpolation_preserves_literals() {
428        let toml = r#"
429[server]
430host = "literal-value"
431"#;
432        let config = ProxyConfig::from_toml_str(toml).unwrap();
433        assert_eq!(config.server.host, "literal-value");
434    }
435
436    #[test]
437    fn env_var_interpolation_unknown_var_becomes_empty() {
438        let result = interpolate_env_vars("prefix-${SURELY_NONEXISTENT_VAR_12345}-suffix");
439        assert_eq!(result, "prefix--suffix");
440    }
441
442    // 4. Unknown field rejection
443    #[test]
444    fn rejects_unknown_top_level_field() {
445        let toml = r#"
446unknown_field = true
447"#;
448        assert!(ProxyConfig::from_toml_str(toml).is_err());
449    }
450
451    #[test]
452    fn rejects_unknown_server_field() {
453        let toml = r#"
454[server]
455host = "0.0.0.0"
456bogus = 42
457"#;
458        assert!(ProxyConfig::from_toml_str(toml).is_err());
459    }
460
461    #[test]
462    fn rejects_unknown_general_field() {
463        let toml = r#"
464[general]
465unknown_option = true
466"#;
467        assert!(ProxyConfig::from_toml_str(toml).is_err());
468    }
469
470    // 5. Default values applied correctly
471    #[test]
472    fn default_values_applied() {
473        let config = ProxyConfig::default();
474        assert_eq!(config.server.host, "0.0.0.0");
475        assert_eq!(config.server.port, 4000);
476        assert_eq!(config.server.request_timeout_secs, 600);
477        assert_eq!(config.server.body_limit_bytes, 10_485_760);
478        assert_eq!(config.server.cors_origins, vec!["*"]);
479        assert_eq!(config.general.default_timeout_secs, 120);
480        assert_eq!(config.general.max_retries, 3);
481        assert!(!config.general.enable_cost_tracking);
482        assert!(!config.general.enable_tracing);
483    }
484
485    #[test]
486    fn budget_default_enforcement() {
487        let toml = r#"
488[budget]
489global_limit = 100.0
490"#;
491        let config = ProxyConfig::from_toml_str(toml).unwrap();
492        assert_eq!(config.budget.unwrap().enforcement, EnforcementMode::Hard);
493    }
494
495    #[test]
496    fn cache_default_backend() {
497        let toml = r#"
498[cache]
499max_entries = 256
500"#;
501        let config = ProxyConfig::from_toml_str(toml).unwrap();
502        assert_eq!(config.cache.unwrap().backend, "memory");
503    }
504
505    #[test]
506    fn files_default_values() {
507        let toml = r#"
508[files]
509"#;
510        let config = ProxyConfig::from_toml_str(toml).unwrap();
511        let files = config.files.unwrap();
512        assert_eq!(files.backend, "memory");
513        assert_eq!(files.prefix, "liter-llm-files/");
514        assert!(files.backend_config.is_empty());
515    }
516
517    // 6. Multiple models with same name (load balancing)
518    #[test]
519    fn multiple_models_same_name() {
520        let toml = r#"
521[[models]]
522name = "gpt-4o"
523provider_model = "openai/gpt-4o"
524api_key = "sk-key-1"
525
526[[models]]
527name = "gpt-4o"
528provider_model = "azure/gpt-4o"
529api_key = "sk-key-2"
530"#;
531        let config = ProxyConfig::from_toml_str(toml).unwrap();
532        assert_eq!(config.models.len(), 2);
533        assert_eq!(config.models[0].name, "gpt-4o");
534        assert_eq!(config.models[1].name, "gpt-4o");
535        assert_ne!(config.models[0].provider_model, config.models[1].provider_model);
536    }
537
538    // 7. Model with fallbacks
539    #[test]
540    fn model_with_fallbacks() {
541        let toml = r#"
542[[models]]
543name = "primary"
544provider_model = "openai/gpt-4o"
545fallbacks = ["fallback-1", "fallback-2"]
546
547[[models]]
548name = "fallback-1"
549provider_model = "anthropic/claude-sonnet-4-20250514"
550
551[[models]]
552name = "fallback-2"
553provider_model = "groq/llama3-70b"
554"#;
555        let config = ProxyConfig::from_toml_str(toml).unwrap();
556        assert_eq!(config.models[0].fallbacks, vec!["fallback-1", "fallback-2"]);
557        assert!(config.models[1].fallbacks.is_empty());
558        assert!(config.models[2].fallbacks.is_empty());
559    }
560
561    #[test]
562    fn interpolate_env_vars_basic() {
563        assert_eq!(interpolate_env_vars("no vars here"), "no vars here");
564        assert_eq!(interpolate_env_vars(""), "");
565        assert_eq!(interpolate_env_vars("$not_a_var"), "$not_a_var");
566    }
567
568    #[test]
569    fn interpolate_env_vars_multiple() {
570        // SAFETY: test is not running concurrently with other tests that
571        // depend on these specific env vars.
572        unsafe {
573            std::env::set_var("LITER_A", "hello");
574            std::env::set_var("LITER_B", "world");
575        }
576        let result = interpolate_env_vars("${LITER_A} ${LITER_B}!");
577        assert_eq!(result, "hello world!");
578        // SAFETY: cleaning up test-only env vars.
579        unsafe {
580            std::env::remove_var("LITER_A");
581            std::env::remove_var("LITER_B");
582        }
583    }
584
585    #[test]
586    fn interpolate_env_vars_unclosed_brace_treated_as_literal() {
587        // Unclosed `${` should be preserved as literal text, not silently dropped.
588        assert_eq!(interpolate_env_vars("prefix-${UNCLOSED"), "prefix-${UNCLOSED");
589        assert_eq!(interpolate_env_vars("${"), "${");
590        assert_eq!(interpolate_env_vars("a${b"), "a${b");
591    }
592}