1use std::time::Duration;
2
3use dashmap::DashMap;
4use dome_core::DomeError;
5use tokio::time::Instant;
6use tracing::warn;
7
8#[derive(Debug, Clone)]
10pub struct Budget {
11 pub spent: f64,
12 pub cap: f64,
13 pub unit: String,
14 pub window: Duration,
15 pub window_start: Instant,
16}
17
18impl Budget {
19 pub fn new(cap: f64, unit: impl Into<String>, window: Duration) -> Self {
20 Self {
21 spent: 0.0,
22 cap,
23 unit: unit.into(),
24 window,
25 window_start: Instant::now(),
26 }
27 }
28
29 pub fn new_at(cap: f64, unit: impl Into<String>, window: Duration, now: Instant) -> Self {
31 Self {
32 spent: 0.0,
33 cap,
34 unit: unit.into(),
35 window,
36 window_start: now,
37 }
38 }
39
40 pub fn remaining(&self) -> f64 {
42 (self.cap - self.spent).max(0.0)
43 }
44
45 fn maybe_reset(&mut self, now: Instant) -> bool {
47 if now.duration_since(self.window_start) >= self.window {
48 self.spent = 0.0;
49 self.window_start = now;
50 true
51 } else {
52 false
53 }
54 }
55
56 fn try_spend_inner(&mut self, amount: f64, now: Instant) -> Result<(), (f64, f64, String)> {
58 self.maybe_reset(now);
59
60 if self.spent + amount > self.cap {
61 Err((self.spent, self.cap, self.unit.clone()))
62 } else {
63 self.spent += amount;
64 Ok(())
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
71pub struct BudgetTrackerConfig {
72 pub default_cap: f64,
73 pub default_unit: String,
74 pub default_window: Duration,
75}
76
77impl Default for BudgetTrackerConfig {
78 fn default() -> Self {
79 Self {
80 default_cap: 100.0,
81 default_unit: "calls".to_string(),
82 default_window: Duration::from_secs(3600), }
84 }
85}
86
87pub struct BudgetTracker {
94 budgets: DashMap<String, Budget>,
95 config: BudgetTrackerConfig,
96 max_entries: usize,
98 insert_counter: std::sync::atomic::AtomicU64,
100}
101
102impl BudgetTracker {
103 pub fn new(config: BudgetTrackerConfig) -> Self {
104 Self {
105 budgets: DashMap::new(),
106 max_entries: 10_000,
107 insert_counter: std::sync::atomic::AtomicU64::new(0),
108 config,
109 }
110 }
111
112 pub fn with_max_entries(config: BudgetTrackerConfig, max_entries: usize) -> Self {
114 Self {
115 budgets: DashMap::new(),
116 max_entries,
117 insert_counter: std::sync::atomic::AtomicU64::new(0),
118 config,
119 }
120 }
121
122 pub fn try_spend(&self, identity: &str, amount: f64) -> Result<(), DomeError> {
127 self.try_spend_at(identity, amount, Instant::now())
128 }
129
130 pub fn try_spend_at(&self, identity: &str, amount: f64, now: Instant) -> Result<(), DomeError> {
132 let is_new = !self.budgets.contains_key(identity);
133
134 let mut entry = self.budgets.entry(identity.to_string()).or_insert_with(|| {
135 Budget::new_at(
136 self.config.default_cap,
137 &self.config.default_unit,
138 self.config.default_window,
139 now,
140 )
141 });
142
143 let result = entry
144 .try_spend_inner(amount, now)
145 .map_err(|(spent, cap, unit)| {
146 warn!(
147 identity = identity,
148 spent = spent,
149 cap = cap,
150 unit = %unit,
151 "budget exhausted"
152 );
153 DomeError::BudgetExhausted { spent, cap, unit }
154 });
155
156 if is_new {
158 let count = self
159 .insert_counter
160 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
161 if count % 100 == 99 {
162 drop(entry);
163 self.maybe_cleanup(now);
164 }
165 }
166
167 result
168 }
169
170 fn maybe_cleanup(&self, now: Instant) {
173 if self.budgets.len() <= self.max_entries {
174 return;
175 }
176
177 self.budgets.retain(|_key, budget| {
178 now.duration_since(budget.window_start) < budget.window
180 });
181 }
182
183 pub fn cleanup(&self) {
185 let now = Instant::now();
186 self.budgets
187 .retain(|_key, budget| now.duration_since(budget.window_start) < budget.window);
188 }
189
190 pub fn set_budget(&self, identity: impl Into<String>, budget: Budget) {
192 self.budgets.insert(identity.into(), budget);
193 }
194
195 pub fn current_spend(&self, identity: &str) -> Option<f64> {
197 self.budgets.get(identity).map(|b| b.spent)
198 }
199
200 pub fn tracked_count(&self) -> usize {
202 self.budgets.len()
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 fn tracker_with_cap(cap: f64, window_secs: u64) -> BudgetTracker {
211 BudgetTracker::new(BudgetTrackerConfig {
212 default_cap: cap,
213 default_unit: "usd".to_string(),
214 default_window: Duration::from_secs(window_secs),
215 })
216 }
217
218 #[tokio::test(start_paused = true)]
219 async fn spend_within_cap_succeeds() {
220 let tracker = tracker_with_cap(10.0, 3600);
221 let now = Instant::now();
222
223 assert!(tracker.try_spend_at("user-a", 3.0, now).is_ok());
224 assert!(tracker.try_spend_at("user-a", 3.0, now).is_ok());
225 assert!(tracker.try_spend_at("user-a", 4.0, now).is_ok());
226 assert_eq!(tracker.current_spend("user-a"), Some(10.0));
228 }
229
230 #[tokio::test(start_paused = true)]
231 async fn rejects_when_exceeding_cap() {
232 let tracker = tracker_with_cap(5.0, 3600);
233 let now = Instant::now();
234
235 assert!(tracker.try_spend_at("user-b", 4.0, now).is_ok());
236
237 let err = tracker.try_spend_at("user-b", 2.0, now).unwrap_err();
239 match err {
240 DomeError::BudgetExhausted { spent, cap, unit } => {
241 assert!((spent - 4.0).abs() < f64::EPSILON);
242 assert!((cap - 5.0).abs() < f64::EPSILON);
243 assert_eq!(unit, "usd");
244 }
245 other => panic!("expected BudgetExhausted, got: {other:?}"),
246 }
247
248 assert_eq!(tracker.current_spend("user-b"), Some(4.0));
250 }
251
252 #[tokio::test(start_paused = true)]
253 async fn window_reset_clears_spend() {
254 let tracker = tracker_with_cap(5.0, 60); let now = Instant::now();
256
257 assert!(tracker.try_spend_at("user-c", 5.0, now).is_ok());
259 assert!(tracker.try_spend_at("user-c", 1.0, now).is_err());
260
261 let later = now + Duration::from_secs(61);
263 assert!(
264 tracker.try_spend_at("user-c", 3.0, later).is_ok(),
265 "should succeed after window reset"
266 );
267 assert_eq!(tracker.current_spend("user-c"), Some(3.0));
268 }
269
270 #[tokio::test(start_paused = true)]
271 async fn separate_identities_have_separate_budgets() {
272 let tracker = tracker_with_cap(5.0, 3600);
273 let now = Instant::now();
274
275 assert!(tracker.try_spend_at("alice", 5.0, now).is_ok());
276 assert!(tracker.try_spend_at("alice", 1.0, now).is_err());
277
278 assert!(tracker.try_spend_at("bob", 5.0, now).is_ok());
280 }
281
282 #[tokio::test(start_paused = true)]
283 async fn custom_budget_overrides_defaults() {
284 let tracker = tracker_with_cap(100.0, 3600);
285 let now = Instant::now();
286
287 tracker.set_budget(
289 "restricted-user",
290 Budget::new_at(2.0, "tokens", Duration::from_secs(60), now),
291 );
292
293 assert!(tracker.try_spend_at("restricted-user", 1.0, now).is_ok());
294 assert!(tracker.try_spend_at("restricted-user", 1.0, now).is_ok());
295 assert!(tracker.try_spend_at("restricted-user", 1.0, now).is_err());
296 }
297
298 #[test]
299 fn concurrent_budget_tracking() {
300 use std::sync::Arc;
301 use std::thread;
302
303 let tracker = Arc::new(tracker_with_cap(1000.0, 3600));
305 let mut handles = vec![];
306
307 for t in 0..10 {
308 let tracker = Arc::clone(&tracker);
309 handles.push(thread::spawn(move || {
310 let id = format!("concurrent-{t}");
311 let mut ok = 0u32;
312 for _ in 0..5 {
313 if tracker.try_spend(&id, 1.0).is_ok() {
314 ok += 1;
315 }
316 }
317 ok
318 }));
319 }
320
321 let total: u32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
322 assert_eq!(total, 50);
324 }
325
326 #[test]
327 fn concurrent_same_identity_respects_cap() {
328 use std::sync::Arc;
329 use std::thread;
330
331 let tracker = Arc::new(tracker_with_cap(10.0, 3600));
333 let mut handles = vec![];
334
335 for _ in 0..20 {
336 let tracker = Arc::clone(&tracker);
337 handles.push(thread::spawn(move || {
338 if tracker.try_spend("shared-user", 1.0).is_ok() {
339 1u32
340 } else {
341 0u32
342 }
343 }));
344 }
345
346 let total: u32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
347 assert_eq!(total, 10);
349 }
350}