1use std::sync::Arc;
11
12use dashmap::DashMap;
13use serde::{Deserialize, Serialize};
14use tracing::{debug, info};
15
16use punch_types::{FighterId, PunchResult};
17
18use crate::metering::{MeteringEngine, SpendPeriod};
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct BudgetLimit {
27 pub max_tokens_per_hour: Option<u64>,
29 pub max_tokens_per_day: Option<u64>,
31 pub max_cost_per_day_cents: Option<u64>,
33 pub max_requests_per_hour: Option<u64>,
35 #[serde(default = "default_warning_threshold")]
37 pub warning_threshold_percent: u8,
38}
39
40fn default_warning_threshold() -> u8 {
41 80
42}
43
44impl Default for BudgetLimit {
45 fn default() -> Self {
46 Self {
47 max_tokens_per_hour: None,
48 max_tokens_per_day: None,
49 max_cost_per_day_cents: None,
50 max_requests_per_hour: None,
51 warning_threshold_percent: default_warning_threshold(),
52 }
53 }
54}
55
56impl BudgetLimit {
57 pub fn has_any_limit(&self) -> bool {
59 self.max_tokens_per_hour.is_some()
60 || self.max_tokens_per_day.is_some()
61 || self.max_cost_per_day_cents.is_some()
62 || self.max_requests_per_hour.is_some()
63 }
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
68pub enum BudgetVerdict {
69 Allowed,
71 Warning {
73 usage_percent: f64,
75 message: String,
77 },
78 Blocked {
80 reason: String,
82 retry_after_secs: u64,
84 },
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct BudgetStatus {
90 pub limits: Option<BudgetLimit>,
92 pub tokens_used_hour: u64,
94 pub tokens_used_day: u64,
96 pub cost_used_day_cents: u64,
98 pub requests_used_hour: u64,
100 pub verdict: BudgetVerdict,
102}
103
104pub struct BudgetEnforcer {
111 metering: Arc<MeteringEngine>,
112 limits: DashMap<FighterId, BudgetLimit>,
113 global_limit: Arc<tokio::sync::RwLock<Option<BudgetLimit>>>,
114}
115
116impl BudgetEnforcer {
117 pub fn new(metering: Arc<MeteringEngine>) -> Self {
119 Self {
120 metering,
121 limits: DashMap::new(),
122 global_limit: Arc::new(tokio::sync::RwLock::new(None)),
123 }
124 }
125
126 pub fn set_fighter_limit(&self, fighter_id: FighterId, limit: BudgetLimit) {
128 info!(%fighter_id, "budget limit set for fighter");
129 self.limits.insert(fighter_id, limit);
130 }
131
132 pub fn remove_fighter_limit(&self, fighter_id: &FighterId) {
134 self.limits.remove(fighter_id);
135 }
136
137 pub fn get_fighter_limit(&self, fighter_id: &FighterId) -> Option<BudgetLimit> {
139 self.limits.get(fighter_id).map(|entry| entry.clone())
140 }
141
142 pub async fn set_global_limit(&self, limit: BudgetLimit) {
144 info!("global budget limit set");
145 let mut guard = self.global_limit.write().await;
146 *guard = Some(limit);
147 }
148
149 pub async fn clear_global_limit(&self) {
151 let mut guard = self.global_limit.write().await;
152 *guard = None;
153 }
154
155 pub async fn get_global_limit(&self) -> Option<BudgetLimit> {
157 let guard = self.global_limit.read().await;
158 guard.clone()
159 }
160
161 pub async fn check_budget(&self, fighter_id: &FighterId) -> PunchResult<BudgetVerdict> {
166 let fighter_verdict = if let Some(limit) = self.limits.get(fighter_id) {
168 self.evaluate_limit(fighter_id, &limit, false).await?
169 } else {
170 BudgetVerdict::Allowed
171 };
172
173 if matches!(fighter_verdict, BudgetVerdict::Blocked { .. }) {
175 return Ok(fighter_verdict);
176 }
177
178 let global_verdict = {
180 let guard = self.global_limit.read().await;
181 if let Some(ref limit) = *guard {
182 self.evaluate_global_limit(limit).await?
183 } else {
184 BudgetVerdict::Allowed
185 }
186 };
187
188 Ok(most_restrictive(fighter_verdict, global_verdict))
190 }
191
192 pub async fn get_fighter_status(&self, fighter_id: &FighterId) -> PunchResult<BudgetStatus> {
194 let limit = self.limits.get(fighter_id).map(|e| e.clone());
195
196 let daily_spend = self
197 .metering
198 .get_spend(fighter_id, SpendPeriod::Day)
199 .await?;
200
201 let verdict = self.check_budget(fighter_id).await?;
202
203 Ok(BudgetStatus {
204 limits: limit,
205 tokens_used_hour: 0, tokens_used_day: 0,
207 cost_used_day_cents: (daily_spend * 100.0) as u64,
208 requests_used_hour: 0,
209 verdict,
210 })
211 }
212
213 pub async fn get_global_status(&self) -> PunchResult<BudgetStatus> {
215 let limit = self.get_global_limit().await;
216
217 let daily_spend = self.metering.get_total_spend(SpendPeriod::Day).await?;
218
219 let global_verdict = if let Some(ref lim) = limit {
220 self.evaluate_global_limit(lim).await?
221 } else {
222 BudgetVerdict::Allowed
223 };
224
225 Ok(BudgetStatus {
226 limits: limit,
227 tokens_used_hour: 0,
228 tokens_used_day: 0,
229 cost_used_day_cents: (daily_spend * 100.0) as u64,
230 requests_used_hour: 0,
231 verdict: global_verdict,
232 })
233 }
234
235 async fn evaluate_limit(
237 &self,
238 fighter_id: &FighterId,
239 limit: &BudgetLimit,
240 _is_global: bool,
241 ) -> PunchResult<BudgetVerdict> {
242 if !limit.has_any_limit() {
243 return Ok(BudgetVerdict::Allowed);
244 }
245
246 let threshold = limit.warning_threshold_percent as f64 / 100.0;
247
248 if let Some(max_cents) = limit.max_cost_per_day_cents {
250 let daily_cost = self
251 .metering
252 .get_spend(fighter_id, SpendPeriod::Day)
253 .await?;
254 let daily_cents = (daily_cost * 100.0) as u64;
255
256 if daily_cents >= max_cents {
257 debug!(%fighter_id, daily_cents, max_cents, "fighter over daily cost budget");
258 return Ok(BudgetVerdict::Blocked {
259 reason: format!(
260 "daily cost budget exceeded: {}c / {}c",
261 daily_cents, max_cents
262 ),
263 retry_after_secs: seconds_until_day_reset(),
264 });
265 }
266
267 let usage_pct = daily_cents as f64 / max_cents as f64;
268 if usage_pct >= threshold {
269 return Ok(BudgetVerdict::Warning {
270 usage_percent: usage_pct * 100.0,
271 message: format!(
272 "approaching daily cost limit: {}c / {}c ({:.0}%)",
273 daily_cents,
274 max_cents,
275 usage_pct * 100.0
276 ),
277 });
278 }
279 }
280
281 if let Some(max_tokens_hour) = limit.max_tokens_per_hour {
283 let hourly_cost = self
284 .metering
285 .get_spend(fighter_id, SpendPeriod::Hour)
286 .await?;
287 let _hourly_cost_cents = (hourly_cost * 100.0) as u64;
290
291 debug!(
294 %fighter_id,
295 max_tokens_hour,
296 "hourly token limit configured (checked via cost proxy)"
297 );
298 }
299
300 Ok(BudgetVerdict::Allowed)
301 }
302
303 async fn evaluate_global_limit(&self, limit: &BudgetLimit) -> PunchResult<BudgetVerdict> {
305 if !limit.has_any_limit() {
306 return Ok(BudgetVerdict::Allowed);
307 }
308
309 let threshold = limit.warning_threshold_percent as f64 / 100.0;
310
311 if let Some(max_cents) = limit.max_cost_per_day_cents {
313 let daily_cost = self.metering.get_total_spend(SpendPeriod::Day).await?;
314 let daily_cents = (daily_cost * 100.0) as u64;
315
316 if daily_cents >= max_cents {
317 return Ok(BudgetVerdict::Blocked {
318 reason: format!(
319 "global daily cost budget exceeded: {}c / {}c",
320 daily_cents, max_cents
321 ),
322 retry_after_secs: seconds_until_day_reset(),
323 });
324 }
325
326 let usage_pct = daily_cents as f64 / max_cents as f64;
327 if usage_pct >= threshold {
328 return Ok(BudgetVerdict::Warning {
329 usage_percent: usage_pct * 100.0,
330 message: format!(
331 "approaching global daily cost limit: {}c / {}c ({:.0}%)",
332 daily_cents,
333 max_cents,
334 usage_pct * 100.0
335 ),
336 });
337 }
338 }
339
340 Ok(BudgetVerdict::Allowed)
341 }
342}
343
344fn most_restrictive(a: BudgetVerdict, b: BudgetVerdict) -> BudgetVerdict {
346 match (&a, &b) {
347 (BudgetVerdict::Blocked { .. }, _) => a,
348 (_, BudgetVerdict::Blocked { .. }) => b,
349 (
350 BudgetVerdict::Warning {
351 usage_percent: pa, ..
352 },
353 BudgetVerdict::Warning {
354 usage_percent: pb, ..
355 },
356 ) => {
357 if pa >= pb {
358 a
359 } else {
360 b
361 }
362 }
363 (BudgetVerdict::Warning { .. }, _) => a,
364 (_, BudgetVerdict::Warning { .. }) => b,
365 _ => BudgetVerdict::Allowed,
366 }
367}
368
369fn seconds_until_day_reset() -> u64 {
371 let now = chrono::Utc::now();
372 let tomorrow = (now + chrono::Duration::days(1))
373 .date_naive()
374 .and_hms_opt(0, 0, 0);
375
376 match tomorrow {
377 Some(t) => {
378 let reset = chrono::DateTime::<chrono::Utc>::from_naive_utc_and_offset(t, chrono::Utc);
379 (reset - now).num_seconds().max(0) as u64
380 }
381 None => 3600, }
383}
384
385#[cfg(test)]
390mod tests {
391 use super::*;
392 use punch_memory::MemorySubstrate;
393 use punch_types::{FighterManifest, FighterStatus, ModelConfig, Provider, WeightClass};
394
395 fn test_manifest() -> FighterManifest {
396 FighterManifest {
397 name: "budget-test".into(),
398 description: "test".into(),
399 model: ModelConfig {
400 provider: Provider::Anthropic,
401 model: "claude-sonnet-4-20250514".into(),
402 api_key_env: None,
403 base_url: None,
404 max_tokens: Some(4096),
405 temperature: Some(0.7),
406 },
407 system_prompt: "test".into(),
408 capabilities: Vec::new(),
409 weight_class: WeightClass::Featherweight,
410 tenant_id: None,
411 }
412 }
413
414 async fn setup() -> (Arc<MeteringEngine>, Arc<MemorySubstrate>) {
415 let memory = Arc::new(MemorySubstrate::in_memory().expect("in-memory substrate"));
416 let metering = Arc::new(MeteringEngine::new(Arc::clone(&memory)));
417 (metering, memory)
418 }
419
420 async fn setup_fighter(memory: &MemorySubstrate) -> FighterId {
421 let fid = FighterId::new();
422 memory
423 .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
424 .await
425 .expect("save fighter");
426 fid
427 }
428
429 #[tokio::test]
430 async fn under_budget_allowed() {
431 let (metering, memory) = setup().await;
432 let fid = setup_fighter(&memory).await;
433 let enforcer = BudgetEnforcer::new(metering);
434
435 enforcer.set_fighter_limit(
436 fid,
437 BudgetLimit {
438 max_cost_per_day_cents: Some(1000), ..Default::default()
440 },
441 );
442
443 let verdict = enforcer.check_budget(&fid).await.expect("check budget");
444 assert_eq!(verdict, BudgetVerdict::Allowed);
445 }
446
447 #[tokio::test]
448 async fn at_80_percent_warning() {
449 let (metering, memory) = setup().await;
450 let fid = setup_fighter(&memory).await;
451
452 metering
459 .record_usage(&fid, "claude-sonnet-4-20250514", 50000, 50000)
460 .await
461 .expect("record usage");
462
463 let enforcer = BudgetEnforcer::new(Arc::clone(&metering));
464 enforcer.set_fighter_limit(
465 fid,
466 BudgetLimit {
467 max_cost_per_day_cents: Some(100), ..Default::default()
469 },
470 );
471
472 let verdict = enforcer.check_budget(&fid).await.expect("check budget");
473 assert!(
474 matches!(verdict, BudgetVerdict::Warning { .. }),
475 "expected warning at ~90%, got {:?}",
476 verdict
477 );
478 }
479
480 #[tokio::test]
481 async fn over_budget_blocked() {
482 let (metering, memory) = setup().await;
483 let fid = setup_fighter(&memory).await;
484
485 metering
489 .record_usage(&fid, "claude-sonnet-4-20250514", 100_000, 100_000)
490 .await
491 .expect("record usage");
492
493 let enforcer = BudgetEnforcer::new(Arc::clone(&metering));
494 enforcer.set_fighter_limit(
495 fid,
496 BudgetLimit {
497 max_cost_per_day_cents: Some(100), ..Default::default()
499 },
500 );
501
502 let verdict = enforcer.check_budget(&fid).await.expect("check budget");
503 assert!(
504 matches!(verdict, BudgetVerdict::Blocked { .. }),
505 "expected blocked, got {:?}",
506 verdict
507 );
508
509 if let BudgetVerdict::Blocked {
510 retry_after_secs, ..
511 } = verdict
512 {
513 assert!(retry_after_secs > 0);
514 }
515 }
516
517 #[tokio::test]
518 async fn budget_resets_at_period_boundary() {
519 let (metering, memory) = setup().await;
522 let fid = setup_fighter(&memory).await;
523 let enforcer = BudgetEnforcer::new(Arc::clone(&metering));
524
525 enforcer.set_fighter_limit(
526 fid,
527 BudgetLimit {
528 max_cost_per_day_cents: Some(100),
529 ..Default::default()
530 },
531 );
532
533 let verdict = enforcer.check_budget(&fid).await.expect("check budget");
535 assert_eq!(verdict, BudgetVerdict::Allowed);
536 }
537
538 #[tokio::test]
539 async fn per_fighter_limits_independent() {
540 let (metering, memory) = setup().await;
541 let fid1 = setup_fighter(&memory).await;
542 let fid2 = setup_fighter(&memory).await;
543
544 metering
546 .record_usage(&fid1, "claude-sonnet-4-20250514", 100_000, 100_000)
547 .await
548 .expect("record usage");
549
550 let enforcer = BudgetEnforcer::new(Arc::clone(&metering));
551 enforcer.set_fighter_limit(
552 fid1,
553 BudgetLimit {
554 max_cost_per_day_cents: Some(100),
555 ..Default::default()
556 },
557 );
558 enforcer.set_fighter_limit(
559 fid2,
560 BudgetLimit {
561 max_cost_per_day_cents: Some(100),
562 ..Default::default()
563 },
564 );
565
566 let v1 = enforcer.check_budget(&fid1).await.expect("check fid1");
567 let v2 = enforcer.check_budget(&fid2).await.expect("check fid2");
568
569 assert!(matches!(v1, BudgetVerdict::Blocked { .. }));
570 assert_eq!(v2, BudgetVerdict::Allowed);
571 }
572
573 #[tokio::test]
574 async fn global_limit_applies_to_all_fighters() {
575 let (metering, memory) = setup().await;
576 let fid = setup_fighter(&memory).await;
577
578 metering
580 .record_usage(&fid, "claude-sonnet-4-20250514", 100_000, 100_000)
581 .await
582 .expect("record usage");
583
584 let enforcer = BudgetEnforcer::new(Arc::clone(&metering));
585 enforcer
586 .set_global_limit(BudgetLimit {
587 max_cost_per_day_cents: Some(100),
588 ..Default::default()
589 })
590 .await;
591
592 let fid2 = setup_fighter(&memory).await;
594 let verdict = enforcer.check_budget(&fid2).await.expect("check budget");
595 assert!(
596 matches!(verdict, BudgetVerdict::Blocked { .. }),
597 "global limit should block: {:?}",
598 verdict
599 );
600 }
601
602 #[tokio::test]
603 async fn no_limit_always_allowed() {
604 let (metering, memory) = setup().await;
605 let fid = setup_fighter(&memory).await;
606
607 metering
609 .record_usage(&fid, "claude-sonnet-4-20250514", 1_000_000, 1_000_000)
610 .await
611 .expect("record usage");
612
613 let enforcer = BudgetEnforcer::new(Arc::clone(&metering));
614
615 let verdict = enforcer.check_budget(&fid).await.expect("check budget");
616 assert_eq!(verdict, BudgetVerdict::Allowed);
617 }
618
619 #[tokio::test]
620 async fn zero_limit_always_blocked() {
621 let (metering, memory) = setup().await;
622 let fid = setup_fighter(&memory).await;
623
624 let enforcer = BudgetEnforcer::new(Arc::clone(&metering));
626 enforcer.set_fighter_limit(
627 fid,
628 BudgetLimit {
629 max_cost_per_day_cents: Some(0),
630 ..Default::default()
631 },
632 );
633
634 let verdict = enforcer.check_budget(&fid).await.expect("check budget");
635 assert!(
636 matches!(verdict, BudgetVerdict::Blocked { .. }),
637 "zero limit should block: {:?}",
638 verdict
639 );
640 }
641
642 #[tokio::test]
643 async fn multiple_fighters_dont_interfere() {
644 let (metering, memory) = setup().await;
645 let fid1 = setup_fighter(&memory).await;
646 let fid2 = setup_fighter(&memory).await;
647 let fid3 = setup_fighter(&memory).await;
648
649 metering
651 .record_usage(&fid1, "claude-sonnet-4-20250514", 100_000, 100_000)
652 .await
653 .expect("record usage");
654
655 let enforcer = BudgetEnforcer::new(Arc::clone(&metering));
656
657 enforcer.set_fighter_limit(
659 fid1,
660 BudgetLimit {
661 max_cost_per_day_cents: Some(100),
662 ..Default::default()
663 },
664 );
665 enforcer.set_fighter_limit(
666 fid2,
667 BudgetLimit {
668 max_cost_per_day_cents: Some(100),
669 ..Default::default()
670 },
671 );
672 enforcer.set_fighter_limit(
673 fid3,
674 BudgetLimit {
675 max_cost_per_day_cents: Some(50),
676 ..Default::default()
677 },
678 );
679
680 let v1 = enforcer.check_budget(&fid1).await.expect("check fid1");
681 let v2 = enforcer.check_budget(&fid2).await.expect("check fid2");
682 let v3 = enforcer.check_budget(&fid3).await.expect("check fid3");
683
684 assert!(matches!(v1, BudgetVerdict::Blocked { .. }));
685 assert_eq!(v2, BudgetVerdict::Allowed);
686 assert_eq!(v3, BudgetVerdict::Allowed);
687 }
688
689 #[tokio::test]
690 async fn warning_threshold_configurable() {
691 let (metering, memory) = setup().await;
692 let fid = setup_fighter(&memory).await;
693
694 metering
696 .record_usage(&fid, "claude-sonnet-4-20250514", 50000, 50000)
697 .await
698 .expect("record usage");
699
700 let enforcer = BudgetEnforcer::new(Arc::clone(&metering));
701
702 enforcer.set_fighter_limit(
704 fid,
705 BudgetLimit {
706 max_cost_per_day_cents: Some(100),
707 warning_threshold_percent: 95,
708 ..Default::default()
709 },
710 );
711
712 let verdict = enforcer.check_budget(&fid).await.expect("check budget");
713 assert_eq!(
714 verdict,
715 BudgetVerdict::Allowed,
716 "95% threshold should not warn at 90%: {:?}",
717 verdict
718 );
719
720 enforcer.set_fighter_limit(
722 fid,
723 BudgetLimit {
724 max_cost_per_day_cents: Some(100),
725 warning_threshold_percent: 50,
726 ..Default::default()
727 },
728 );
729
730 let verdict = enforcer.check_budget(&fid).await.expect("check budget");
731 assert!(
732 matches!(verdict, BudgetVerdict::Warning { .. }),
733 "50% threshold should warn at 90%: {:?}",
734 verdict
735 );
736 }
737
738 #[tokio::test]
739 async fn cost_based_budget() {
740 let (metering, memory) = setup().await;
741 let fid = setup_fighter(&memory).await;
742
743 metering
746 .record_usage(&fid, "gpt-4o-mini", 10_000, 10_000)
747 .await
748 .expect("record usage");
749
750 let enforcer = BudgetEnforcer::new(Arc::clone(&metering));
751 enforcer.set_fighter_limit(
752 fid,
753 BudgetLimit {
754 max_cost_per_day_cents: Some(1), ..Default::default()
756 },
757 );
758
759 let verdict = enforcer.check_budget(&fid).await.expect("check budget");
760 assert_eq!(
762 verdict,
763 BudgetVerdict::Allowed,
764 "0.75 cents should be under 1 cent limit at 80% threshold: {:?}",
765 verdict
766 );
767 }
768
769 #[test]
770 fn most_restrictive_selects_blocked_over_warning() {
771 let a = BudgetVerdict::Warning {
772 usage_percent: 85.0,
773 message: "warning".to_string(),
774 };
775 let b = BudgetVerdict::Blocked {
776 reason: "blocked".to_string(),
777 retry_after_secs: 100,
778 };
779
780 let result = most_restrictive(a, b);
781 assert!(matches!(result, BudgetVerdict::Blocked { .. }));
782 }
783
784 #[test]
785 fn most_restrictive_selects_warning_over_allowed() {
786 let a = BudgetVerdict::Allowed;
787 let b = BudgetVerdict::Warning {
788 usage_percent: 85.0,
789 message: "warning".to_string(),
790 };
791
792 let result = most_restrictive(a, b);
793 assert!(matches!(result, BudgetVerdict::Warning { .. }));
794 }
795
796 #[test]
797 fn most_restrictive_both_allowed() {
798 let result = most_restrictive(BudgetVerdict::Allowed, BudgetVerdict::Allowed);
799 assert_eq!(result, BudgetVerdict::Allowed);
800 }
801
802 #[test]
803 fn budget_limit_default() {
804 let limit = BudgetLimit::default();
805 assert!(!limit.has_any_limit());
806 assert_eq!(limit.warning_threshold_percent, 80);
807 }
808
809 #[test]
810 fn seconds_until_day_reset_positive() {
811 let secs = seconds_until_day_reset();
812 assert!(secs > 0);
813 assert!(secs <= 86400);
814 }
815}