Skip to main content

litellm_rs/core/router/
budget_routing.rs

1//! Budget-aware routing
2//!
3//! This module provides budget-aware routing capabilities that filter out
4//! providers and models that have exceeded their budget limits.
5
6use crate::core::budget::{BudgetStatus, UnifiedBudgetLimits};
7use std::sync::Arc;
8use tracing::{debug, warn};
9
10/// Budget-aware router wrapper
11///
12/// Wraps routing decisions with budget checking to skip providers
13/// that have exceeded their budget limits.
14#[derive(Clone)]
15pub struct BudgetAwareRouter {
16    /// Budget limits manager
17    budget_limits: Arc<UnifiedBudgetLimits>,
18    /// Whether to log warnings when approaching limits
19    log_warnings: bool,
20    /// Warning threshold percentage (0.0 to 1.0)
21    warning_threshold: f64,
22}
23
24impl BudgetAwareRouter {
25    /// Create a new budget-aware router
26    pub fn new(budget_limits: Arc<UnifiedBudgetLimits>) -> Self {
27        Self {
28            budget_limits,
29            log_warnings: true,
30            warning_threshold: 0.8,
31        }
32    }
33
34    /// Set whether to log warnings
35    pub fn with_warnings(mut self, log_warnings: bool) -> Self {
36        self.log_warnings = log_warnings;
37        self
38    }
39
40    /// Set warning threshold
41    pub fn with_warning_threshold(mut self, threshold: f64) -> Self {
42        self.warning_threshold = threshold.clamp(0.0, 1.0);
43        self
44    }
45
46    /// Get the budget limits manager
47    pub fn budget_limits(&self) -> &UnifiedBudgetLimits {
48        &self.budget_limits
49    }
50
51    /// Filter providers based on budget availability
52    ///
53    /// Returns providers that have not exceeded their budget limits.
54    /// Providers without configured budgets are always included.
55    pub fn filter_available_providers(&self, providers: Vec<String>) -> Vec<String> {
56        let available: Vec<String> = providers
57            .into_iter()
58            .filter(|provider| {
59                let status = self.budget_limits.providers.check_provider_budget(provider);
60                let is_available = status != BudgetStatus::Exceeded;
61
62                if !is_available {
63                    debug!("Provider '{}' filtered out: budget exceeded", provider);
64                } else if self.log_warnings
65                    && status == BudgetStatus::Warning
66                    && let Some(usage) = self.budget_limits.providers.get_provider_usage(provider)
67                {
68                    warn!(
69                        "Provider '{}' approaching budget limit: ${:.2} / ${:.2} ({:.1}%)",
70                        provider, usage.current_spend, usage.max_budget, usage.usage_percentage
71                    );
72                }
73
74                is_available
75            })
76            .collect();
77
78        if available.is_empty() {
79            debug!("All providers have exceeded budget limits");
80        }
81
82        available
83    }
84
85    /// Filter models based on budget availability
86    pub fn filter_available_models(&self, models: Vec<String>) -> Vec<String> {
87        models
88            .into_iter()
89            .filter(|model| {
90                let status = self.budget_limits.models.check_model_budget(model);
91                let is_available = status != BudgetStatus::Exceeded;
92
93                if !is_available {
94                    debug!("Model '{}' filtered out: budget exceeded", model);
95                }
96
97                is_available
98            })
99            .collect()
100    }
101
102    /// Check if a specific provider is available
103    pub fn is_provider_available(&self, provider: &str) -> bool {
104        self.budget_limits.is_provider_available(provider)
105    }
106
107    /// Check if a specific model is available
108    pub fn is_model_available(&self, model: &str) -> bool {
109        self.budget_limits.is_model_available(model)
110    }
111
112    /// Check if a request can be made with the given provider and model
113    pub fn can_make_request(
114        &self,
115        provider: &str,
116        model: &str,
117        estimated_cost: f64,
118    ) -> RequestBudgetCheck {
119        let provider_check = self.check_provider(provider, estimated_cost);
120        let model_check = self.check_model(model, estimated_cost);
121
122        RequestBudgetCheck {
123            allowed: provider_check.allowed && model_check.allowed,
124            provider_status: provider_check.status,
125            model_status: model_check.status,
126            provider_remaining: provider_check.remaining,
127            model_remaining: model_check.remaining,
128            reason: if !provider_check.allowed {
129                Some(format!("Provider '{}' has exceeded budget", provider))
130            } else if !model_check.allowed {
131                Some(format!("Model '{}' has exceeded budget", model))
132            } else {
133                None
134            },
135        }
136    }
137
138    /// Check provider budget
139    fn check_provider(&self, provider: &str, estimated_cost: f64) -> BudgetCheckResult {
140        let can_spend = self
141            .budget_limits
142            .providers
143            .can_provider_spend(provider, estimated_cost);
144        let status = self.budget_limits.providers.check_provider_budget(provider);
145        let remaining = self
146            .budget_limits
147            .providers
148            .get_provider_usage(provider)
149            .map(|u| u.remaining)
150            .unwrap_or(f64::INFINITY);
151
152        BudgetCheckResult {
153            allowed: can_spend,
154            status,
155            remaining,
156        }
157    }
158
159    /// Check model budget
160    fn check_model(&self, model: &str, estimated_cost: f64) -> BudgetCheckResult {
161        let can_spend = self
162            .budget_limits
163            .models
164            .can_model_spend(model, estimated_cost);
165        let status = self.budget_limits.models.check_model_budget(model);
166        let remaining = self
167            .budget_limits
168            .models
169            .get_model_usage(model)
170            .map(|u| u.remaining)
171            .unwrap_or(f64::INFINITY);
172
173        BudgetCheckResult {
174            allowed: can_spend,
175            status,
176            remaining,
177        }
178    }
179
180    /// Record spend after a request completes
181    pub fn record_spend(&self, provider: &str, model: &str, cost: f64) {
182        self.budget_limits.record_spend(provider, model, cost);
183    }
184
185    /// Get a fallback provider when the primary is over budget
186    ///
187    /// Returns the first available provider from the fallback list,
188    /// or None if all providers have exceeded their budgets.
189    pub fn get_fallback_provider(&self, fallbacks: &[String]) -> Option<String> {
190        for provider in fallbacks {
191            if self.is_provider_available(provider) {
192                debug!("Using fallback provider: {}", provider);
193                return Some(provider.clone());
194            }
195        }
196        None
197    }
198
199    /// Get providers sorted by remaining budget (highest first)
200    pub fn get_providers_by_remaining_budget(&self, providers: Vec<String>) -> Vec<String> {
201        let mut provider_budgets: Vec<(String, f64)> = providers
202            .into_iter()
203            .filter_map(|p| {
204                let remaining = self
205                    .budget_limits
206                    .providers
207                    .get_provider_usage(&p)
208                    .map(|u| u.remaining)
209                    .unwrap_or(f64::INFINITY);
210
211                // Only include providers that aren't exceeded
212                if self.is_provider_available(&p) {
213                    Some((p, remaining))
214                } else {
215                    None
216                }
217            })
218            .collect();
219
220        // Sort by remaining budget (highest first)
221        provider_budgets.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
222
223        provider_budgets.into_iter().map(|(p, _)| p).collect()
224    }
225}
226
227/// Internal budget check result
228struct BudgetCheckResult {
229    allowed: bool,
230    status: BudgetStatus,
231    remaining: f64,
232}
233
234/// Result of checking if a request is within budget
235#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
236pub struct RequestBudgetCheck {
237    /// Whether the request is allowed
238    pub allowed: bool,
239    /// Provider budget status
240    pub provider_status: BudgetStatus,
241    /// Model budget status
242    pub model_status: BudgetStatus,
243    /// Remaining provider budget
244    pub provider_remaining: f64,
245    /// Remaining model budget
246    pub model_remaining: f64,
247    /// Reason if not allowed
248    pub reason: Option<String>,
249}
250
251/// Extension trait for adding budget awareness to routers
252pub trait BudgetAwareRouting {
253    /// Filter providers based on budget availability
254    fn filter_by_budget(
255        &self,
256        providers: Vec<String>,
257        budget_router: &BudgetAwareRouter,
258    ) -> Vec<String>;
259}
260
261impl<T> BudgetAwareRouting for T {
262    fn filter_by_budget(
263        &self,
264        providers: Vec<String>,
265        budget_router: &BudgetAwareRouter,
266    ) -> Vec<String> {
267        budget_router.filter_available_providers(providers)
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    use crate::core::budget::{ModelLimitConfig, ProviderLimitConfig, ResetPeriod};
275
276    fn create_test_router() -> BudgetAwareRouter {
277        let limits = Arc::new(UnifiedBudgetLimits::new());
278        BudgetAwareRouter::new(limits)
279    }
280
281    #[test]
282    fn test_budget_aware_router_creation() {
283        let router = create_test_router();
284        assert!(router.log_warnings);
285        assert!((router.warning_threshold - 0.8).abs() < f64::EPSILON);
286    }
287
288    #[test]
289    fn test_filter_available_providers_no_limits() {
290        let router = create_test_router();
291        let providers = vec!["openai".to_string(), "anthropic".to_string()];
292
293        let available = router.filter_available_providers(providers.clone());
294        assert_eq!(available, providers);
295    }
296
297    #[test]
298    fn test_filter_available_providers_with_exceeded() {
299        let limits = Arc::new(UnifiedBudgetLimits::new());
300        limits.providers.set_provider_limit(
301            "openai",
302            ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
303        );
304        limits.providers.set_provider_limit(
305            "anthropic",
306            ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
307        );
308
309        // Exceed openai budget
310        limits.providers.record_provider_spend("openai", 150.0);
311
312        let router = BudgetAwareRouter::new(limits);
313        let providers = vec!["openai".to_string(), "anthropic".to_string()];
314
315        let available = router.filter_available_providers(providers);
316        assert_eq!(available.len(), 1);
317        assert_eq!(available[0], "anthropic");
318    }
319
320    #[test]
321    fn test_is_provider_available() {
322        let limits = Arc::new(UnifiedBudgetLimits::new());
323        limits.providers.set_provider_limit(
324            "openai",
325            ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
326        );
327
328        let router = BudgetAwareRouter::new(limits.clone());
329
330        assert!(router.is_provider_available("openai"));
331        assert!(router.is_provider_available("unknown")); // No limit = available
332
333        limits.providers.record_provider_spend("openai", 150.0);
334        assert!(!router.is_provider_available("openai"));
335    }
336
337    #[test]
338    fn test_can_make_request() {
339        let limits = Arc::new(UnifiedBudgetLimits::new());
340        limits.providers.set_provider_limit(
341            "openai",
342            ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
343        );
344        limits
345            .models
346            .set_model_limit("gpt-4", ModelLimitConfig::new(50.0, ResetPeriod::Monthly));
347
348        let router = BudgetAwareRouter::new(limits.clone());
349
350        // Should allow
351        let check = router.can_make_request("openai", "gpt-4", 10.0);
352        assert!(check.allowed);
353        assert!(check.reason.is_none());
354
355        // Exceed model budget
356        limits.models.record_model_spend("gpt-4", 60.0);
357
358        let check = router.can_make_request("openai", "gpt-4", 10.0);
359        assert!(!check.allowed);
360        assert!(check.reason.is_some());
361        assert!(check.reason.unwrap().contains("gpt-4"));
362    }
363
364    #[test]
365    fn test_record_spend() {
366        let limits = Arc::new(UnifiedBudgetLimits::new());
367        limits.providers.set_provider_limit(
368            "openai",
369            ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
370        );
371        limits
372            .models
373            .set_model_limit("gpt-4", ModelLimitConfig::new(100.0, ResetPeriod::Monthly));
374
375        let router = BudgetAwareRouter::new(limits.clone());
376        router.record_spend("openai", "gpt-4", 25.0);
377
378        let provider_usage = limits.providers.get_provider_usage("openai").unwrap();
379        let model_usage = limits.models.get_model_usage("gpt-4").unwrap();
380
381        assert_eq!(provider_usage.current_spend, 25.0);
382        assert_eq!(model_usage.current_spend, 25.0);
383    }
384
385    #[test]
386    fn test_get_fallback_provider() {
387        let limits = Arc::new(UnifiedBudgetLimits::new());
388        limits.providers.set_provider_limit(
389            "openai",
390            ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
391        );
392        limits.providers.set_provider_limit(
393            "anthropic",
394            ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
395        );
396        limits.providers.set_provider_limit(
397            "google",
398            ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
399        );
400
401        // Exceed openai and anthropic
402        limits.providers.record_provider_spend("openai", 150.0);
403        limits.providers.record_provider_spend("anthropic", 150.0);
404
405        let router = BudgetAwareRouter::new(limits);
406
407        let fallbacks = vec![
408            "openai".to_string(),
409            "anthropic".to_string(),
410            "google".to_string(),
411        ];
412
413        let fallback = router.get_fallback_provider(&fallbacks);
414        assert_eq!(fallback, Some("google".to_string()));
415    }
416
417    #[test]
418    fn test_get_providers_by_remaining_budget() {
419        let limits = Arc::new(UnifiedBudgetLimits::new());
420        limits.providers.set_provider_limit(
421            "openai",
422            ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
423        );
424        limits.providers.set_provider_limit(
425            "anthropic",
426            ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
427        );
428        limits.providers.set_provider_limit(
429            "google",
430            ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
431        );
432
433        // Different spend amounts
434        limits.providers.record_provider_spend("openai", 80.0); // 20 remaining
435        limits.providers.record_provider_spend("anthropic", 30.0); // 70 remaining
436        limits.providers.record_provider_spend("google", 50.0); // 50 remaining
437
438        let router = BudgetAwareRouter::new(limits);
439
440        let providers = vec![
441            "openai".to_string(),
442            "anthropic".to_string(),
443            "google".to_string(),
444        ];
445
446        let sorted = router.get_providers_by_remaining_budget(providers);
447
448        // Should be sorted by remaining budget (highest first)
449        assert_eq!(sorted[0], "anthropic"); // 70 remaining
450        assert_eq!(sorted[1], "google"); // 50 remaining
451        assert_eq!(sorted[2], "openai"); // 20 remaining
452    }
453
454    #[test]
455    fn test_filter_available_models() {
456        let limits = Arc::new(UnifiedBudgetLimits::new());
457        limits
458            .models
459            .set_model_limit("gpt-4", ModelLimitConfig::new(100.0, ResetPeriod::Monthly));
460        limits.models.set_model_limit(
461            "gpt-3.5-turbo",
462            ModelLimitConfig::new(100.0, ResetPeriod::Monthly),
463        );
464
465        // Exceed gpt-4 budget
466        limits.models.record_model_spend("gpt-4", 150.0);
467
468        let router = BudgetAwareRouter::new(limits);
469
470        let models = vec!["gpt-4".to_string(), "gpt-3.5-turbo".to_string()];
471        let available = router.filter_available_models(models);
472
473        assert_eq!(available.len(), 1);
474        assert_eq!(available[0], "gpt-3.5-turbo");
475    }
476}