liter_llm_proxy/
service_pool.rs1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3use std::time::Duration;
4
5use tower::Layer;
6
7use liter_llm::client::{ClientConfigBuilder, DefaultClient};
8use liter_llm::error::LiterLlmError;
9use liter_llm::tower::types::{LlmRequest, LlmResponse};
10use liter_llm::tower::{
11 BudgetConfig, BudgetLayer, BudgetState, CacheConfig, CacheLayer, CooldownLayer, CostTrackingLayer, Enforcement,
12 HealthCheckLayer, LlmService, ModelRateLimitLayer, RateLimitConfig, TracingLayer,
13};
14
15use crate::config::{ModelEntry, ProxyConfig};
16use crate::error::ProxyError;
17
18type Bcs = tower::util::BoxCloneService<LlmRequest, LlmResponse, LiterLlmError>;
19
20struct SyncBoxService {
27 inner: Mutex<Bcs>,
28}
29
30impl SyncBoxService {
31 fn clone_service(&self) -> Result<Bcs, ProxyError> {
37 self.inner
38 .lock()
39 .map(|guard| guard.clone())
40 .map_err(|_| ProxyError::internal("service mutex poisoned"))
41 }
42}
43
44pub struct ServicePool {
50 services: HashMap<String, SyncBoxService>,
52 clients: HashMap<String, Arc<DefaultClient>>,
55 default_client: Option<Arc<DefaultClient>>,
58}
59
60impl ServicePool {
64 pub fn from_config(config: &ProxyConfig) -> Result<Self, String> {
75 let mut grouped: HashMap<String, Vec<&ModelEntry>> = HashMap::new();
78 for entry in &config.models {
79 grouped.entry(entry.name.clone()).or_default().push(entry);
80 }
81
82 let mut services = HashMap::new();
83 let mut clients = HashMap::new();
84 let mut default_client: Option<Arc<DefaultClient>> = None;
85
86 for (name, entries) in &grouped {
87 let entry = entries[0];
89
90 let client = build_client(entry, config)?;
91 let client_arc = Arc::new(client);
92
93 if default_client.is_none() {
95 default_client = Some(Arc::clone(&client_arc));
96 }
97
98 let svc = build_service_stack(config, Arc::clone(&client_arc));
99
100 services.insert(name.clone(), SyncBoxService { inner: Mutex::new(svc) });
101 clients.insert(name.clone(), client_arc);
102 }
103
104 Ok(Self {
105 services,
106 clients,
107 default_client,
108 })
109 }
110
111 pub fn get_service(&self, model: &str) -> Result<Bcs, ProxyError> {
117 self.services
118 .get(model)
119 .ok_or_else(|| ProxyError::not_found(format!("model '{model}' not found")))?
120 .clone_service()
121 }
122
123 pub fn get_client(&self, model: &str) -> Result<Arc<DefaultClient>, ProxyError> {
132 self.clients
133 .get(model)
134 .cloned()
135 .ok_or_else(|| ProxyError::not_found(format!("model '{model}' not found")))
136 }
137
138 pub fn first_client(&self) -> Result<Arc<DefaultClient>, ProxyError> {
143 self.default_client
144 .clone()
145 .ok_or_else(|| ProxyError::service_unavailable("no models configured"))
146 }
147
148 pub fn model_names(&self) -> Vec<&str> {
150 self.services.keys().map(String::as_str).collect()
151 }
152
153 pub fn has_any_service(&self) -> bool {
155 !self.services.is_empty()
156 }
157}
158
159fn build_client(entry: &ModelEntry, config: &ProxyConfig) -> Result<DefaultClient, String> {
161 let api_key = entry.api_key.as_deref().unwrap_or("");
162
163 let mut builder = ClientConfigBuilder::new(api_key);
164
165 if let Some(ref url) = entry.base_url {
166 builder = builder.base_url(url);
167 }
168
169 let timeout_secs = entry.timeout_secs.unwrap_or(config.general.default_timeout_secs);
170 builder = builder.timeout(Duration::from_secs(timeout_secs));
171 builder = builder.max_retries(config.general.max_retries);
172
173 let client_config = builder.build();
174
175 DefaultClient::new(client_config, Some(&entry.provider_model))
176 .map_err(|e| format!("failed to build client for model '{}': {e}", entry.name))
177}
178
179fn build_service_stack(config: &ProxyConfig, client: Arc<DefaultClient>) -> Bcs {
190 let base = LlmService::new_from_arc(client);
191 let mut svc: Bcs = tower::util::BoxCloneService::new(base);
192
193 if let Some(ref cache_cfg) = config.cache {
195 let max_entries = cache_cfg.max_entries.unwrap_or(256);
196 let ttl = Duration::from_secs(cache_cfg.ttl_seconds.unwrap_or(300));
197 let tower_cache_cfg = CacheConfig {
198 max_entries,
199 ttl,
200 backend: liter_llm::tower::CacheBackend::Memory,
201 };
202 let layer = CacheLayer::new(tower_cache_cfg);
203 svc = tower::util::BoxCloneService::new(layer.layer(svc));
204 }
205
206 if let Some(ref health_cfg) = config.health
208 && let Some(interval_secs) = health_cfg.interval_secs
209 {
210 let layer = HealthCheckLayer::new(Duration::from_secs(interval_secs));
211 svc = tower::util::BoxCloneService::new(layer.layer(svc));
212 }
213
214 if let Some(ref cooldown_cfg) = config.cooldown {
216 let layer = CooldownLayer::new(Duration::from_secs(cooldown_cfg.duration_secs));
217 svc = tower::util::BoxCloneService::new(layer.layer(svc));
218 }
219
220 if let Some(ref rl_cfg) = config.rate_limit {
222 let tower_rl_cfg = RateLimitConfig {
223 rpm: rl_cfg.rpm,
224 tpm: rl_cfg.tpm,
225 window: Duration::from_secs(60),
226 };
227 let layer = ModelRateLimitLayer::new(tower_rl_cfg);
228 svc = tower::util::BoxCloneService::new(layer.layer(svc));
229 }
230
231 if config.general.enable_cost_tracking {
233 svc = tower::util::BoxCloneService::new(CostTrackingLayer.layer(svc));
234 }
235
236 if let Some(ref budget_cfg) = config.budget {
238 let enforcement = match budget_cfg.enforcement {
239 crate::config::EnforcementMode::Soft => Enforcement::Soft,
240 crate::config::EnforcementMode::Hard => Enforcement::Hard,
241 };
242 let tower_budget_cfg = BudgetConfig {
243 global_limit: budget_cfg.global_limit,
244 model_limits: budget_cfg.model_limits.clone(),
245 enforcement,
246 };
247 let state = Arc::new(BudgetState::new());
248 let layer = BudgetLayer::new(tower_budget_cfg, state);
249 svc = tower::util::BoxCloneService::new(layer.layer(svc));
250 }
251
252 if config.general.enable_tracing {
254 svc = tower::util::BoxCloneService::new(TracingLayer.layer(svc));
255 }
256
257 svc
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use crate::config::ProxyConfig;
264
265 fn config_with_one_model() -> ProxyConfig {
266 ProxyConfig::from_toml_str(
267 r#"
268[[models]]
269name = "test-model"
270provider_model = "openai/gpt-4o"
271api_key = "sk-test"
272"#,
273 )
274 .expect("valid TOML")
275 }
276
277 fn config_with_two_models() -> ProxyConfig {
278 ProxyConfig::from_toml_str(
279 r#"
280[[models]]
281name = "model-a"
282provider_model = "openai/gpt-4o"
283api_key = "sk-a"
284
285[[models]]
286name = "model-b"
287provider_model = "anthropic/claude-sonnet-4-20250514"
288api_key = "sk-b"
289"#,
290 )
291 .expect("valid TOML")
292 }
293
294 #[test]
295 fn build_from_empty_config() {
296 let config = ProxyConfig::default();
297 let pool = ServicePool::from_config(&config).expect("empty config should build");
298 assert!(pool.services.is_empty());
299 assert!(pool.clients.is_empty());
300 assert!(!pool.has_any_service());
301 }
302
303 #[test]
304 fn build_from_config_with_one_model() {
305 let config = config_with_one_model();
306 let pool = ServicePool::from_config(&config).expect("should build");
307 assert_eq!(pool.services.len(), 1);
308 assert_eq!(pool.clients.len(), 1);
309 assert!(pool.has_any_service());
310 }
311
312 #[test]
313 fn get_service_for_unknown_model_returns_not_found() {
314 let config = config_with_one_model();
315 let pool = ServicePool::from_config(&config).expect("should build");
316 let result = pool.get_service("nonexistent");
317 assert!(result.is_err());
318 let err = result.unwrap_err();
319 assert!(err.to_string().contains("not found"));
320 }
321
322 #[test]
323 fn get_service_for_known_model_succeeds() {
324 let config = config_with_one_model();
325 let pool = ServicePool::from_config(&config).expect("should build");
326 let result = pool.get_service("test-model");
327 assert!(result.is_ok());
328 }
329
330 #[test]
331 fn get_client_for_known_model_succeeds() {
332 let config = config_with_one_model();
333 let pool = ServicePool::from_config(&config).expect("should build");
334 let result = pool.get_client("test-model");
335 assert!(result.is_ok());
336 }
337
338 #[test]
339 fn get_client_for_unknown_model_returns_not_found() {
340 let config = config_with_one_model();
341 let pool = ServicePool::from_config(&config).expect("should build");
342 let result = pool.get_client("nonexistent");
343 assert!(result.is_err());
344 }
345
346 #[test]
347 fn model_names_returns_correct_list() {
348 let config = config_with_two_models();
349 let pool = ServicePool::from_config(&config).expect("should build");
350 let mut names = pool.model_names();
351 names.sort();
352 assert_eq!(names, vec!["model-a", "model-b"]);
353 }
354
355 #[test]
356 fn has_any_service_returns_false_for_empty_pool() {
357 let config = ProxyConfig::default();
358 let pool = ServicePool::from_config(&config).expect("should build");
359 assert!(!pool.has_any_service());
360 }
361
362 #[test]
363 fn has_any_service_returns_true_for_nonempty_pool() {
364 let config = config_with_one_model();
365 let pool = ServicePool::from_config(&config).expect("should build");
366 assert!(pool.has_any_service());
367 }
368
369 #[tokio::test]
370 async fn build_with_middleware_config() {
371 let config = ProxyConfig::from_toml_str(
372 r#"
373[general]
374enable_cost_tracking = true
375enable_tracing = true
376
377[[models]]
378name = "gpt"
379provider_model = "openai/gpt-4o"
380api_key = "sk-test"
381
382[cache]
383max_entries = 128
384ttl_seconds = 60
385
386[rate_limit]
387rpm = 100
388
389[budget]
390global_limit = 50.0
391enforcement = "soft"
392
393[cooldown]
394duration_secs = 30
395
396[health]
397interval_secs = 10
398"#,
399 )
400 .expect("valid TOML");
401
402 let pool = ServicePool::from_config(&config).expect("should build with middleware");
403 assert!(pool.has_any_service());
404 assert!(pool.get_service("gpt").is_ok());
405 }
406
407 #[test]
408 fn duplicate_model_names_use_first_entry() {
409 let config = ProxyConfig::from_toml_str(
410 r#"
411[[models]]
412name = "gpt"
413provider_model = "openai/gpt-4o"
414api_key = "sk-1"
415
416[[models]]
417name = "gpt"
418provider_model = "azure/gpt-4o"
419api_key = "sk-2"
420"#,
421 )
422 .expect("valid TOML");
423
424 let pool = ServicePool::from_config(&config).expect("should build");
425 assert_eq!(pool.services.len(), 1);
427 assert!(pool.get_service("gpt").is_ok());
428 }
429}