Skip to main content

loadwise_core/
quota.rs

1//! Per-node multi-dimensional quota tracking.
2
3use std::collections::HashMap;
4use std::hash::Hash;
5use std::sync::RwLock;
6use std::time::{Duration, Instant};
7
8use crate::window::FixedWindowCounter;
9
10/// Configuration for a single quota dimension (e.g., "rpm" or "tpm").
11#[derive(Clone, bon::Builder)]
12pub struct QuotaDimension {
13    /// Dimension name (e.g., "rpm", "tpm").
14    pub name: String,
15    /// Time window for this dimension.
16    pub window: Duration,
17    /// Maximum allowed units within the window.
18    pub limit: u64,
19    /// Number of time buckets within the window (higher = more precision, more memory).
20    /// **Default: 6.**
21    #[builder(default = 6)]
22    pub resolution: usize,
23}
24
25/// Per-node quota configuration.
26#[derive(Clone, bon::Builder)]
27pub struct QuotaConfig {
28    /// Quota dimensions to track for this node.
29    pub dimensions: Vec<QuotaDimension>,
30}
31
32/// Internal per-node state.
33struct NodeQuota {
34    /// (dimension_name, limit, counter)
35    dimensions: Vec<(String, u64, FixedWindowCounter)>,
36    /// If set, node is temporarily exhausted until this instant.
37    exhausted_until: Option<Instant>,
38}
39
40/// Tracks rate/quota usage for a set of nodes.
41///
42/// Thread-safe — all methods take `&self` and synchronise internally via [`RwLock`].
43///
44/// # Examples
45///
46/// ```
47/// # extern crate loadwise_core as loadwise;
48/// use loadwise::quota::{QuotaTracker, QuotaConfig, QuotaDimension};
49/// use std::time::Duration;
50///
51/// let tracker = QuotaTracker::<String>::new();
52///
53/// tracker.register(&"key-1".into(), QuotaConfig::builder()
54///     .dimensions(vec![
55///         QuotaDimension::builder()
56///             .name("rpm".into())
57///             .window(Duration::from_secs(60))
58///             .limit(100)
59///             .build(),
60///     ])
61///     .build());
62///
63/// tracker.record_usage(&"key-1".into(), &[("rpm", 1)]);
64/// assert_eq!(tracker.remaining(&"key-1".into(), "rpm"), 99);
65/// ```
66pub struct QuotaTracker<Id: Eq + Hash + Clone> {
67    nodes: RwLock<HashMap<Id, NodeQuota>>,
68}
69
70impl<Id: Eq + Hash + Clone> std::fmt::Debug for QuotaTracker<Id> {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        let count = self.nodes.read().unwrap().len();
73        f.debug_struct("QuotaTracker")
74            .field("tracked_nodes", &count)
75            .finish_non_exhaustive()
76    }
77}
78
79impl<Id: Eq + Hash + Clone> QuotaTracker<Id> {
80    /// Creates a new, empty tracker.
81    pub fn new() -> Self {
82        Self {
83            nodes: RwLock::new(HashMap::new()),
84        }
85    }
86
87    /// Register a node with its quota configuration. If already registered, replaces config.
88    pub fn register(&self, id: &Id, config: QuotaConfig) {
89        let dimensions = config
90            .dimensions
91            .into_iter()
92            .map(|d| {
93                let counter = FixedWindowCounter::new(d.window, d.resolution);
94                (d.name, d.limit, counter)
95            })
96            .collect();
97
98        let mut nodes = self.nodes.write().unwrap();
99        nodes.insert(
100            id.clone(),
101            NodeQuota {
102                dimensions,
103                exhausted_until: None,
104            },
105        );
106    }
107
108    /// Record usage. `amounts` is a slice of `(dimension_name, amount)` pairs.
109    /// Unknown dimensions are silently ignored.
110    ///
111    /// Returns `true` if the node was registered (usage recorded),
112    /// `false` if the node ID was not found.
113    pub fn record_usage(&self, id: &Id, amounts: &[(&str, u64)]) -> bool {
114        let mut nodes = self.nodes.write().unwrap();
115        let Some(node) = nodes.get_mut(id) else {
116            return false;
117        };
118
119        // Clear expired exhaustion deadline while we hold the write lock.
120        if let Some(t) = node.exhausted_until {
121            if Instant::now() >= t {
122                node.exhausted_until = None;
123            }
124        }
125
126        for &(dim_name, amount) in amounts {
127            for (name, _, counter) in &node.dimensions {
128                if name == dim_name {
129                    counter.record(amount);
130                    break;
131                }
132            }
133        }
134        true
135    }
136
137    /// Mark a node as exhausted for `duration` (e.g., from a 429 Retry-After header).
138    pub fn mark_exhausted(&self, id: &Id, duration: Duration) {
139        let mut nodes = self.nodes.write().unwrap();
140        if let Some(node) = nodes.get_mut(id) {
141            node.exhausted_until = Some(Instant::now() + duration);
142        }
143    }
144
145    /// Whether the node has capacity for the estimated usage across ALL dimensions.
146    /// Returns `false` if the node is marked exhausted or any dimension would exceed its limit.
147    /// Returns `true` for unregistered nodes (unknown = no quota constraint).
148    pub fn has_capacity(&self, id: &Id, estimated: &[(&str, u64)]) -> bool {
149        let nodes = self.nodes.read().unwrap();
150        let Some(node) = nodes.get(id) else {
151            return true;
152        };
153
154        // Check exhaustion (stale deadlines are treated as not-exhausted).
155        if let Some(t) = node.exhausted_until {
156            if Instant::now() < t {
157                return false;
158            }
159        }
160
161        for &(dim_name, est) in estimated {
162            for (name, limit, counter) in &node.dimensions {
163                if name == dim_name {
164                    if counter.remaining(*limit) < est {
165                        return false;
166                    }
167                    break;
168                }
169            }
170        }
171
172        true
173    }
174
175    /// Remaining capacity in a specific dimension. Returns `u64::MAX` for unknown nodes/dimensions.
176    pub fn remaining(&self, id: &Id, dimension: &str) -> u64 {
177        let nodes = self.nodes.read().unwrap();
178        let Some(node) = nodes.get(id) else {
179            return u64::MAX;
180        };
181
182        for (name, limit, counter) in &node.dimensions {
183            if name == dimension {
184                return counter.remaining(*limit);
185            }
186        }
187
188        u64::MAX
189    }
190
191    /// Pressure score across all dimensions: `0.0` = fully idle, `1.0` = at least one dimension full.
192    /// Computed as `max(usage_i / limit_i)` across all dimensions.
193    /// Dimensions with `limit == 0` are skipped (treated as unconstrained).
194    /// Returns `0.0` for unregistered nodes. Returns `1.0` if the node is marked exhausted.
195    pub fn pressure(&self, id: &Id) -> f64 {
196        let nodes = self.nodes.read().unwrap();
197        let Some(node) = nodes.get(id) else {
198            return 0.0;
199        };
200
201        // Check exhaustion (stale deadlines are treated as not-exhausted).
202        if let Some(t) = node.exhausted_until {
203            if Instant::now() < t {
204                return 1.0;
205            }
206        }
207
208        let mut max_pressure: f64 = 0.0;
209        for (_, limit, counter) in &node.dimensions {
210            if *limit == 0 {
211                continue;
212            }
213            let usage = counter.sum() as f64 / *limit as f64;
214            if usage > max_pressure {
215                max_pressure = usage;
216            }
217        }
218
219        max_pressure.min(1.0)
220    }
221}
222
223impl<Id: Eq + Hash + Clone> Default for QuotaTracker<Id> {
224    fn default() -> Self {
225        Self::new()
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    fn make_tracker_with_node(id: &str, limit: u64) -> (QuotaTracker<String>, String) {
234        let tracker = QuotaTracker::<String>::new();
235        let node_id = id.to_string();
236        tracker.register(
237            &node_id,
238            QuotaConfig::builder()
239                .dimensions(vec![QuotaDimension::builder()
240                    .name("rpm".into())
241                    .window(Duration::from_secs(60))
242                    .limit(limit)
243                    .build()])
244                .build(),
245        );
246        (tracker, node_id)
247    }
248
249    #[test]
250    fn basic_register_and_record() {
251        let (tracker, id) = make_tracker_with_node("n1", 100);
252
253        tracker.record_usage(&id, &[("rpm", 10)]);
254        assert_eq!(tracker.remaining(&id, "rpm"), 90);
255
256        tracker.record_usage(&id, &[("rpm", 5)]);
257        assert_eq!(tracker.remaining(&id, "rpm"), 85);
258    }
259
260    #[test]
261    fn pressure_increases_with_usage() {
262        let (tracker, id) = make_tracker_with_node("n1", 100);
263
264        let p0 = tracker.pressure(&id);
265        assert!((p0 - 0.0).abs() < f64::EPSILON);
266
267        tracker.record_usage(&id, &[("rpm", 50)]);
268        let p50 = tracker.pressure(&id);
269        assert!((p50 - 0.5).abs() < f64::EPSILON);
270
271        tracker.record_usage(&id, &[("rpm", 50)]);
272        let p100 = tracker.pressure(&id);
273        assert!((p100 - 1.0).abs() < f64::EPSILON);
274    }
275
276    #[test]
277    fn has_capacity_with_estimation() {
278        let (tracker, id) = make_tracker_with_node("n1", 100);
279
280        tracker.record_usage(&id, &[("rpm", 90)]);
281        assert!(tracker.has_capacity(&id, &[("rpm", 10)]));
282        assert!(!tracker.has_capacity(&id, &[("rpm", 11)]));
283    }
284
285    #[test]
286    fn mark_exhausted_blocks_capacity() {
287        let (tracker, id) = make_tracker_with_node("n1", 100);
288
289        tracker.mark_exhausted(&id, Duration::from_secs(60));
290        assert!(!tracker.has_capacity(&id, &[("rpm", 1)]));
291        assert!((tracker.pressure(&id) - 1.0).abs() < f64::EPSILON);
292    }
293
294    #[test]
295    fn exhausted_expires() {
296        let (tracker, id) = make_tracker_with_node("n1", 100);
297
298        tracker.mark_exhausted(&id, Duration::from_millis(20));
299        assert!(!tracker.has_capacity(&id, &[("rpm", 1)]));
300
301        std::thread::sleep(Duration::from_millis(30));
302        assert!(tracker.has_capacity(&id, &[("rpm", 1)]));
303    }
304
305    #[test]
306    fn unknown_node_has_unlimited_capacity() {
307        let tracker = QuotaTracker::<String>::new();
308        let unknown = "unknown".to_string();
309
310        assert!(tracker.has_capacity(&unknown, &[("rpm", 999)]));
311        assert_eq!(tracker.remaining(&unknown, "rpm"), u64::MAX);
312        assert!((tracker.pressure(&unknown) - 0.0).abs() < f64::EPSILON);
313    }
314}