1use crate::core::budget::{BudgetStatus, UnifiedBudgetLimits};
7use std::sync::Arc;
8use tracing::{debug, warn};
9
10#[derive(Clone)]
15pub struct BudgetAwareRouter {
16 budget_limits: Arc<UnifiedBudgetLimits>,
18 log_warnings: bool,
20 warning_threshold: f64,
22}
23
24impl BudgetAwareRouter {
25 pub fn new(budget_limits: Arc<UnifiedBudgetLimits>) -> Self {
27 Self {
28 budget_limits,
29 log_warnings: true,
30 warning_threshold: 0.8,
31 }
32 }
33
34 pub fn with_warnings(mut self, log_warnings: bool) -> Self {
36 self.log_warnings = log_warnings;
37 self
38 }
39
40 pub fn with_warning_threshold(mut self, threshold: f64) -> Self {
42 self.warning_threshold = threshold.clamp(0.0, 1.0);
43 self
44 }
45
46 pub fn budget_limits(&self) -> &UnifiedBudgetLimits {
48 &self.budget_limits
49 }
50
51 pub fn filter_available_providers(&self, providers: Vec<String>) -> Vec<String> {
56 let available: Vec<String> = providers
57 .into_iter()
58 .filter(|provider| {
59 let status = self.budget_limits.providers.check_provider_budget(provider);
60 let is_available = status != BudgetStatus::Exceeded;
61
62 if !is_available {
63 debug!("Provider '{}' filtered out: budget exceeded", provider);
64 } else if self.log_warnings && status == BudgetStatus::Warning {
65 if let Some(usage) = self.budget_limits.providers.get_provider_usage(provider) {
66 warn!(
67 "Provider '{}' approaching budget limit: ${:.2} / ${:.2} ({:.1}%)",
68 provider, usage.current_spend, usage.max_budget, usage.usage_percentage
69 );
70 }
71 }
72
73 is_available
74 })
75 .collect();
76
77 if available.is_empty() {
78 debug!("All providers have exceeded budget limits");
79 }
80
81 available
82 }
83
84 pub fn filter_available_models(&self, models: Vec<String>) -> Vec<String> {
86 models
87 .into_iter()
88 .filter(|model| {
89 let status = self.budget_limits.models.check_model_budget(model);
90 let is_available = status != BudgetStatus::Exceeded;
91
92 if !is_available {
93 debug!("Model '{}' filtered out: budget exceeded", model);
94 }
95
96 is_available
97 })
98 .collect()
99 }
100
101 pub fn is_provider_available(&self, provider: &str) -> bool {
103 self.budget_limits.is_provider_available(provider)
104 }
105
106 pub fn is_model_available(&self, model: &str) -> bool {
108 self.budget_limits.is_model_available(model)
109 }
110
111 pub fn can_make_request(
113 &self,
114 provider: &str,
115 model: &str,
116 estimated_cost: f64,
117 ) -> RequestBudgetCheck {
118 let provider_check = self.check_provider(provider, estimated_cost);
119 let model_check = self.check_model(model, estimated_cost);
120
121 RequestBudgetCheck {
122 allowed: provider_check.allowed && model_check.allowed,
123 provider_status: provider_check.status,
124 model_status: model_check.status,
125 provider_remaining: provider_check.remaining,
126 model_remaining: model_check.remaining,
127 reason: if !provider_check.allowed {
128 Some(format!("Provider '{}' has exceeded budget", provider))
129 } else if !model_check.allowed {
130 Some(format!("Model '{}' has exceeded budget", model))
131 } else {
132 None
133 },
134 }
135 }
136
137 fn check_provider(&self, provider: &str, estimated_cost: f64) -> BudgetCheckResult {
139 let can_spend = self
140 .budget_limits
141 .providers
142 .can_provider_spend(provider, estimated_cost);
143 let status = self.budget_limits.providers.check_provider_budget(provider);
144 let remaining = self
145 .budget_limits
146 .providers
147 .get_provider_usage(provider)
148 .map(|u| u.remaining)
149 .unwrap_or(f64::INFINITY);
150
151 BudgetCheckResult {
152 allowed: can_spend,
153 status,
154 remaining,
155 }
156 }
157
158 fn check_model(&self, model: &str, estimated_cost: f64) -> BudgetCheckResult {
160 let can_spend = self
161 .budget_limits
162 .models
163 .can_model_spend(model, estimated_cost);
164 let status = self.budget_limits.models.check_model_budget(model);
165 let remaining = self
166 .budget_limits
167 .models
168 .get_model_usage(model)
169 .map(|u| u.remaining)
170 .unwrap_or(f64::INFINITY);
171
172 BudgetCheckResult {
173 allowed: can_spend,
174 status,
175 remaining,
176 }
177 }
178
179 pub fn record_spend(&self, provider: &str, model: &str, cost: f64) {
181 self.budget_limits.record_spend(provider, model, cost);
182 }
183
184 pub fn get_fallback_provider(&self, fallbacks: &[String]) -> Option<String> {
189 for provider in fallbacks {
190 if self.is_provider_available(provider) {
191 debug!("Using fallback provider: {}", provider);
192 return Some(provider.clone());
193 }
194 }
195 None
196 }
197
198 pub fn get_providers_by_remaining_budget(&self, providers: Vec<String>) -> Vec<String> {
200 let mut provider_budgets: Vec<(String, f64)> = providers
201 .into_iter()
202 .filter_map(|p| {
203 let remaining = self
204 .budget_limits
205 .providers
206 .get_provider_usage(&p)
207 .map(|u| u.remaining)
208 .unwrap_or(f64::INFINITY);
209
210 if self.is_provider_available(&p) {
212 Some((p, remaining))
213 } else {
214 None
215 }
216 })
217 .collect();
218
219 provider_budgets.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
221
222 provider_budgets.into_iter().map(|(p, _)| p).collect()
223 }
224}
225
226struct BudgetCheckResult {
228 allowed: bool,
229 status: BudgetStatus,
230 remaining: f64,
231}
232
233#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
235pub struct RequestBudgetCheck {
236 pub allowed: bool,
238 pub provider_status: BudgetStatus,
240 pub model_status: BudgetStatus,
242 pub provider_remaining: f64,
244 pub model_remaining: f64,
246 pub reason: Option<String>,
248}
249
250pub trait BudgetAwareRouting {
252 fn filter_by_budget(
254 &self,
255 providers: Vec<String>,
256 budget_router: &BudgetAwareRouter,
257 ) -> Vec<String>;
258}
259
260impl<T> BudgetAwareRouting for T {
261 fn filter_by_budget(
262 &self,
263 providers: Vec<String>,
264 budget_router: &BudgetAwareRouter,
265 ) -> Vec<String> {
266 budget_router.filter_available_providers(providers)
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use crate::core::budget::{ModelLimitConfig, ProviderLimitConfig, ResetPeriod};
274
275 fn create_test_router() -> BudgetAwareRouter {
276 let limits = Arc::new(UnifiedBudgetLimits::new());
277 BudgetAwareRouter::new(limits)
278 }
279
280 #[test]
281 fn test_budget_aware_router_creation() {
282 let router = create_test_router();
283 assert!(router.log_warnings);
284 assert!((router.warning_threshold - 0.8).abs() < f64::EPSILON);
285 }
286
287 #[test]
288 fn test_filter_available_providers_no_limits() {
289 let router = create_test_router();
290 let providers = vec!["openai".to_string(), "anthropic".to_string()];
291
292 let available = router.filter_available_providers(providers.clone());
293 assert_eq!(available, providers);
294 }
295
296 #[test]
297 fn test_filter_available_providers_with_exceeded() {
298 let limits = Arc::new(UnifiedBudgetLimits::new());
299 limits.providers.set_provider_limit(
300 "openai",
301 ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
302 );
303 limits.providers.set_provider_limit(
304 "anthropic",
305 ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
306 );
307
308 limits.providers.record_provider_spend("openai", 150.0);
310
311 let router = BudgetAwareRouter::new(limits);
312 let providers = vec!["openai".to_string(), "anthropic".to_string()];
313
314 let available = router.filter_available_providers(providers);
315 assert_eq!(available.len(), 1);
316 assert_eq!(available[0], "anthropic");
317 }
318
319 #[test]
320 fn test_is_provider_available() {
321 let limits = Arc::new(UnifiedBudgetLimits::new());
322 limits.providers.set_provider_limit(
323 "openai",
324 ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
325 );
326
327 let router = BudgetAwareRouter::new(limits.clone());
328
329 assert!(router.is_provider_available("openai"));
330 assert!(router.is_provider_available("unknown")); limits.providers.record_provider_spend("openai", 150.0);
333 assert!(!router.is_provider_available("openai"));
334 }
335
336 #[test]
337 fn test_can_make_request() {
338 let limits = Arc::new(UnifiedBudgetLimits::new());
339 limits.providers.set_provider_limit(
340 "openai",
341 ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
342 );
343 limits
344 .models
345 .set_model_limit("gpt-4", ModelLimitConfig::new(50.0, ResetPeriod::Monthly));
346
347 let router = BudgetAwareRouter::new(limits.clone());
348
349 let check = router.can_make_request("openai", "gpt-4", 10.0);
351 assert!(check.allowed);
352 assert!(check.reason.is_none());
353
354 limits.models.record_model_spend("gpt-4", 60.0);
356
357 let check = router.can_make_request("openai", "gpt-4", 10.0);
358 assert!(!check.allowed);
359 assert!(check.reason.is_some());
360 assert!(check.reason.unwrap().contains("gpt-4"));
361 }
362
363 #[test]
364 fn test_record_spend() {
365 let limits = Arc::new(UnifiedBudgetLimits::new());
366 limits.providers.set_provider_limit(
367 "openai",
368 ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
369 );
370 limits
371 .models
372 .set_model_limit("gpt-4", ModelLimitConfig::new(100.0, ResetPeriod::Monthly));
373
374 let router = BudgetAwareRouter::new(limits.clone());
375 router.record_spend("openai", "gpt-4", 25.0);
376
377 let provider_usage = limits.providers.get_provider_usage("openai").unwrap();
378 let model_usage = limits.models.get_model_usage("gpt-4").unwrap();
379
380 assert_eq!(provider_usage.current_spend, 25.0);
381 assert_eq!(model_usage.current_spend, 25.0);
382 }
383
384 #[test]
385 fn test_get_fallback_provider() {
386 let limits = Arc::new(UnifiedBudgetLimits::new());
387 limits.providers.set_provider_limit(
388 "openai",
389 ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
390 );
391 limits.providers.set_provider_limit(
392 "anthropic",
393 ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
394 );
395 limits.providers.set_provider_limit(
396 "google",
397 ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
398 );
399
400 limits.providers.record_provider_spend("openai", 150.0);
402 limits.providers.record_provider_spend("anthropic", 150.0);
403
404 let router = BudgetAwareRouter::new(limits);
405
406 let fallbacks = vec![
407 "openai".to_string(),
408 "anthropic".to_string(),
409 "google".to_string(),
410 ];
411
412 let fallback = router.get_fallback_provider(&fallbacks);
413 assert_eq!(fallback, Some("google".to_string()));
414 }
415
416 #[test]
417 fn test_get_providers_by_remaining_budget() {
418 let limits = Arc::new(UnifiedBudgetLimits::new());
419 limits.providers.set_provider_limit(
420 "openai",
421 ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
422 );
423 limits.providers.set_provider_limit(
424 "anthropic",
425 ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
426 );
427 limits.providers.set_provider_limit(
428 "google",
429 ProviderLimitConfig::new(100.0, ResetPeriod::Monthly),
430 );
431
432 limits.providers.record_provider_spend("openai", 80.0); limits.providers.record_provider_spend("anthropic", 30.0); limits.providers.record_provider_spend("google", 50.0); let router = BudgetAwareRouter::new(limits);
438
439 let providers = vec![
440 "openai".to_string(),
441 "anthropic".to_string(),
442 "google".to_string(),
443 ];
444
445 let sorted = router.get_providers_by_remaining_budget(providers);
446
447 assert_eq!(sorted[0], "anthropic"); assert_eq!(sorted[1], "google"); assert_eq!(sorted[2], "openai"); }
452
453 #[test]
454 fn test_filter_available_models() {
455 let limits = Arc::new(UnifiedBudgetLimits::new());
456 limits
457 .models
458 .set_model_limit("gpt-4", ModelLimitConfig::new(100.0, ResetPeriod::Monthly));
459 limits.models.set_model_limit(
460 "gpt-3.5-turbo",
461 ModelLimitConfig::new(100.0, ResetPeriod::Monthly),
462 );
463
464 limits.models.record_model_spend("gpt-4", 150.0);
466
467 let router = BudgetAwareRouter::new(limits);
468
469 let models = vec!["gpt-4".to_string(), "gpt-3.5-turbo".to_string()];
470 let available = router.filter_available_models(models);
471
472 assert_eq!(available.len(), 1);
473 assert_eq!(available[0], "gpt-3.5-turbo");
474 }
475}