1pub mod files;
2pub mod key;
3pub mod model;
4pub mod server;
5
6pub use files::FileStorageConfig;
7pub use key::VirtualKeyConfig;
8pub use model::{AliasEntry, ModelEntry};
9pub use server::ServerConfig;
10
11use std::collections::HashMap;
12use std::path::Path;
13
14use serde::Deserialize;
15
16fn default_timeout() -> u64 {
21 120
22}
23
24fn default_retries() -> u32 {
25 3
26}
27
28fn default_cache_backend() -> String {
29 "memory".to_string()
30}
31
32#[derive(Debug, Clone, Deserialize)]
38#[serde(deny_unknown_fields)]
39pub struct GeneralConfig {
40 pub master_key: Option<String>,
41 #[serde(default = "default_timeout")]
42 pub default_timeout_secs: u64,
43 #[serde(default = "default_retries")]
44 pub max_retries: u32,
45 #[serde(default)]
46 pub enable_cost_tracking: bool,
47 #[serde(default)]
48 pub enable_tracing: bool,
49}
50
51impl Default for GeneralConfig {
52 fn default() -> Self {
53 Self {
54 master_key: None,
55 default_timeout_secs: default_timeout(),
56 max_retries: default_retries(),
57 enable_cost_tracking: false,
58 enable_tracing: false,
59 }
60 }
61}
62
63#[derive(Debug, Clone, Deserialize)]
65#[serde(deny_unknown_fields)]
66pub struct RateLimitConfig {
67 pub rpm: Option<u32>,
68 pub tpm: Option<u64>,
69}
70
71#[derive(Debug, Clone, Default, Deserialize, PartialEq, Eq)]
73#[serde(rename_all = "lowercase")]
74pub enum EnforcementMode {
75 #[default]
77 Hard,
78 Soft,
80}
81
82#[derive(Debug, Clone, Deserialize)]
84#[serde(deny_unknown_fields)]
85pub struct BudgetConfig {
86 pub global_limit: Option<f64>,
87 #[serde(default)]
88 pub model_limits: HashMap<String, f64>,
89 #[serde(default)]
90 pub enforcement: EnforcementMode,
91}
92
93#[derive(Debug, Clone, Deserialize)]
95#[serde(deny_unknown_fields)]
96pub struct CacheConfig {
97 pub max_entries: Option<usize>,
98 pub ttl_seconds: Option<u64>,
99 #[serde(default = "default_cache_backend")]
100 pub backend: String,
101 #[serde(default)]
102 pub backend_config: HashMap<String, String>,
103}
104
105#[derive(Debug, Clone, Deserialize)]
107#[serde(deny_unknown_fields)]
108pub struct HealthConfig {
109 pub interval_secs: Option<u64>,
110 pub probe_model: Option<String>,
111}
112
113#[derive(Debug, Clone, Deserialize)]
115#[serde(deny_unknown_fields)]
116pub struct CooldownConfig {
117 pub duration_secs: u64,
118}
119
120#[derive(Debug, Clone, Deserialize)]
130#[serde(deny_unknown_fields)]
131#[derive(Default)]
132pub struct ProxyConfig {
133 #[serde(default)]
134 pub server: ServerConfig,
135 #[serde(default)]
136 pub general: GeneralConfig,
137 #[serde(default)]
138 pub models: Vec<ModelEntry>,
139 #[serde(default)]
140 pub aliases: Vec<AliasEntry>,
141 pub rate_limit: Option<RateLimitConfig>,
142 pub budget: Option<BudgetConfig>,
143 pub cache: Option<CacheConfig>,
144 pub files: Option<FileStorageConfig>,
145 #[serde(default)]
146 pub keys: Vec<VirtualKeyConfig>,
147 pub health: Option<HealthConfig>,
148 pub cooldown: Option<CooldownConfig>,
149}
150
151pub fn interpolate_env_vars(s: &str) -> String {
159 let mut result = String::with_capacity(s.len());
160 let mut chars = s.chars().peekable();
161
162 while let Some(ch) = chars.next() {
163 if ch == '$' && chars.peek() == Some(&'{') {
164 chars.next();
166 let mut var_name = String::new();
167 let mut found_closing = false;
168 for c in chars.by_ref() {
169 if c == '}' {
170 found_closing = true;
171 break;
172 }
173 var_name.push(c);
174 }
175 if found_closing {
176 if let Ok(val) = std::env::var(&var_name) {
177 result.push_str(&val);
178 }
179 } else {
180 result.push('$');
182 result.push('{');
183 result.push_str(&var_name);
184 }
185 } else {
186 result.push(ch);
187 }
188 }
189
190 result
191}
192
193fn parse_with_env_interpolation(raw: &str) -> Result<ProxyConfig, String> {
199 let expanded = interpolate_env_vars(raw);
200 toml::from_str(&expanded).map_err(|e| format!("invalid TOML config: {e}"))
201}
202
203impl ProxyConfig {
204 pub fn from_toml_str(s: &str) -> Result<Self, String> {
206 parse_with_env_interpolation(s)
207 }
208
209 pub fn from_toml_file(path: impl AsRef<Path>) -> Result<Self, String> {
211 let path = path.as_ref();
212 let content =
213 std::fs::read_to_string(path).map_err(|e| format!("failed to read config file {}: {e}", path.display()))?;
214 Self::from_toml_str(&content)
215 }
216
217 pub fn discover() -> Result<Option<Self>, String> {
222 let mut current = std::env::current_dir().map_err(|e| format!("failed to get current directory: {e}"))?;
223 loop {
224 let config_path = current.join("liter-llm-proxy.toml");
225 if config_path.exists() {
226 return Ok(Some(Self::from_toml_file(config_path)?));
227 }
228 match current.parent() {
229 Some(parent) => current = parent.to_path_buf(),
230 None => break,
231 }
232 }
233 Ok(None)
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 #[test]
243 fn parse_minimal_config() {
244 let config = ProxyConfig::from_toml_str("").unwrap();
245 assert_eq!(config.server.host, "0.0.0.0");
246 assert_eq!(config.server.port, 4000);
247 assert_eq!(config.general.default_timeout_secs, 120);
248 assert_eq!(config.general.max_retries, 3);
249 assert!(config.models.is_empty());
250 assert!(config.keys.is_empty());
251 assert!(config.rate_limit.is_none());
252 assert!(config.budget.is_none());
253 assert!(config.cache.is_none());
254 assert!(config.files.is_none());
255 assert!(config.health.is_none());
256 assert!(config.cooldown.is_none());
257 }
258
259 #[test]
261 fn parse_full_config() {
262 let toml = r#"
263[server]
264host = "127.0.0.1"
265port = 8080
266request_timeout_secs = 300
267body_limit_bytes = 5242880
268cors_origins = ["https://example.com"]
269
270[general]
271master_key = "sk-master"
272default_timeout_secs = 60
273max_retries = 5
274enable_cost_tracking = true
275enable_tracing = true
276
277[[models]]
278name = "gpt-4o"
279provider_model = "openai/gpt-4o"
280api_key = "sk-openai"
281base_url = "https://api.openai.com/v1"
282timeout_secs = 30
283fallbacks = ["claude-sonnet"]
284
285[[models]]
286name = "claude-sonnet"
287provider_model = "anthropic/claude-sonnet-4-20250514"
288
289[[aliases]]
290pattern = "anthropic/*"
291api_key = "sk-anthropic"
292
293[[keys]]
294key = "vk-team-a"
295description = "Team A key"
296models = ["gpt-4o"]
297rpm = 60
298tpm = 100000
299budget_limit = 50.0
300
301[rate_limit]
302rpm = 120
303tpm = 500000
304
305[budget]
306global_limit = 100.0
307enforcement = "soft"
308
309[budget.model_limits]
310"openai/gpt-4o" = 50.0
311
312[cache]
313max_entries = 1024
314ttl_seconds = 600
315backend = "memory"
316
317[files]
318backend = "s3"
319prefix = "proxy-files/"
320
321[files.backend_config]
322bucket = "my-bucket"
323
324[health]
325interval_secs = 30
326probe_model = "openai/gpt-4o-mini"
327
328[cooldown]
329duration_secs = 60
330"#;
331 let config = ProxyConfig::from_toml_str(toml).unwrap();
332
333 assert_eq!(config.server.host, "127.0.0.1");
335 assert_eq!(config.server.port, 8080);
336 assert_eq!(config.server.request_timeout_secs, 300);
337 assert_eq!(config.server.body_limit_bytes, 5_242_880);
338 assert_eq!(config.server.cors_origins, vec!["https://example.com"]);
339
340 assert_eq!(config.general.master_key.as_deref(), Some("sk-master"));
342 assert_eq!(config.general.default_timeout_secs, 60);
343 assert_eq!(config.general.max_retries, 5);
344 assert!(config.general.enable_cost_tracking);
345 assert!(config.general.enable_tracing);
346
347 assert_eq!(config.models.len(), 2);
349 assert_eq!(config.models[0].name, "gpt-4o");
350 assert_eq!(config.models[0].provider_model, "openai/gpt-4o");
351 assert_eq!(config.models[0].api_key.as_deref(), Some("sk-openai"));
352 assert_eq!(config.models[0].fallbacks, vec!["claude-sonnet"]);
353 assert_eq!(config.models[1].name, "claude-sonnet");
354 assert!(config.models[1].api_key.is_none());
355
356 assert_eq!(config.aliases.len(), 1);
358 assert_eq!(config.aliases[0].pattern, "anthropic/*");
359
360 assert_eq!(config.keys.len(), 1);
362 assert_eq!(config.keys[0].key, "vk-team-a");
363 assert_eq!(config.keys[0].models, vec!["gpt-4o"]);
364 assert_eq!(config.keys[0].rpm, Some(60));
365
366 let rl = config.rate_limit.unwrap();
368 assert_eq!(rl.rpm, Some(120));
369 assert_eq!(rl.tpm, Some(500_000));
370
371 let budget = config.budget.unwrap();
373 assert_eq!(budget.global_limit, Some(100.0));
374 assert_eq!(budget.enforcement, EnforcementMode::Soft);
375 assert_eq!(budget.model_limits.get("openai/gpt-4o"), Some(&50.0));
376
377 let cache = config.cache.unwrap();
379 assert_eq!(cache.max_entries, Some(1024));
380 assert_eq!(cache.ttl_seconds, Some(600));
381 assert_eq!(cache.backend, "memory");
382
383 let files = config.files.unwrap();
385 assert_eq!(files.backend, "s3");
386 assert_eq!(files.prefix, "proxy-files/");
387 assert_eq!(files.backend_config.get("bucket").unwrap(), "my-bucket");
388
389 let health = config.health.unwrap();
391 assert_eq!(health.interval_secs, Some(30));
392 assert_eq!(health.probe_model.as_deref(), Some("openai/gpt-4o-mini"));
393
394 assert_eq!(config.cooldown.unwrap().duration_secs, 60);
396 }
397
398 #[test]
400 fn env_var_interpolation() {
401 unsafe {
404 std::env::set_var("LITER_TEST_KEY", "sk-from-env");
405 std::env::set_var("LITER_TEST_HOST", "10.0.0.1");
406 }
407
408 let toml = r#"
409[server]
410host = "${LITER_TEST_HOST}"
411
412[general]
413master_key = "${LITER_TEST_KEY}"
414"#;
415 let config = ProxyConfig::from_toml_str(toml).unwrap();
416 assert_eq!(config.server.host, "10.0.0.1");
417 assert_eq!(config.general.master_key.as_deref(), Some("sk-from-env"));
418
419 unsafe {
421 std::env::remove_var("LITER_TEST_KEY");
422 std::env::remove_var("LITER_TEST_HOST");
423 }
424 }
425
426 #[test]
427 fn env_var_interpolation_preserves_literals() {
428 let toml = r#"
429[server]
430host = "literal-value"
431"#;
432 let config = ProxyConfig::from_toml_str(toml).unwrap();
433 assert_eq!(config.server.host, "literal-value");
434 }
435
436 #[test]
437 fn env_var_interpolation_unknown_var_becomes_empty() {
438 let result = interpolate_env_vars("prefix-${SURELY_NONEXISTENT_VAR_12345}-suffix");
439 assert_eq!(result, "prefix--suffix");
440 }
441
442 #[test]
444 fn rejects_unknown_top_level_field() {
445 let toml = r#"
446unknown_field = true
447"#;
448 assert!(ProxyConfig::from_toml_str(toml).is_err());
449 }
450
451 #[test]
452 fn rejects_unknown_server_field() {
453 let toml = r#"
454[server]
455host = "0.0.0.0"
456bogus = 42
457"#;
458 assert!(ProxyConfig::from_toml_str(toml).is_err());
459 }
460
461 #[test]
462 fn rejects_unknown_general_field() {
463 let toml = r#"
464[general]
465unknown_option = true
466"#;
467 assert!(ProxyConfig::from_toml_str(toml).is_err());
468 }
469
470 #[test]
472 fn default_values_applied() {
473 let config = ProxyConfig::default();
474 assert_eq!(config.server.host, "0.0.0.0");
475 assert_eq!(config.server.port, 4000);
476 assert_eq!(config.server.request_timeout_secs, 600);
477 assert_eq!(config.server.body_limit_bytes, 10_485_760);
478 assert_eq!(config.server.cors_origins, vec!["*"]);
479 assert_eq!(config.general.default_timeout_secs, 120);
480 assert_eq!(config.general.max_retries, 3);
481 assert!(!config.general.enable_cost_tracking);
482 assert!(!config.general.enable_tracing);
483 }
484
485 #[test]
486 fn budget_default_enforcement() {
487 let toml = r#"
488[budget]
489global_limit = 100.0
490"#;
491 let config = ProxyConfig::from_toml_str(toml).unwrap();
492 assert_eq!(config.budget.unwrap().enforcement, EnforcementMode::Hard);
493 }
494
495 #[test]
496 fn cache_default_backend() {
497 let toml = r#"
498[cache]
499max_entries = 256
500"#;
501 let config = ProxyConfig::from_toml_str(toml).unwrap();
502 assert_eq!(config.cache.unwrap().backend, "memory");
503 }
504
505 #[test]
506 fn files_default_values() {
507 let toml = r#"
508[files]
509"#;
510 let config = ProxyConfig::from_toml_str(toml).unwrap();
511 let files = config.files.unwrap();
512 assert_eq!(files.backend, "memory");
513 assert_eq!(files.prefix, "liter-llm-files/");
514 assert!(files.backend_config.is_empty());
515 }
516
517 #[test]
519 fn multiple_models_same_name() {
520 let toml = r#"
521[[models]]
522name = "gpt-4o"
523provider_model = "openai/gpt-4o"
524api_key = "sk-key-1"
525
526[[models]]
527name = "gpt-4o"
528provider_model = "azure/gpt-4o"
529api_key = "sk-key-2"
530"#;
531 let config = ProxyConfig::from_toml_str(toml).unwrap();
532 assert_eq!(config.models.len(), 2);
533 assert_eq!(config.models[0].name, "gpt-4o");
534 assert_eq!(config.models[1].name, "gpt-4o");
535 assert_ne!(config.models[0].provider_model, config.models[1].provider_model);
536 }
537
538 #[test]
540 fn model_with_fallbacks() {
541 let toml = r#"
542[[models]]
543name = "primary"
544provider_model = "openai/gpt-4o"
545fallbacks = ["fallback-1", "fallback-2"]
546
547[[models]]
548name = "fallback-1"
549provider_model = "anthropic/claude-sonnet-4-20250514"
550
551[[models]]
552name = "fallback-2"
553provider_model = "groq/llama3-70b"
554"#;
555 let config = ProxyConfig::from_toml_str(toml).unwrap();
556 assert_eq!(config.models[0].fallbacks, vec!["fallback-1", "fallback-2"]);
557 assert!(config.models[1].fallbacks.is_empty());
558 assert!(config.models[2].fallbacks.is_empty());
559 }
560
561 #[test]
562 fn interpolate_env_vars_basic() {
563 assert_eq!(interpolate_env_vars("no vars here"), "no vars here");
564 assert_eq!(interpolate_env_vars(""), "");
565 assert_eq!(interpolate_env_vars("$not_a_var"), "$not_a_var");
566 }
567
568 #[test]
569 fn interpolate_env_vars_multiple() {
570 unsafe {
573 std::env::set_var("LITER_A", "hello");
574 std::env::set_var("LITER_B", "world");
575 }
576 let result = interpolate_env_vars("${LITER_A} ${LITER_B}!");
577 assert_eq!(result, "hello world!");
578 unsafe {
580 std::env::remove_var("LITER_A");
581 std::env::remove_var("LITER_B");
582 }
583 }
584
585 #[test]
586 fn interpolate_env_vars_unclosed_brace_treated_as_literal() {
587 assert_eq!(interpolate_env_vars("prefix-${UNCLOSED"), "prefix-${UNCLOSED");
589 assert_eq!(interpolate_env_vars("${"), "${");
590 assert_eq!(interpolate_env_vars("a${b"), "a${b");
591 }
592}