1use std::sync::Arc;
11
12use dashmap::DashMap;
13use serde::{Deserialize, Serialize};
14use tracing::{debug, info};
15
16use punch_types::{FighterId, PunchResult};
17
18use crate::metering::{MeteringEngine, SpendPeriod};
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct BudgetLimit {
27 pub max_tokens_per_hour: Option<u64>,
29 pub max_tokens_per_day: Option<u64>,
31 pub max_cost_per_day_usd: Option<f64>,
33 pub max_requests_per_hour: Option<u64>,
35 #[serde(default = "default_warning_threshold")]
37 pub warning_threshold_percent: u8,
38}
39
40fn default_warning_threshold() -> u8 {
41 80
42}
43
44impl Default for BudgetLimit {
45 fn default() -> Self {
46 Self {
47 max_tokens_per_hour: None,
48 max_tokens_per_day: None,
49 max_cost_per_day_usd: None,
50 max_requests_per_hour: None,
51 warning_threshold_percent: default_warning_threshold(),
52 }
53 }
54}
55
56impl BudgetLimit {
57 pub fn has_any_limit(&self) -> bool {
59 self.max_tokens_per_hour.is_some()
60 || self.max_tokens_per_day.is_some()
61 || self.max_cost_per_day_usd.is_some()
62 || self.max_requests_per_hour.is_some()
63 }
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
68pub enum BudgetVerdict {
69 Allowed,
71 Warning {
73 usage_percent: f64,
75 message: String,
77 },
78 Blocked {
80 reason: String,
82 retry_after_secs: u64,
84 },
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct BudgetStatus {
90 pub limits: Option<BudgetLimit>,
92 pub tokens_used_hour: u64,
94 pub tokens_used_day: u64,
96 pub cost_used_day_usd: f64,
98 pub requests_used_hour: u64,
100 pub verdict: BudgetVerdict,
102}
103
104pub struct BudgetEnforcer {
111 metering: Arc<MeteringEngine>,
112 limits: DashMap<FighterId, BudgetLimit>,
113 global_limit: std::sync::RwLock<Option<BudgetLimit>>,
117}
118
119impl BudgetEnforcer {
120 pub fn new(metering: Arc<MeteringEngine>) -> Self {
122 Self {
123 metering,
124 limits: DashMap::new(),
125 global_limit: std::sync::RwLock::new(None),
126 }
127 }
128
129 pub fn set_fighter_limit(&self, fighter_id: FighterId, limit: BudgetLimit) {
131 info!(%fighter_id, "budget limit set for fighter");
132 self.limits.insert(fighter_id, limit);
133 }
134
135 pub fn remove_fighter_limit(&self, fighter_id: &FighterId) {
137 self.limits.remove(fighter_id);
138 }
139
140 pub fn get_fighter_limit(&self, fighter_id: &FighterId) -> Option<BudgetLimit> {
142 self.limits.get(fighter_id).map(|entry| entry.clone())
143 }
144
145 pub fn set_global_limit(&self, limit: BudgetLimit) {
149 info!("global budget limit set");
150 let mut guard = self
151 .global_limit
152 .write()
153 .expect("global_limit lock poisoned");
154 *guard = Some(limit);
155 }
156
157 pub fn clear_global_limit(&self) {
159 let mut guard = self
160 .global_limit
161 .write()
162 .expect("global_limit lock poisoned");
163 *guard = None;
164 }
165
166 pub fn get_global_limit(&self) -> Option<BudgetLimit> {
168 let guard = self
169 .global_limit
170 .read()
171 .expect("global_limit lock poisoned");
172 guard.clone()
173 }
174
175 pub async fn check_budget(&self, fighter_id: &FighterId) -> PunchResult<BudgetVerdict> {
180 let fighter_verdict = if let Some(limit) = self.limits.get(fighter_id) {
182 self.evaluate_limit(fighter_id, &limit, false).await?
183 } else {
184 BudgetVerdict::Allowed
185 };
186
187 if matches!(fighter_verdict, BudgetVerdict::Blocked { .. }) {
189 return Ok(fighter_verdict);
190 }
191
192 let global_verdict = {
194 let global = self.get_global_limit();
195 if let Some(ref limit) = global {
196 self.evaluate_global_limit(limit).await?
197 } else {
198 BudgetVerdict::Allowed
199 }
200 };
201
202 Ok(most_restrictive(fighter_verdict, global_verdict))
204 }
205
206 pub async fn get_fighter_status(&self, fighter_id: &FighterId) -> PunchResult<BudgetStatus> {
208 let limit = self.limits.get(fighter_id).map(|e| e.clone());
209
210 let daily_spend = self
211 .metering
212 .get_spend(fighter_id, SpendPeriod::Day)
213 .await?;
214
215 let verdict = self.check_budget(fighter_id).await?;
216
217 Ok(BudgetStatus {
218 limits: limit,
219 tokens_used_hour: 0, tokens_used_day: 0,
221 cost_used_day_usd: daily_spend,
222 requests_used_hour: 0,
223 verdict,
224 })
225 }
226
227 pub async fn get_global_status(&self) -> PunchResult<BudgetStatus> {
229 let limit = self.get_global_limit();
230
231 let daily_spend = self.metering.get_total_spend(SpendPeriod::Day).await?;
232
233 let global_verdict = if let Some(ref lim) = limit {
234 self.evaluate_global_limit(lim).await?
235 } else {
236 BudgetVerdict::Allowed
237 };
238
239 Ok(BudgetStatus {
240 limits: limit,
241 tokens_used_hour: 0,
242 tokens_used_day: 0,
243 cost_used_day_usd: daily_spend,
244 requests_used_hour: 0,
245 verdict: global_verdict,
246 })
247 }
248
249 async fn evaluate_limit(
251 &self,
252 fighter_id: &FighterId,
253 limit: &BudgetLimit,
254 _is_global: bool,
255 ) -> PunchResult<BudgetVerdict> {
256 if !limit.has_any_limit() {
257 return Ok(BudgetVerdict::Allowed);
258 }
259
260 let threshold = limit.warning_threshold_percent as f64 / 100.0;
261
262 if let Some(max_usd) = limit.max_cost_per_day_usd {
264 let daily_cost = self
265 .metering
266 .get_spend(fighter_id, SpendPeriod::Day)
267 .await?;
268
269 if daily_cost >= max_usd {
270 debug!(%fighter_id, daily_cost, max_usd, "fighter over daily cost budget");
271 return Ok(BudgetVerdict::Blocked {
272 reason: format!(
273 "daily cost budget exceeded: ${:.4} / ${:.4}",
274 daily_cost, max_usd
275 ),
276 retry_after_secs: seconds_until_day_reset(),
277 });
278 }
279
280 let usage_pct = if max_usd > 0.0 {
281 daily_cost / max_usd
282 } else {
283 1.0 };
285 if usage_pct >= threshold {
286 return Ok(BudgetVerdict::Warning {
287 usage_percent: usage_pct * 100.0,
288 message: format!(
289 "approaching daily cost limit: ${:.4} / ${:.4} ({:.0}%)",
290 daily_cost,
291 max_usd,
292 usage_pct * 100.0
293 ),
294 });
295 }
296 }
297
298 if let Some(max_tokens_hour) = limit.max_tokens_per_hour {
300 let hourly_cost = self
301 .metering
302 .get_spend(fighter_id, SpendPeriod::Hour)
303 .await?;
304 let _hourly_cost_cents = (hourly_cost * 100.0) as u64;
307
308 debug!(
311 %fighter_id,
312 max_tokens_hour,
313 "hourly token limit configured (checked via cost proxy)"
314 );
315 }
316
317 Ok(BudgetVerdict::Allowed)
318 }
319
320 async fn evaluate_global_limit(&self, limit: &BudgetLimit) -> PunchResult<BudgetVerdict> {
322 if !limit.has_any_limit() {
323 return Ok(BudgetVerdict::Allowed);
324 }
325
326 let threshold = limit.warning_threshold_percent as f64 / 100.0;
327
328 if let Some(max_usd) = limit.max_cost_per_day_usd {
330 let daily_cost = self.metering.get_total_spend(SpendPeriod::Day).await?;
331
332 if daily_cost >= max_usd {
333 return Ok(BudgetVerdict::Blocked {
334 reason: format!(
335 "global daily cost budget exceeded: ${:.4} / ${:.4}",
336 daily_cost, max_usd
337 ),
338 retry_after_secs: seconds_until_day_reset(),
339 });
340 }
341
342 let usage_pct = if max_usd > 0.0 {
343 daily_cost / max_usd
344 } else {
345 1.0
346 };
347 if usage_pct >= threshold {
348 return Ok(BudgetVerdict::Warning {
349 usage_percent: usage_pct * 100.0,
350 message: format!(
351 "approaching global daily cost limit: ${:.4} / ${:.4} ({:.0}%)",
352 daily_cost,
353 max_usd,
354 usage_pct * 100.0
355 ),
356 });
357 }
358 }
359
360 Ok(BudgetVerdict::Allowed)
361 }
362}
363
364fn most_restrictive(a: BudgetVerdict, b: BudgetVerdict) -> BudgetVerdict {
366 match (&a, &b) {
367 (BudgetVerdict::Blocked { .. }, _) => a,
368 (_, BudgetVerdict::Blocked { .. }) => b,
369 (
370 BudgetVerdict::Warning {
371 usage_percent: pa, ..
372 },
373 BudgetVerdict::Warning {
374 usage_percent: pb, ..
375 },
376 ) => {
377 if pa >= pb {
378 a
379 } else {
380 b
381 }
382 }
383 (BudgetVerdict::Warning { .. }, _) => a,
384 (_, BudgetVerdict::Warning { .. }) => b,
385 _ => BudgetVerdict::Allowed,
386 }
387}
388
389fn seconds_until_day_reset() -> u64 {
391 let now = chrono::Utc::now();
392 let tomorrow = (now + chrono::Duration::days(1))
393 .date_naive()
394 .and_hms_opt(0, 0, 0);
395
396 match tomorrow {
397 Some(t) => {
398 let reset = chrono::DateTime::<chrono::Utc>::from_naive_utc_and_offset(t, chrono::Utc);
399 (reset - now).num_seconds().max(0) as u64
400 }
401 None => 3600, }
403}
404
405#[cfg(test)]
410mod tests {
411 use super::*;
412 use punch_memory::MemorySubstrate;
413 use punch_types::{FighterManifest, FighterStatus, ModelConfig, Provider, WeightClass};
414
415 fn test_manifest() -> FighterManifest {
416 FighterManifest {
417 name: "budget-test".into(),
418 description: "test".into(),
419 model: ModelConfig {
420 provider: Provider::Anthropic,
421 model: "claude-sonnet-4-20250514".into(),
422 api_key_env: None,
423 base_url: None,
424 max_tokens: Some(4096),
425 temperature: Some(0.7),
426 },
427 system_prompt: "test".into(),
428 capabilities: Vec::new(),
429 weight_class: WeightClass::Featherweight,
430 tenant_id: None,
431 }
432 }
433
434 async fn setup() -> (Arc<MeteringEngine>, Arc<MemorySubstrate>) {
435 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
436 let metering = Arc::new(MeteringEngine::new(Arc::clone(&memory)));
437 (metering, memory)
438 }
439
440 async fn setup_fighter(memory: &MemorySubstrate) -> FighterId {
441 let fid = FighterId::new();
442 memory
443 .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
444 .await
445 .expect("save fighter");
446 fid
447 }
448
449 #[tokio::test]
450 async fn under_budget_allowed() {
451 let (metering, memory) = setup().await;
452 let fid = setup_fighter(&memory).await;
453 let enforcer = BudgetEnforcer::new(metering);
454
455 enforcer.set_fighter_limit(
456 fid,
457 BudgetLimit {
458 max_cost_per_day_usd: Some(10.0),
459 ..Default::default()
460 },
461 );
462
463 let verdict = enforcer.check_budget(&fid).await.expect("check budget");
464 assert_eq!(verdict, BudgetVerdict::Allowed);
465 }
466
467 #[tokio::test]
468 async fn at_80_percent_warning() {
469 let (metering, memory) = setup().await;
470 let fid = setup_fighter(&memory).await;
471
472 metering
479 .record_usage(&fid, "claude-sonnet-4-20250514", 50000, 50000)
480 .await
481 .expect("record usage");
482
483 let enforcer = BudgetEnforcer::new(Arc::clone(&metering));
484 enforcer.set_fighter_limit(
485 fid,
486 BudgetLimit {
487 max_cost_per_day_usd: Some(1.0),
488 ..Default::default()
489 },
490 );
491
492 let verdict = enforcer.check_budget(&fid).await.expect("check budget");
493 assert!(
494 matches!(verdict, BudgetVerdict::Warning { .. }),
495 "expected warning at ~90%, got {:?}",
496 verdict
497 );
498 }
499
500 #[tokio::test]
501 async fn over_budget_blocked() {
502 let (metering, memory) = setup().await;
503 let fid = setup_fighter(&memory).await;
504
505 metering
509 .record_usage(&fid, "claude-sonnet-4-20250514", 100_000, 100_000)
510 .await
511 .expect("record usage");
512
513 let enforcer = BudgetEnforcer::new(Arc::clone(&metering));
514 enforcer.set_fighter_limit(
515 fid,
516 BudgetLimit {
517 max_cost_per_day_usd: Some(1.0),
518 ..Default::default()
519 },
520 );
521
522 let verdict = enforcer.check_budget(&fid).await.expect("check budget");
523 assert!(
524 matches!(verdict, BudgetVerdict::Blocked { .. }),
525 "expected blocked, got {:?}",
526 verdict
527 );
528
529 if let BudgetVerdict::Blocked {
530 retry_after_secs, ..
531 } = verdict
532 {
533 assert!(retry_after_secs > 0);
534 }
535 }
536
537 #[tokio::test]
538 async fn budget_resets_at_period_boundary() {
539 let (metering, memory) = setup().await;
542 let fid = setup_fighter(&memory).await;
543 let enforcer = BudgetEnforcer::new(Arc::clone(&metering));
544
545 enforcer.set_fighter_limit(
546 fid,
547 BudgetLimit {
548 max_cost_per_day_usd: Some(1.0),
549 ..Default::default()
550 },
551 );
552
553 let verdict = enforcer.check_budget(&fid).await.expect("check budget");
555 assert_eq!(verdict, BudgetVerdict::Allowed);
556 }
557
558 #[tokio::test]
559 async fn per_fighter_limits_independent() {
560 let (metering, memory) = setup().await;
561 let fid1 = setup_fighter(&memory).await;
562 let fid2 = setup_fighter(&memory).await;
563
564 metering
566 .record_usage(&fid1, "claude-sonnet-4-20250514", 100_000, 100_000)
567 .await
568 .expect("record usage");
569
570 let enforcer = BudgetEnforcer::new(Arc::clone(&metering));
571 enforcer.set_fighter_limit(
572 fid1,
573 BudgetLimit {
574 max_cost_per_day_usd: Some(1.0),
575 ..Default::default()
576 },
577 );
578 enforcer.set_fighter_limit(
579 fid2,
580 BudgetLimit {
581 max_cost_per_day_usd: Some(1.0),
582 ..Default::default()
583 },
584 );
585
586 let v1 = enforcer.check_budget(&fid1).await.expect("check fid1");
587 let v2 = enforcer.check_budget(&fid2).await.expect("check fid2");
588
589 assert!(matches!(v1, BudgetVerdict::Blocked { .. }));
590 assert_eq!(v2, BudgetVerdict::Allowed);
591 }
592
593 #[tokio::test]
594 async fn global_limit_applies_to_all_fighters() {
595 let (metering, memory) = setup().await;
596 let fid = setup_fighter(&memory).await;
597
598 metering
600 .record_usage(&fid, "claude-sonnet-4-20250514", 100_000, 100_000)
601 .await
602 .expect("record usage");
603
604 let enforcer = BudgetEnforcer::new(Arc::clone(&metering));
605 enforcer.set_global_limit(BudgetLimit {
606 max_cost_per_day_usd: Some(1.0),
607 ..Default::default()
608 });
609
610 let fid2 = setup_fighter(&memory).await;
612 let verdict = enforcer.check_budget(&fid2).await.expect("check budget");
613 assert!(
614 matches!(verdict, BudgetVerdict::Blocked { .. }),
615 "global limit should block: {:?}",
616 verdict
617 );
618 }
619
620 #[tokio::test]
621 async fn no_limit_always_allowed() {
622 let (metering, memory) = setup().await;
623 let fid = setup_fighter(&memory).await;
624
625 metering
627 .record_usage(&fid, "claude-sonnet-4-20250514", 1_000_000, 1_000_000)
628 .await
629 .expect("record usage");
630
631 let enforcer = BudgetEnforcer::new(Arc::clone(&metering));
632
633 let verdict = enforcer.check_budget(&fid).await.expect("check budget");
634 assert_eq!(verdict, BudgetVerdict::Allowed);
635 }
636
637 #[tokio::test]
638 async fn zero_limit_always_blocked() {
639 let (metering, memory) = setup().await;
640 let fid = setup_fighter(&memory).await;
641
642 let enforcer = BudgetEnforcer::new(Arc::clone(&metering));
644 enforcer.set_fighter_limit(
645 fid,
646 BudgetLimit {
647 max_cost_per_day_usd: Some(0.0),
648 ..Default::default()
649 },
650 );
651
652 let verdict = enforcer.check_budget(&fid).await.expect("check budget");
653 assert!(
654 matches!(verdict, BudgetVerdict::Blocked { .. }),
655 "zero limit should block: {:?}",
656 verdict
657 );
658 }
659
660 #[tokio::test]
661 async fn multiple_fighters_dont_interfere() {
662 let (metering, memory) = setup().await;
663 let fid1 = setup_fighter(&memory).await;
664 let fid2 = setup_fighter(&memory).await;
665 let fid3 = setup_fighter(&memory).await;
666
667 metering
669 .record_usage(&fid1, "claude-sonnet-4-20250514", 100_000, 100_000)
670 .await
671 .expect("record usage");
672
673 let enforcer = BudgetEnforcer::new(Arc::clone(&metering));
674
675 enforcer.set_fighter_limit(
677 fid1,
678 BudgetLimit {
679 max_cost_per_day_usd: Some(1.0),
680 ..Default::default()
681 },
682 );
683 enforcer.set_fighter_limit(
684 fid2,
685 BudgetLimit {
686 max_cost_per_day_usd: Some(1.0),
687 ..Default::default()
688 },
689 );
690 enforcer.set_fighter_limit(
691 fid3,
692 BudgetLimit {
693 max_cost_per_day_usd: Some(0.50),
694 ..Default::default()
695 },
696 );
697
698 let v1 = enforcer.check_budget(&fid1).await.expect("check fid1");
699 let v2 = enforcer.check_budget(&fid2).await.expect("check fid2");
700 let v3 = enforcer.check_budget(&fid3).await.expect("check fid3");
701
702 assert!(matches!(v1, BudgetVerdict::Blocked { .. }));
703 assert_eq!(v2, BudgetVerdict::Allowed);
704 assert_eq!(v3, BudgetVerdict::Allowed);
705 }
706
707 #[tokio::test]
708 async fn warning_threshold_configurable() {
709 let (metering, memory) = setup().await;
710 let fid = setup_fighter(&memory).await;
711
712 metering
714 .record_usage(&fid, "claude-sonnet-4-20250514", 50000, 50000)
715 .await
716 .expect("record usage");
717
718 let enforcer = BudgetEnforcer::new(Arc::clone(&metering));
719
720 enforcer.set_fighter_limit(
722 fid,
723 BudgetLimit {
724 max_cost_per_day_usd: Some(1.0),
725 warning_threshold_percent: 95,
726 ..Default::default()
727 },
728 );
729
730 let verdict = enforcer.check_budget(&fid).await.expect("check budget");
731 assert_eq!(
732 verdict,
733 BudgetVerdict::Allowed,
734 "95% threshold should not warn at 90%: {:?}",
735 verdict
736 );
737
738 enforcer.set_fighter_limit(
740 fid,
741 BudgetLimit {
742 max_cost_per_day_usd: Some(1.0),
743 warning_threshold_percent: 50,
744 ..Default::default()
745 },
746 );
747
748 let verdict = enforcer.check_budget(&fid).await.expect("check budget");
749 assert!(
750 matches!(verdict, BudgetVerdict::Warning { .. }),
751 "50% threshold should warn at 90%: {:?}",
752 verdict
753 );
754 }
755
756 #[tokio::test]
757 async fn cost_based_budget() {
758 let (metering, memory) = setup().await;
759 let fid = setup_fighter(&memory).await;
760
761 metering
764 .record_usage(&fid, "gpt-4o-mini", 10_000, 10_000)
765 .await
766 .expect("record usage");
767
768 let enforcer = BudgetEnforcer::new(Arc::clone(&metering));
769 enforcer.set_fighter_limit(
770 fid,
771 BudgetLimit {
772 max_cost_per_day_usd: Some(0.01), ..Default::default()
774 },
775 );
776
777 let verdict = enforcer.check_budget(&fid).await.expect("check budget");
778 assert_eq!(
780 verdict,
781 BudgetVerdict::Allowed,
782 "$0.0075 should be under $0.01 limit at 80% threshold: {:?}",
783 verdict
784 );
785 }
786
787 #[test]
788 fn most_restrictive_selects_blocked_over_warning() {
789 let a = BudgetVerdict::Warning {
790 usage_percent: 85.0,
791 message: "warning".to_string(),
792 };
793 let b = BudgetVerdict::Blocked {
794 reason: "blocked".to_string(),
795 retry_after_secs: 100,
796 };
797
798 let result = most_restrictive(a, b);
799 assert!(matches!(result, BudgetVerdict::Blocked { .. }));
800 }
801
802 #[test]
803 fn most_restrictive_selects_warning_over_allowed() {
804 let a = BudgetVerdict::Allowed;
805 let b = BudgetVerdict::Warning {
806 usage_percent: 85.0,
807 message: "warning".to_string(),
808 };
809
810 let result = most_restrictive(a, b);
811 assert!(matches!(result, BudgetVerdict::Warning { .. }));
812 }
813
814 #[test]
815 fn most_restrictive_both_allowed() {
816 let result = most_restrictive(BudgetVerdict::Allowed, BudgetVerdict::Allowed);
817 assert_eq!(result, BudgetVerdict::Allowed);
818 }
819
820 #[test]
821 fn budget_limit_default() {
822 let limit = BudgetLimit::default();
823 assert!(!limit.has_any_limit());
824 assert_eq!(limit.warning_threshold_percent, 80);
825 }
826
827 #[test]
828 fn seconds_until_day_reset_positive() {
829 let secs = seconds_until_day_reset();
830 assert!(secs > 0);
831 assert!(secs <= 86400);
832 }
833}