Skip to main content

liter_llm_proxy/
service_pool.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3use std::time::Duration;
4
5use tower::Layer;
6
7use liter_llm::client::{ClientConfigBuilder, DefaultClient};
8use liter_llm::error::LiterLlmError;
9use liter_llm::tower::types::{LlmRequest, LlmResponse};
10use liter_llm::tower::{
11    BudgetConfig, BudgetLayer, BudgetState, CacheConfig, CacheLayer, CooldownLayer, CostTrackingLayer, Enforcement,
12    HealthCheckLayer, LlmService, ModelRateLimitLayer, RateLimitConfig, TracingLayer,
13};
14
15use crate::config::{ModelEntry, ProxyConfig};
16use crate::error::ProxyError;
17
18type Bcs = tower::util::BoxCloneService<LlmRequest, LlmResponse, LiterLlmError>;
19
20/// Thread-safe wrapper around `BoxCloneService`.
21///
22/// Tower's `BoxCloneService` is `Send` but not `Sync`, because `Service::call`
23/// takes `&mut self`. We wrap it in a `Mutex` and clone on each request — the
24/// lock is held only for the duration of `Clone::clone` (a handful of `Arc`
25/// ref-count bumps).
26struct SyncBoxService {
27    inner: Mutex<Bcs>,
28}
29
30impl SyncBoxService {
31    /// Clone the inner service out of the mutex.
32    ///
33    /// # Errors
34    ///
35    /// Returns `ProxyError::internal` if the mutex is poisoned.
36    fn clone_service(&self) -> Result<Bcs, ProxyError> {
37        self.inner
38            .lock()
39            .map(|guard| guard.clone())
40            .map_err(|_| ProxyError::internal("service mutex poisoned"))
41    }
42}
43
44/// A pool of Tower service stacks, one per configured model name.
45///
46/// Each model name maps to a type-erased `BoxCloneService` with the full
47/// middleware stack applied (cache, health check, cooldown, rate limit, cost
48/// tracking, budget, tracing).
49pub struct ServicePool {
50    /// Model name -> Tower service stack.
51    services: HashMap<String, SyncBoxService>,
52    /// Model name -> raw `DefaultClient` (for File/Batch/Response operations
53    /// that bypass the Tower stack).
54    clients: HashMap<String, Arc<DefaultClient>>,
55    /// The first client inserted during construction, for deterministic
56    /// `first_client()` behaviour regardless of `HashMap` iteration order.
57    default_client: Option<Arc<DefaultClient>>,
58}
59
60// SAFETY: `SyncBoxService` wraps a `Mutex<BoxCloneService>` which is `Send + Sync`.
61// `Arc<DefaultClient>` is `Send + Sync`. The compiler verifies these bounds.
62
63impl ServicePool {
64    /// Build a pool from the proxy configuration.
65    ///
66    /// Groups `config.models` by `name` and creates a Tower service stack for
67    /// each unique model name.  When multiple deployments share a name, the
68    /// first entry is used (round-robin load balancing is planned for v2).
69    ///
70    /// # Errors
71    ///
72    /// Returns an error string if a `DefaultClient` cannot be constructed for
73    /// any model entry.
74    pub fn from_config(config: &ProxyConfig) -> Result<Self, String> {
75        // Group model entries by name, preserving insertion order for the
76        // first-entry-wins rule.
77        let mut grouped: HashMap<String, Vec<&ModelEntry>> = HashMap::new();
78        for entry in &config.models {
79            grouped.entry(entry.name.clone()).or_default().push(entry);
80        }
81
82        let mut services = HashMap::new();
83        let mut clients = HashMap::new();
84        let mut default_client: Option<Arc<DefaultClient>> = None;
85
86        for (name, entries) in &grouped {
87            // Use the first entry for now (round-robin is v2).
88            let entry = entries[0];
89
90            let client = build_client(entry, config)?;
91            let client_arc = Arc::new(client);
92
93            // Capture the very first client for deterministic `first_client()`.
94            if default_client.is_none() {
95                default_client = Some(Arc::clone(&client_arc));
96            }
97
98            let svc = build_service_stack(config, Arc::clone(&client_arc));
99
100            services.insert(name.clone(), SyncBoxService { inner: Mutex::new(svc) });
101            clients.insert(name.clone(), client_arc);
102        }
103
104        Ok(Self {
105            services,
106            clients,
107            default_client,
108        })
109    }
110
111    /// Clone and return a Tower service stack for the given model name.
112    ///
113    /// # Errors
114    ///
115    /// Returns `ProxyError::not_found` if no model with that name exists.
116    pub fn get_service(&self, model: &str) -> Result<Bcs, ProxyError> {
117        self.services
118            .get(model)
119            .ok_or_else(|| ProxyError::not_found(format!("model '{model}' not found")))?
120            .clone_service()
121    }
122
123    /// Return a reference to the raw `DefaultClient` for the given model.
124    ///
125    /// Useful for File, Batch, and Response API operations that bypass the
126    /// Tower middleware stack.
127    ///
128    /// # Errors
129    ///
130    /// Returns `ProxyError::not_found` if no model with that name exists.
131    pub fn get_client(&self, model: &str) -> Result<Arc<DefaultClient>, ProxyError> {
132        self.clients
133            .get(model)
134            .cloned()
135            .ok_or_else(|| ProxyError::not_found(format!("model '{model}' not found")))
136    }
137
138    /// Return the first available raw client.
139    ///
140    /// Used by File, Batch, and Response API endpoints that do not carry a
141    /// model field in the request body.
142    pub fn first_client(&self) -> Result<Arc<DefaultClient>, ProxyError> {
143        self.default_client
144            .clone()
145            .ok_or_else(|| ProxyError::service_unavailable("no models configured"))
146    }
147
148    /// Return the names of all available models.
149    pub fn model_names(&self) -> Vec<&str> {
150        self.services.keys().map(String::as_str).collect()
151    }
152
153    /// Return `true` if the pool contains at least one service.
154    pub fn has_any_service(&self) -> bool {
155        !self.services.is_empty()
156    }
157}
158
159/// Build a `DefaultClient` from a `ModelEntry` and global config defaults.
160fn build_client(entry: &ModelEntry, config: &ProxyConfig) -> Result<DefaultClient, String> {
161    let api_key = entry.api_key.as_deref().unwrap_or("");
162
163    let mut builder = ClientConfigBuilder::new(api_key);
164
165    if let Some(ref url) = entry.base_url {
166        builder = builder.base_url(url);
167    }
168
169    let timeout_secs = entry.timeout_secs.unwrap_or(config.general.default_timeout_secs);
170    builder = builder.timeout(Duration::from_secs(timeout_secs));
171    builder = builder.max_retries(config.general.max_retries);
172
173    let client_config = builder.build();
174
175    DefaultClient::new(client_config, Some(&entry.provider_model))
176        .map_err(|e| format!("failed to build client for model '{}': {e}", entry.name))
177}
178
179/// Compose the Tower middleware stack, following the same layering order as
180/// `managed.rs:build_service_stack`:
181///
182/// 1. Cache (innermost)
183/// 2. HealthCheck
184/// 3. Cooldown
185/// 4. RateLimit
186/// 5. CostTracking
187/// 6. Budget
188/// 7. Tracing (outermost)
189fn build_service_stack(config: &ProxyConfig, client: Arc<DefaultClient>) -> Bcs {
190    let base = LlmService::new_from_arc(client);
191    let mut svc: Bcs = tower::util::BoxCloneService::new(base);
192
193    // 1. Cache (innermost).
194    if let Some(ref cache_cfg) = config.cache {
195        let max_entries = cache_cfg.max_entries.unwrap_or(256);
196        let ttl = Duration::from_secs(cache_cfg.ttl_seconds.unwrap_or(300));
197        let tower_cache_cfg = CacheConfig {
198            max_entries,
199            ttl,
200            backend: liter_llm::tower::CacheBackend::Memory,
201        };
202        let layer = CacheLayer::new(tower_cache_cfg);
203        svc = tower::util::BoxCloneService::new(layer.layer(svc));
204    }
205
206    // 2. HealthCheck.
207    if let Some(ref health_cfg) = config.health
208        && let Some(interval_secs) = health_cfg.interval_secs
209    {
210        let layer = HealthCheckLayer::new(Duration::from_secs(interval_secs));
211        svc = tower::util::BoxCloneService::new(layer.layer(svc));
212    }
213
214    // 3. Cooldown.
215    if let Some(ref cooldown_cfg) = config.cooldown {
216        let layer = CooldownLayer::new(Duration::from_secs(cooldown_cfg.duration_secs));
217        svc = tower::util::BoxCloneService::new(layer.layer(svc));
218    }
219
220    // 4. RateLimit.
221    if let Some(ref rl_cfg) = config.rate_limit {
222        let tower_rl_cfg = RateLimitConfig {
223            rpm: rl_cfg.rpm,
224            tpm: rl_cfg.tpm,
225            window: Duration::from_secs(60),
226        };
227        let layer = ModelRateLimitLayer::new(tower_rl_cfg);
228        svc = tower::util::BoxCloneService::new(layer.layer(svc));
229    }
230
231    // 5. CostTracking.
232    if config.general.enable_cost_tracking {
233        svc = tower::util::BoxCloneService::new(CostTrackingLayer.layer(svc));
234    }
235
236    // 6. Budget.
237    if let Some(ref budget_cfg) = config.budget {
238        let enforcement = match budget_cfg.enforcement {
239            crate::config::EnforcementMode::Soft => Enforcement::Soft,
240            crate::config::EnforcementMode::Hard => Enforcement::Hard,
241        };
242        let tower_budget_cfg = BudgetConfig {
243            global_limit: budget_cfg.global_limit,
244            model_limits: budget_cfg.model_limits.clone(),
245            enforcement,
246        };
247        let state = Arc::new(BudgetState::new());
248        let layer = BudgetLayer::new(tower_budget_cfg, state);
249        svc = tower::util::BoxCloneService::new(layer.layer(svc));
250    }
251
252    // 7. Tracing (outermost).
253    if config.general.enable_tracing {
254        svc = tower::util::BoxCloneService::new(TracingLayer.layer(svc));
255    }
256
257    svc
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use crate::config::ProxyConfig;
264
265    fn config_with_one_model() -> ProxyConfig {
266        ProxyConfig::from_toml_str(
267            r#"
268[[models]]
269name = "test-model"
270provider_model = "openai/gpt-4o"
271api_key = "sk-test"
272"#,
273        )
274        .expect("valid TOML")
275    }
276
277    fn config_with_two_models() -> ProxyConfig {
278        ProxyConfig::from_toml_str(
279            r#"
280[[models]]
281name = "model-a"
282provider_model = "openai/gpt-4o"
283api_key = "sk-a"
284
285[[models]]
286name = "model-b"
287provider_model = "anthropic/claude-sonnet-4-20250514"
288api_key = "sk-b"
289"#,
290        )
291        .expect("valid TOML")
292    }
293
294    #[test]
295    fn build_from_empty_config() {
296        let config = ProxyConfig::default();
297        let pool = ServicePool::from_config(&config).expect("empty config should build");
298        assert!(pool.services.is_empty());
299        assert!(pool.clients.is_empty());
300        assert!(!pool.has_any_service());
301    }
302
303    #[test]
304    fn build_from_config_with_one_model() {
305        let config = config_with_one_model();
306        let pool = ServicePool::from_config(&config).expect("should build");
307        assert_eq!(pool.services.len(), 1);
308        assert_eq!(pool.clients.len(), 1);
309        assert!(pool.has_any_service());
310    }
311
312    #[test]
313    fn get_service_for_unknown_model_returns_not_found() {
314        let config = config_with_one_model();
315        let pool = ServicePool::from_config(&config).expect("should build");
316        let result = pool.get_service("nonexistent");
317        assert!(result.is_err());
318        let err = result.unwrap_err();
319        assert!(err.to_string().contains("not found"));
320    }
321
322    #[test]
323    fn get_service_for_known_model_succeeds() {
324        let config = config_with_one_model();
325        let pool = ServicePool::from_config(&config).expect("should build");
326        let result = pool.get_service("test-model");
327        assert!(result.is_ok());
328    }
329
330    #[test]
331    fn get_client_for_known_model_succeeds() {
332        let config = config_with_one_model();
333        let pool = ServicePool::from_config(&config).expect("should build");
334        let result = pool.get_client("test-model");
335        assert!(result.is_ok());
336    }
337
338    #[test]
339    fn get_client_for_unknown_model_returns_not_found() {
340        let config = config_with_one_model();
341        let pool = ServicePool::from_config(&config).expect("should build");
342        let result = pool.get_client("nonexistent");
343        assert!(result.is_err());
344    }
345
346    #[test]
347    fn model_names_returns_correct_list() {
348        let config = config_with_two_models();
349        let pool = ServicePool::from_config(&config).expect("should build");
350        let mut names = pool.model_names();
351        names.sort();
352        assert_eq!(names, vec!["model-a", "model-b"]);
353    }
354
355    #[test]
356    fn has_any_service_returns_false_for_empty_pool() {
357        let config = ProxyConfig::default();
358        let pool = ServicePool::from_config(&config).expect("should build");
359        assert!(!pool.has_any_service());
360    }
361
362    #[test]
363    fn has_any_service_returns_true_for_nonempty_pool() {
364        let config = config_with_one_model();
365        let pool = ServicePool::from_config(&config).expect("should build");
366        assert!(pool.has_any_service());
367    }
368
369    #[tokio::test]
370    async fn build_with_middleware_config() {
371        let config = ProxyConfig::from_toml_str(
372            r#"
373[general]
374enable_cost_tracking = true
375enable_tracing = true
376
377[[models]]
378name = "gpt"
379provider_model = "openai/gpt-4o"
380api_key = "sk-test"
381
382[cache]
383max_entries = 128
384ttl_seconds = 60
385
386[rate_limit]
387rpm = 100
388
389[budget]
390global_limit = 50.0
391enforcement = "soft"
392
393[cooldown]
394duration_secs = 30
395
396[health]
397interval_secs = 10
398"#,
399        )
400        .expect("valid TOML");
401
402        let pool = ServicePool::from_config(&config).expect("should build with middleware");
403        assert!(pool.has_any_service());
404        assert!(pool.get_service("gpt").is_ok());
405    }
406
407    #[test]
408    fn duplicate_model_names_use_first_entry() {
409        let config = ProxyConfig::from_toml_str(
410            r#"
411[[models]]
412name = "gpt"
413provider_model = "openai/gpt-4o"
414api_key = "sk-1"
415
416[[models]]
417name = "gpt"
418provider_model = "azure/gpt-4o"
419api_key = "sk-2"
420"#,
421        )
422        .expect("valid TOML");
423
424        let pool = ServicePool::from_config(&config).expect("should build");
425        // Only one entry in the pool despite two config entries with same name.
426        assert_eq!(pool.services.len(), 1);
427        assert!(pool.get_service("gpt").is_ok());
428    }
429}