Skip to main content

litellm_rs/core/router/
budget_routing.rs

1//! Budget-aware routing
2//!
3//! This module provides budget-aware routing capabilities that filter out
4//! providers and models that have exceeded their budget limits.
5
6use crate::core::budget::{BudgetStatus, UnifiedBudgetLimits};
7use std::sync::Arc;
8use tracing::{debug, warn};
9
10/// Budget-aware router wrapper
11///
12/// Wraps routing decisions with budget checking to skip providers
13/// that have exceeded their budget limits.
14#[derive(Clone)]
15pub struct BudgetAwareRouter {
16    /// Budget limits manager
17    budget_limits: Arc<UnifiedBudgetLimits>,
18    /// Whether to log warnings when approaching limits
19    log_warnings: bool,
20    /// Warning threshold percentage (0.0 to 1.0)
21    warning_threshold: f64,
22}
23
24impl BudgetAwareRouter {
25    /// Create a new budget-aware router
26    pub fn new(budget_limits: Arc<UnifiedBudgetLimits>) -> Self {
27        Self {
28            budget_limits,
29            log_warnings: true,
30            warning_threshold: 0.8,
31        }
32    }
33
34    /// Set whether to log warnings
35    pub fn with_warnings(mut self, log_warnings: bool) -> Self {
36        self.log_warnings = log_warnings;
37        self
38    }
39
40    /// Set warning threshold
41    pub fn with_warning_threshold(mut self, threshold: f64) -> Self {
42        self.warning_threshold = threshold.clamp(0.0, 1.0);
43        self
44    }
45
46    /// Get the budget limits manager
47    pub fn budget_limits(&self) -> &UnifiedBudgetLimits {
48        &self.budget_limits
49    }
50
51    /// Filter providers based on budget availability
52    ///
53    /// Returns providers that have not exceeded their budget limits.
54    /// Providers without configured budgets are always included.
55    pub fn filter_available_providers(&self, providers: Vec<String>) -> Vec<String> {
56        let available: Vec<String> = providers
57            .into_iter()
58            .filter(|provider| {
59                let status = self.budget_limits.providers.check_provider_budget(provider);
60                let is_available = status != BudgetStatus::Exceeded;
61
62                if !is_available {
63                    debug!("Provider '{}' filtered out: budget exceeded", provider);
64                } else if self.log_warnings && status == BudgetStatus::Warning {
65                    if let Some(usage) = self.budget_limits.providers.get_provider_usage(provider) {
66                        warn!(
67                            "Provider '{}' approaching budget limit: ${:.2} / ${:.2} ({:.1}%)",
68                            provider, usage.current_spend, usage.max_budget, usage.usage_percentage
69                        );
70                    }
71                }
72
73                is_available
74            })
75            .collect();
76
77        if available.is_empty() {
78            debug!("All providers have exceeded budget limits");
79        }
80
81        available
82    }
83
84    /// Filter models based on budget availability
85    pub fn filter_available_models(&self, models: Vec<String>) -> Vec<String> {
86        models
87            .into_iter()
88            .filter(|model| {
89                let status = self.budget_limits.models.check_model_budget(model);
90                let is_available = status != BudgetStatus::Exceeded;
91
92                if !is_available {
93                    debug!("Model '{}' filtered out: budget exceeded", model);
94                }
95
96                is_available
97            })
98            .collect()
99    }
100
101    /// Check if a specific provider is available
102    pub fn is_provider_available(&self, provider: &str) -> bool {
103        self.budget_limits.is_provider_available(provider)
104    }
105
106    /// Check if a specific model is available
107    pub fn is_model_available(&self, model: &str) -> bool {
108        self.budget_limits.is_model_available(model)
109    }
110
111    /// Check if a request can be made with the given provider and model
112    pub fn can_make_request(
113        &self,
114        provider: &str,
115        model: &str,
116        estimated_cost: f64,
117    ) -> RequestBudgetCheck {
118        let provider_check = self.check_provider(provider, estimated_cost);
119        let model_check = self.check_model(model, estimated_cost);
120
121        RequestBudgetCheck {
122            allowed: provider_check.allowed && model_check.allowed,
123            provider_status: provider_check.status,
124            model_status: model_check.status,
125            provider_remaining: provider_check.remaining,
126            model_remaining: model_check.remaining,
127            reason: if !provider_check.allowed {
128                Some(format!("Provider '{}' has exceeded budget", provider))
129            } else if !model_check.allowed {
130                Some(format!("Model '{}' has exceeded budget", model))
131            } else {
132                None
133            },
134        }
135    }
136
137    /// Check provider budget
138    fn check_provider(&self, provider: &str, estimated_cost: f64) -> BudgetCheckResult {
139        let can_spend = self
140            .budget_limits
141            .providers
142            .can_provider_spend(provider, estimated_cost);
143        let status = self.budget_limits.providers.check_provider_budget(provider);
144        let remaining = self
145            .budget_limits
146            .providers
147            .get_provider_usage(provider)
148            .map(|u| u.remaining)
149            .unwrap_or(f64::INFINITY);
150
151        BudgetCheckResult {
152            allowed: can_spend,
153            status,
154            remaining,
155        }
156    }
157
158    /// Check model budget
159    fn check_model(&self, model: &str, estimated_cost: f64) -> BudgetCheckResult {
160        let can_spend = self
161            .budget_limits
162            .models
163            .can_model_spend(model, estimated_cost);
164        let status = self.budget_limits.models.check_model_budget(model);
165        let remaining = self
166            .budget_limits
167            .models
168            .get_model_usage(model)
169            .map(|u| u.remaining)
170            .unwrap_or(f64::INFINITY);
171
172        BudgetCheckResult {
173            allowed: can_spend,
174            status,
175            remaining,
176        }
177    }
178
179    /// Record spend after a request completes
180    pub fn record_spend(&self, provider: &str, model: &str, cost: f64) {
181        self.budget_limits.record_spend(provider, model, cost);
182    }
183
184    /// Get a fallback provider when the primary is over budget
185    ///
186    /// Returns the first available provider from the fallback list,
187    /// or None if all providers have exceeded their budgets.
188    pub fn get_fallback_provider(&self, fallbacks: &[String]) -> Option<String> {
189        for provider in fallbacks {
190            if self.is_provider_available(provider) {
191                debug!("Using fallback provider: {}", provider);
192                return Some(provider.clone());
193            }
194        }
195        None
196    }
197
198    /// Get providers sorted by remaining budget (highest first)
199    pub fn get_providers_by_remaining_budget(&self, providers: Vec<String>) -> Vec<String> {
200        let mut provider_budgets: Vec<(String, f64)> = providers
201            .into_iter()
202            .filter_map(|p| {
203                let remaining = self
204                    .budget_limits
205                    .providers
206                    .get_provider_usage(&p)
207                    .map(|u| u.remaining)
208                    .unwrap_or(f64::INFINITY);
209
210                // Only include providers that aren't exceeded
211                if self.is_provider_available(&p) {
212                    Some((p, remaining))
213                } else {
214                    None
215                }
216            })
217            .collect();
218
219        // Sort by remaining budget (highest first)
220        provider_budgets.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
221
222        provider_budgets.into_iter().map(|(p, _)| p).collect()
223    }
224}
225
226/// Internal budget check result
227struct BudgetCheckResult {
228    allowed: bool,
229    status: BudgetStatus,
230    remaining: f64,
231}
232
233/// Result of checking if a request is within budget
234#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
235pub struct RequestBudgetCheck {
236    /// Whether the request is allowed
237    pub allowed: bool,
238    /// Provider budget status
239    pub provider_status: BudgetStatus,
240    /// Model budget status
241    pub model_status: BudgetStatus,
242    /// Remaining provider budget
243    pub provider_remaining: f64,
244    /// Remaining model budget
245    pub model_remaining: f64,
246    /// Reason if not allowed
247    pub reason: Option<String>,
248}
249
250/// Extension trait for adding budget awareness to routers
251pub trait BudgetAwareRouting {
252    /// Filter providers based on budget availability
253    fn filter_by_budget(
254        &self,
255        providers: Vec<String>,
256        budget_router: &BudgetAwareRouter,
257    ) -> Vec<String>;
258}
259
260impl<T> BudgetAwareRouting for T {
261    fn filter_by_budget(
262        &self,
263        providers: Vec<String>,
264        budget_router: &BudgetAwareRouter,
265    ) -> Vec<String> {
266        budget_router.filter_available_providers(providers)
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use crate::core::budget::{ModelLimitConfig, ProviderLimitConfig, ResetPeriod};
274
275    fn create_test_router() -> BudgetAwareRouter {
276        let limits = Arc::new(UnifiedBudgetLimits::new());
277        BudgetAwareRouter::new(limits)
278    }
279
280    #[test]
281    fn test_budget_aware_router_creation() {
282        let router = create_test_router();
283        assert!(router.log_warnings);
284        assert!((router.warning_threshold - 0.8).abs() < f64::EPSILON);
285    }
286
287    #[test]
288    fn test_filter_available_providers_no_limits() {
289        let router = create_test_router();
290        let providers = vec!["openai".to_string(), "anthropic".to_string()];
291
292        let available = router.filter_available_providers(providers.clone());
293        assert_eq!(available, providers);
294    }
295
296    #[test]
297    fn test_filter_available_providers_with_exceeded() {
298        let limits = Arc::new(UnifiedBudgetLimits::new());
299        limits.providers.set_provider_limit(
300            "openai",
301            ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
302        );
303        limits.providers.set_provider_limit(
304            "anthropic",
305            ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
306        );
307
308        // Exceed openai budget
309        limits.providers.record_provider_spend("openai", 150.0);
310
311        let router = BudgetAwareRouter::new(limits);
312        let providers = vec!["openai".to_string(), "anthropic".to_string()];
313
314        let available = router.filter_available_providers(providers);
315        assert_eq!(available.len(), 1);
316        assert_eq!(available[0], "anthropic");
317    }
318
319    #[test]
320    fn test_is_provider_available() {
321        let limits = Arc::new(UnifiedBudgetLimits::new());
322        limits.providers.set_provider_limit(
323            "openai",
324            ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
325        );
326
327        let router = BudgetAwareRouter::new(limits.clone());
328
329        assert!(router.is_provider_available("openai"));
330        assert!(router.is_provider_available("unknown")); // No limit = available
331
332        limits.providers.record_provider_spend("openai", 150.0);
333        assert!(!router.is_provider_available("openai"));
334    }
335
336    #[test]
337    fn test_can_make_request() {
338        let limits = Arc::new(UnifiedBudgetLimits::new());
339        limits.providers.set_provider_limit(
340            "openai",
341            ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
342        );
343        limits
344            .models
345            .set_model_limit("gpt-4", ModelLimitConfig::new(50.0, ResetPeriod::Monthly));
346
347        let router = BudgetAwareRouter::new(limits.clone());
348
349        // Should allow
350        let check = router.can_make_request("openai", "gpt-4", 10.0);
351        assert!(check.allowed);
352        assert!(check.reason.is_none());
353
354        // Exceed model budget
355        limits.models.record_model_spend("gpt-4", 60.0);
356
357        let check = router.can_make_request("openai", "gpt-4", 10.0);
358        assert!(!check.allowed);
359        assert!(check.reason.is_some());
360        assert!(check.reason.unwrap().contains("gpt-4"));
361    }
362
363    #[test]
364    fn test_record_spend() {
365        let limits = Arc::new(UnifiedBudgetLimits::new());
366        limits.providers.set_provider_limit(
367            "openai",
368            ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
369        );
370        limits
371            .models
372            .set_model_limit("gpt-4", ModelLimitConfig::new(100.0, ResetPeriod::Monthly));
373
374        let router = BudgetAwareRouter::new(limits.clone());
375        router.record_spend("openai", "gpt-4", 25.0);
376
377        let provider_usage = limits.providers.get_provider_usage("openai").unwrap();
378        let model_usage = limits.models.get_model_usage("gpt-4").unwrap();
379
380        assert_eq!(provider_usage.current_spend, 25.0);
381        assert_eq!(model_usage.current_spend, 25.0);
382    }
383
384    #[test]
385    fn test_get_fallback_provider() {
386        let limits = Arc::new(UnifiedBudgetLimits::new());
387        limits.providers.set_provider_limit(
388            "openai",
389            ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
390        );
391        limits.providers.set_provider_limit(
392            "anthropic",
393            ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
394        );
395        limits.providers.set_provider_limit(
396            "google",
397            ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
398        );
399
400        // Exceed openai and anthropic
401        limits.providers.record_provider_spend("openai", 150.0);
402        limits.providers.record_provider_spend("anthropic", 150.0);
403
404        let router = BudgetAwareRouter::new(limits);
405
406        let fallbacks = vec![
407            "openai".to_string(),
408            "anthropic".to_string(),
409            "google".to_string(),
410        ];
411
412        let fallback = router.get_fallback_provider(&fallbacks);
413        assert_eq!(fallback, Some("google".to_string()));
414    }
415
416    #[test]
417    fn test_get_providers_by_remaining_budget() {
418        let limits = Arc::new(UnifiedBudgetLimits::new());
419        limits.providers.set_provider_limit(
420            "openai",
421            ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
422        );
423        limits.providers.set_provider_limit(
424            "anthropic",
425            ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
426        );
427        limits.providers.set_provider_limit(
428            "google",
429            ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
430        );
431
432        // Different spend amounts
433        limits.providers.record_provider_spend("openai", 80.0); // 20 remaining
434        limits.providers.record_provider_spend("anthropic", 30.0); // 70 remaining
435        limits.providers.record_provider_spend("google", 50.0); // 50 remaining
436
437        let router = BudgetAwareRouter::new(limits);
438
439        let providers = vec![
440            "openai".to_string(),
441            "anthropic".to_string(),
442            "google".to_string(),
443        ];
444
445        let sorted = router.get_providers_by_remaining_budget(providers);
446
447        // Should be sorted by remaining budget (highest first)
448        assert_eq!(sorted[0], "anthropic"); // 70 remaining
449        assert_eq!(sorted[1], "google"); // 50 remaining
450        assert_eq!(sorted[2], "openai"); // 20 remaining
451    }
452
453    #[test]
454    fn test_filter_available_models() {
455        let limits = Arc::new(UnifiedBudgetLimits::new());
456        limits
457            .models
458            .set_model_limit("gpt-4", ModelLimitConfig::new(100.0, ResetPeriod::Monthly));
459        limits.models.set_model_limit(
460            "gpt-3.5-turbo",
461            ModelLimitConfig::new(100.0, ResetPeriod::Monthly),
462        );
463
464        // Exceed gpt-4 budget
465        limits.models.record_model_spend("gpt-4", 150.0);
466
467        let router = BudgetAwareRouter::new(limits);
468
469        let models = vec!["gpt-4".to_string(), "gpt-3.5-turbo".to_string()];
470        let available = router.filter_available_models(models);
471
472        assert_eq!(available.len(), 1);
473        assert_eq!(available[0], "gpt-3.5-turbo");
474    }
475}