1use crate::core::budget::{BudgetStatus, UnifiedBudgetLimits};
7use std::sync::Arc;
8use tracing::{debug, warn};
9
10#[derive(Clone)]
15pub struct BudgetAwareRouter {
16 budget_limits: Arc<UnifiedBudgetLimits>,
18 log_warnings: bool,
20 warning_threshold: f64,
22}
23
24impl BudgetAwareRouter {
25 pub fn new(budget_limits: Arc<UnifiedBudgetLimits>) -> Self {
27 Self {
28 budget_limits,
29 log_warnings: true,
30 warning_threshold: 0.8,
31 }
32 }
33
34 pub fn with_warnings(mut self, log_warnings: bool) -> Self {
36 self.log_warnings = log_warnings;
37 self
38 }
39
40 pub fn with_warning_threshold(mut self, threshold: f64) -> Self {
42 self.warning_threshold = threshold.clamp(0.0, 1.0);
43 self
44 }
45
46 pub fn budget_limits(&self) -> &UnifiedBudgetLimits {
48 &self.budget_limits
49 }
50
51 pub fn filter_available_providers(&self, providers: Vec<String>) -> Vec<String> {
56 let available: Vec<String> = providers
57 .into_iter()
58 .filter(|provider| {
59 let status = self.budget_limits.providers.check_provider_budget(provider);
60 let is_available = status != BudgetStatus::Exceeded;
61
62 if !is_available {
63 debug!("Provider '{}' filtered out: budget exceeded", provider);
64 } else if self.log_warnings
65 && status == BudgetStatus::Warning
66 && let Some(usage) = self.budget_limits.providers.get_provider_usage(provider)
67 {
68 warn!(
69 "Provider '{}' approaching budget limit: ${:.2} / ${:.2} ({:.1}%)",
70 provider, usage.current_spend, usage.max_budget, usage.usage_percentage
71 );
72 }
73
74 is_available
75 })
76 .collect();
77
78 if available.is_empty() {
79 debug!("All providers have exceeded budget limits");
80 }
81
82 available
83 }
84
85 pub fn filter_available_models(&self, models: Vec<String>) -> Vec<String> {
87 models
88 .into_iter()
89 .filter(|model| {
90 let status = self.budget_limits.models.check_model_budget(model);
91 let is_available = status != BudgetStatus::Exceeded;
92
93 if !is_available {
94 debug!("Model '{}' filtered out: budget exceeded", model);
95 }
96
97 is_available
98 })
99 .collect()
100 }
101
102 pub fn is_provider_available(&self, provider: &str) -> bool {
104 self.budget_limits.is_provider_available(provider)
105 }
106
107 pub fn is_model_available(&self, model: &str) -> bool {
109 self.budget_limits.is_model_available(model)
110 }
111
112 pub fn can_make_request(
114 &self,
115 provider: &str,
116 model: &str,
117 estimated_cost: f64,
118 ) -> RequestBudgetCheck {
119 let provider_check = self.check_provider(provider, estimated_cost);
120 let model_check = self.check_model(model, estimated_cost);
121
122 RequestBudgetCheck {
123 allowed: provider_check.allowed && model_check.allowed,
124 provider_status: provider_check.status,
125 model_status: model_check.status,
126 provider_remaining: provider_check.remaining,
127 model_remaining: model_check.remaining,
128 reason: if !provider_check.allowed {
129 Some(format!("Provider '{}' has exceeded budget", provider))
130 } else if !model_check.allowed {
131 Some(format!("Model '{}' has exceeded budget", model))
132 } else {
133 None
134 },
135 }
136 }
137
138 fn check_provider(&self, provider: &str, estimated_cost: f64) -> BudgetCheckResult {
140 let can_spend = self
141 .budget_limits
142 .providers
143 .can_provider_spend(provider, estimated_cost);
144 let status = self.budget_limits.providers.check_provider_budget(provider);
145 let remaining = self
146 .budget_limits
147 .providers
148 .get_provider_usage(provider)
149 .map(|u| u.remaining)
150 .unwrap_or(f64::INFINITY);
151
152 BudgetCheckResult {
153 allowed: can_spend,
154 status,
155 remaining,
156 }
157 }
158
159 fn check_model(&self, model: &str, estimated_cost: f64) -> BudgetCheckResult {
161 let can_spend = self
162 .budget_limits
163 .models
164 .can_model_spend(model, estimated_cost);
165 let status = self.budget_limits.models.check_model_budget(model);
166 let remaining = self
167 .budget_limits
168 .models
169 .get_model_usage(model)
170 .map(|u| u.remaining)
171 .unwrap_or(f64::INFINITY);
172
173 BudgetCheckResult {
174 allowed: can_spend,
175 status,
176 remaining,
177 }
178 }
179
180 pub fn record_spend(&self, provider: &str, model: &str, cost: f64) {
182 self.budget_limits.record_spend(provider, model, cost);
183 }
184
185 pub fn get_fallback_provider(&self, fallbacks: &[String]) -> Option<String> {
190 for provider in fallbacks {
191 if self.is_provider_available(provider) {
192 debug!("Using fallback provider: {}", provider);
193 return Some(provider.clone());
194 }
195 }
196 None
197 }
198
199 pub fn get_providers_by_remaining_budget(&self, providers: Vec<String>) -> Vec<String> {
201 let mut provider_budgets: Vec<(String, f64)> = providers
202 .into_iter()
203 .filter_map(|p| {
204 let remaining = self
205 .budget_limits
206 .providers
207 .get_provider_usage(&p)
208 .map(|u| u.remaining)
209 .unwrap_or(f64::INFINITY);
210
211 if self.is_provider_available(&p) {
213 Some((p, remaining))
214 } else {
215 None
216 }
217 })
218 .collect();
219
220 provider_budgets.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
222
223 provider_budgets.into_iter().map(|(p, _)| p).collect()
224 }
225}
226
227struct BudgetCheckResult {
229 allowed: bool,
230 status: BudgetStatus,
231 remaining: f64,
232}
233
234#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
236pub struct RequestBudgetCheck {
237 pub allowed: bool,
239 pub provider_status: BudgetStatus,
241 pub model_status: BudgetStatus,
243 pub provider_remaining: f64,
245 pub model_remaining: f64,
247 pub reason: Option<String>,
249}
250
251pub trait BudgetAwareRouting {
253 fn filter_by_budget(
255 &self,
256 providers: Vec<String>,
257 budget_router: &BudgetAwareRouter,
258 ) -> Vec<String>;
259}
260
261impl<T> BudgetAwareRouting for T {
262 fn filter_by_budget(
263 &self,
264 providers: Vec<String>,
265 budget_router: &BudgetAwareRouter,
266 ) -> Vec<String> {
267 budget_router.filter_available_providers(providers)
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274 use crate::core::budget::{ModelLimitConfig, ProviderLimitConfig, ResetPeriod};
275
276 fn create_test_router() -> BudgetAwareRouter {
277 let limits = Arc::new(UnifiedBudgetLimits::new());
278 BudgetAwareRouter::new(limits)
279 }
280
281 #[test]
282 fn test_budget_aware_router_creation() {
283 let router = create_test_router();
284 assert!(router.log_warnings);
285 assert!((router.warning_threshold - 0.8).abs() < f64::EPSILON);
286 }
287
288 #[test]
289 fn test_filter_available_providers_no_limits() {
290 let router = create_test_router();
291 let providers = vec!["openai".to_string(), "anthropic".to_string()];
292
293 let available = router.filter_available_providers(providers.clone());
294 assert_eq!(available, providers);
295 }
296
297 #[test]
298 fn test_filter_available_providers_with_exceeded() {
299 let limits = Arc::new(UnifiedBudgetLimits::new());
300 limits.providers.set_provider_limit(
301 "openai",
302 ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
303 );
304 limits.providers.set_provider_limit(
305 "anthropic",
306 ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
307 );
308
309 limits.providers.record_provider_spend("openai", 150.0);
311
312 let router = BudgetAwareRouter::new(limits);
313 let providers = vec!["openai".to_string(), "anthropic".to_string()];
314
315 let available = router.filter_available_providers(providers);
316 assert_eq!(available.len(), 1);
317 assert_eq!(available[0], "anthropic");
318 }
319
320 #[test]
321 fn test_is_provider_available() {
322 let limits = Arc::new(UnifiedBudgetLimits::new());
323 limits.providers.set_provider_limit(
324 "openai",
325 ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
326 );
327
328 let router = BudgetAwareRouter::new(limits.clone());
329
330 assert!(router.is_provider_available("openai"));
331 assert!(router.is_provider_available("unknown")); limits.providers.record_provider_spend("openai", 150.0);
334 assert!(!router.is_provider_available("openai"));
335 }
336
337 #[test]
338 fn test_can_make_request() {
339 let limits = Arc::new(UnifiedBudgetLimits::new());
340 limits.providers.set_provider_limit(
341 "openai",
342 ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
343 );
344 limits
345 .models
346 .set_model_limit("gpt-4", ModelLimitConfig::new(50.0, ResetPeriod::Monthly));
347
348 let router = BudgetAwareRouter::new(limits.clone());
349
350 let check = router.can_make_request("openai", "gpt-4", 10.0);
352 assert!(check.allowed);
353 assert!(check.reason.is_none());
354
355 limits.models.record_model_spend("gpt-4", 60.0);
357
358 let check = router.can_make_request("openai", "gpt-4", 10.0);
359 assert!(!check.allowed);
360 assert!(check.reason.is_some());
361 assert!(check.reason.unwrap().contains("gpt-4"));
362 }
363
364 #[test]
365 fn test_record_spend() {
366 let limits = Arc::new(UnifiedBudgetLimits::new());
367 limits.providers.set_provider_limit(
368 "openai",
369 ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
370 );
371 limits
372 .models
373 .set_model_limit("gpt-4", ModelLimitConfig::new(100.0, ResetPeriod::Monthly));
374
375 let router = BudgetAwareRouter::new(limits.clone());
376 router.record_spend("openai", "gpt-4", 25.0);
377
378 let provider_usage = limits.providers.get_provider_usage("openai").unwrap();
379 let model_usage = limits.models.get_model_usage("gpt-4").unwrap();
380
381 assert_eq!(provider_usage.current_spend, 25.0);
382 assert_eq!(model_usage.current_spend, 25.0);
383 }
384
385 #[test]
386 fn test_get_fallback_provider() {
387 let limits = Arc::new(UnifiedBudgetLimits::new());
388 limits.providers.set_provider_limit(
389 "openai",
390 ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
391 );
392 limits.providers.set_provider_limit(
393 "anthropic",
394 ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
395 );
396 limits.providers.set_provider_limit(
397 "google",
398 ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
399 );
400
401 limits.providers.record_provider_spend("openai", 150.0);
403 limits.providers.record_provider_spend("anthropic", 150.0);
404
405 let router = BudgetAwareRouter::new(limits);
406
407 let fallbacks = vec![
408 "openai".to_string(),
409 "anthropic".to_string(),
410 "google".to_string(),
411 ];
412
413 let fallback = router.get_fallback_provider(&fallbacks);
414 assert_eq!(fallback, Some("google".to_string()));
415 }
416
417 #[test]
418 fn test_get_providers_by_remaining_budget() {
419 let limits = Arc::new(UnifiedBudgetLimits::new());
420 limits.providers.set_provider_limit(
421 "openai",
422 ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
423 );
424 limits.providers.set_provider_limit(
425 "anthropic",
426 ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
427 );
428 limits.providers.set_provider_limit(
429 "google",
430 ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
431 );
432
433 limits.providers.record_provider_spend("openai", 80.0); limits.providers.record_provider_spend("anthropic", 30.0); limits.providers.record_provider_spend("google", 50.0); let router = BudgetAwareRouter::new(limits);
439
440 let providers = vec![
441 "openai".to_string(),
442 "anthropic".to_string(),
443 "google".to_string(),
444 ];
445
446 let sorted = router.get_providers_by_remaining_budget(providers);
447
448 assert_eq!(sorted[0], "anthropic"); assert_eq!(sorted[1], "google"); assert_eq!(sorted[2], "openai"); }
453
454 #[test]
455 fn test_filter_available_models() {
456 let limits = Arc::new(UnifiedBudgetLimits::new());
457 limits
458 .models
459 .set_model_limit("gpt-4", ModelLimitConfig::new(100.0, ResetPeriod::Monthly));
460 limits.models.set_model_limit(
461 "gpt-3.5-turbo",
462 ModelLimitConfig::new(100.0, ResetPeriod::Monthly),
463 );
464
465 limits.models.record_model_spend("gpt-4", 150.0);
467
468 let router = BudgetAwareRouter::new(limits);
469
470 let models = vec!["gpt-4".to_string(), "gpt-3.5-turbo".to_string()];
471 let available = router.filter_available_models(models);
472
473 assert_eq!(available.len(), 1);
474 assert_eq!(available[0], "gpt-3.5-turbo");
475 }
476}