Skip to main content

dome_throttle/
budget.rs

1use std::time::Duration;
2
3use dashmap::DashMap;
4use dome_core::DomeError;
5use tokio::time::Instant;
6use tracing::warn;
7
8/// Per-identity budget with a rolling window.
9#[derive(Debug, Clone)]
10pub struct Budget {
11    pub spent: f64,
12    pub cap: f64,
13    pub unit: String,
14    pub window: Duration,
15    pub window_start: Instant,
16}
17
18impl Budget {
19    pub fn new(cap: f64, unit: impl Into<String>, window: Duration) -> Self {
20        Self {
21            spent: 0.0,
22            cap,
23            unit: unit.into(),
24            window,
25            window_start: Instant::now(),
26        }
27    }
28
29    /// Create with an explicit start time (for testing).
30    pub fn new_at(cap: f64, unit: impl Into<String>, window: Duration, now: Instant) -> Self {
31        Self {
32            spent: 0.0,
33            cap,
34            unit: unit.into(),
35            window,
36            window_start: now,
37        }
38    }
39
40    /// Remaining budget in the current window.
41    pub fn remaining(&self) -> f64 {
42        (self.cap - self.spent).max(0.0)
43    }
44
45    /// Check if the window has expired and reset if so. Returns true if reset happened.
46    fn maybe_reset(&mut self, now: Instant) -> bool {
47        if now.duration_since(self.window_start) >= self.window {
48            self.spent = 0.0;
49            self.window_start = now;
50            true
51        } else {
52            false
53        }
54    }
55
56    /// Try to spend `amount` from this budget. Resets window if expired.
57    fn try_spend_inner(&mut self, amount: f64, now: Instant) -> Result<(), (f64, f64, String)> {
58        self.maybe_reset(now);
59
60        if self.spent + amount > self.cap {
61            Err((self.spent, self.cap, self.unit.clone()))
62        } else {
63            self.spent += amount;
64            Ok(())
65        }
66    }
67}
68
69/// Configuration for default budgets assigned to new identities.
70#[derive(Debug, Clone)]
71pub struct BudgetTrackerConfig {
72    pub default_cap: f64,
73    pub default_unit: String,
74    pub default_window: Duration,
75}
76
77impl Default for BudgetTrackerConfig {
78    fn default() -> Self {
79        Self {
80            default_cap: 100.0,
81            default_unit: "calls".to_string(),
82            default_window: Duration::from_secs(3600), // 1 hour
83        }
84    }
85}
86
87/// Concurrent budget tracker backed by DashMap.
88///
89/// Tracks cumulative spend per identity within rolling time windows.
90/// Budgets are created lazily with default config on first access.
91/// Stale entries (past their window) are periodically cleaned up to
92/// prevent unbounded memory growth.
93pub struct BudgetTracker {
94    budgets: DashMap<String, Budget>,
95    config: BudgetTrackerConfig,
96    /// Maximum number of tracked identities before cleanup triggers.
97    max_entries: usize,
98    /// Counter for periodic cleanup scheduling.
99    insert_counter: std::sync::atomic::AtomicU64,
100}
101
102impl BudgetTracker {
103    pub fn new(config: BudgetTrackerConfig) -> Self {
104        Self {
105            budgets: DashMap::new(),
106            max_entries: 10_000,
107            insert_counter: std::sync::atomic::AtomicU64::new(0),
108            config,
109        }
110    }
111
112    /// Create a tracker with a custom max entries limit.
113    pub fn with_max_entries(config: BudgetTrackerConfig, max_entries: usize) -> Self {
114        Self {
115            budgets: DashMap::new(),
116            max_entries,
117            insert_counter: std::sync::atomic::AtomicU64::new(0),
118            config,
119        }
120    }
121
122    /// Try to spend `amount` for the given identity.
123    ///
124    /// If the budget window has expired, it resets automatically before checking.
125    /// Returns `Ok(())` if spend is within cap, or `DomeError::BudgetExhausted` otherwise.
126    pub fn try_spend(&self, identity: &str, amount: f64) -> Result<(), DomeError> {
127        self.try_spend_at(identity, amount, Instant::now())
128    }
129
130    /// Same as `try_spend` but with explicit timestamp (for testing).
131    pub fn try_spend_at(&self, identity: &str, amount: f64, now: Instant) -> Result<(), DomeError> {
132        let is_new = !self.budgets.contains_key(identity);
133
134        let mut entry = self.budgets.entry(identity.to_string()).or_insert_with(|| {
135            Budget::new_at(
136                self.config.default_cap,
137                &self.config.default_unit,
138                self.config.default_window,
139                now,
140            )
141        });
142
143        let result = entry
144            .try_spend_inner(amount, now)
145            .map_err(|(spent, cap, unit)| {
146                warn!(
147                    identity = identity,
148                    spent = spent,
149                    cap = cap,
150                    unit = %unit,
151                    "budget exhausted"
152                );
153                DomeError::BudgetExhausted { spent, cap, unit }
154            });
155
156        // Periodic cleanup on new insertions
157        if is_new {
158            let count = self
159                .insert_counter
160                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
161            if count % 100 == 99 {
162                drop(entry);
163                self.maybe_cleanup(now);
164            }
165        }
166
167        result
168    }
169
170    /// Remove entries whose windows have expired (stale budgets).
171    /// Called periodically to prevent unbounded memory growth.
172    fn maybe_cleanup(&self, now: Instant) {
173        if self.budgets.len() <= self.max_entries {
174            return;
175        }
176
177        self.budgets.retain(|_key, budget| {
178            // Keep entries whose window hasn't expired yet
179            now.duration_since(budget.window_start) < budget.window
180        });
181    }
182
183    /// Explicitly run cleanup, removing entries whose windows have expired.
184    pub fn cleanup(&self) {
185        let now = Instant::now();
186        self.budgets
187            .retain(|_key, budget| now.duration_since(budget.window_start) < budget.window);
188    }
189
190    /// Register a custom budget for an identity (overrides defaults).
191    pub fn set_budget(&self, identity: impl Into<String>, budget: Budget) {
192        self.budgets.insert(identity.into(), budget);
193    }
194
195    /// Current spend for an identity, if tracked.
196    pub fn current_spend(&self, identity: &str) -> Option<f64> {
197        self.budgets.get(identity).map(|b| b.spent)
198    }
199
200    /// Number of tracked identities.
201    pub fn tracked_count(&self) -> usize {
202        self.budgets.len()
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    fn tracker_with_cap(cap: f64, window_secs: u64) -> BudgetTracker {
211        BudgetTracker::new(BudgetTrackerConfig {
212            default_cap: cap,
213            default_unit: "usd".to_string(),
214            default_window: Duration::from_secs(window_secs),
215        })
216    }
217
218    #[tokio::test(start_paused = true)]
219    async fn spend_within_cap_succeeds() {
220        let tracker = tracker_with_cap(10.0, 3600);
221        let now = Instant::now();
222
223        assert!(tracker.try_spend_at("user-a", 3.0, now).is_ok());
224        assert!(tracker.try_spend_at("user-a", 3.0, now).is_ok());
225        assert!(tracker.try_spend_at("user-a", 4.0, now).is_ok());
226        // Exactly at cap
227        assert_eq!(tracker.current_spend("user-a"), Some(10.0));
228    }
229
230    #[tokio::test(start_paused = true)]
231    async fn rejects_when_exceeding_cap() {
232        let tracker = tracker_with_cap(5.0, 3600);
233        let now = Instant::now();
234
235        assert!(tracker.try_spend_at("user-b", 4.0, now).is_ok());
236
237        // This would push to 6.0, exceeding cap of 5.0
238        let err = tracker.try_spend_at("user-b", 2.0, now).unwrap_err();
239        match err {
240            DomeError::BudgetExhausted { spent, cap, unit } => {
241                assert!((spent - 4.0).abs() < f64::EPSILON);
242                assert!((cap - 5.0).abs() < f64::EPSILON);
243                assert_eq!(unit, "usd");
244            }
245            other => panic!("expected BudgetExhausted, got: {other:?}"),
246        }
247
248        // Spend should not have changed after rejection
249        assert_eq!(tracker.current_spend("user-b"), Some(4.0));
250    }
251
252    #[tokio::test(start_paused = true)]
253    async fn window_reset_clears_spend() {
254        let tracker = tracker_with_cap(5.0, 60); // 60 second window
255        let now = Instant::now();
256
257        // Spend to the cap
258        assert!(tracker.try_spend_at("user-c", 5.0, now).is_ok());
259        assert!(tracker.try_spend_at("user-c", 1.0, now).is_err());
260
261        // Advance past the window
262        let later = now + Duration::from_secs(61);
263        assert!(
264            tracker.try_spend_at("user-c", 3.0, later).is_ok(),
265            "should succeed after window reset"
266        );
267        assert_eq!(tracker.current_spend("user-c"), Some(3.0));
268    }
269
270    #[tokio::test(start_paused = true)]
271    async fn separate_identities_have_separate_budgets() {
272        let tracker = tracker_with_cap(5.0, 3600);
273        let now = Instant::now();
274
275        assert!(tracker.try_spend_at("alice", 5.0, now).is_ok());
276        assert!(tracker.try_spend_at("alice", 1.0, now).is_err());
277
278        // Bob is unaffected
279        assert!(tracker.try_spend_at("bob", 5.0, now).is_ok());
280    }
281
282    #[tokio::test(start_paused = true)]
283    async fn custom_budget_overrides_defaults() {
284        let tracker = tracker_with_cap(100.0, 3600);
285        let now = Instant::now();
286
287        // Set a tight budget for a specific identity
288        tracker.set_budget(
289            "restricted-user",
290            Budget::new_at(2.0, "tokens", Duration::from_secs(60), now),
291        );
292
293        assert!(tracker.try_spend_at("restricted-user", 1.0, now).is_ok());
294        assert!(tracker.try_spend_at("restricted-user", 1.0, now).is_ok());
295        assert!(tracker.try_spend_at("restricted-user", 1.0, now).is_err());
296    }
297
298    #[test]
299    fn concurrent_budget_tracking() {
300        use std::sync::Arc;
301        use std::thread;
302
303        // 1000 cap, spend 1.0 per request, 100 threads x 5 requests = 500 total
304        let tracker = Arc::new(tracker_with_cap(1000.0, 3600));
305        let mut handles = vec![];
306
307        for t in 0..10 {
308            let tracker = Arc::clone(&tracker);
309            handles.push(thread::spawn(move || {
310                let id = format!("concurrent-{t}");
311                let mut ok = 0u32;
312                for _ in 0..5 {
313                    if tracker.try_spend(&id, 1.0).is_ok() {
314                        ok += 1;
315                    }
316                }
317                ok
318            }));
319        }
320
321        let total: u32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
322        // Each identity has its own budget of 1000, so all 50 should pass
323        assert_eq!(total, 50);
324    }
325
326    #[test]
327    fn concurrent_same_identity_respects_cap() {
328        use std::sync::Arc;
329        use std::thread;
330
331        // Single identity, cap = 10, 20 threads each trying to spend 1.0
332        let tracker = Arc::new(tracker_with_cap(10.0, 3600));
333        let mut handles = vec![];
334
335        for _ in 0..20 {
336            let tracker = Arc::clone(&tracker);
337            handles.push(thread::spawn(move || {
338                if tracker.try_spend("shared-user", 1.0).is_ok() {
339                    1u32
340                } else {
341                    0u32
342                }
343            }));
344        }
345
346        let total: u32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
347        // Exactly 10 should succeed (cap = 10, each spends 1.0)
348        assert_eq!(total, 10);
349    }
350}