1use std::collections::HashMap;
4use std::hash::Hash;
5use std::sync::RwLock;
6use std::time::{Duration, Instant};
7
8use crate::window::FixedWindowCounter;
9
10#[derive(Clone, bon::Builder)]
12pub struct QuotaDimension {
13 pub name: String,
15 pub window: Duration,
17 pub limit: u64,
19 #[builder(default = 6)]
22 pub resolution: usize,
23}
24
25#[derive(Clone, bon::Builder)]
27pub struct QuotaConfig {
28 pub dimensions: Vec<QuotaDimension>,
30}
31
32struct NodeQuota {
34 dimensions: Vec<(String, u64, FixedWindowCounter)>,
36 exhausted_until: Option<Instant>,
38}
39
40pub struct QuotaTracker<Id: Eq + Hash + Clone> {
67 nodes: RwLock<HashMap<Id, NodeQuota>>,
68}
69
70impl<Id: Eq + Hash + Clone> std::fmt::Debug for QuotaTracker<Id> {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 let count = self.nodes.read().unwrap().len();
73 f.debug_struct("QuotaTracker")
74 .field("tracked_nodes", &count)
75 .finish_non_exhaustive()
76 }
77}
78
79impl<Id: Eq + Hash + Clone> QuotaTracker<Id> {
80 pub fn new() -> Self {
82 Self {
83 nodes: RwLock::new(HashMap::new()),
84 }
85 }
86
87 pub fn register(&self, id: &Id, config: QuotaConfig) {
89 let dimensions = config
90 .dimensions
91 .into_iter()
92 .map(|d| {
93 let counter = FixedWindowCounter::new(d.window, d.resolution);
94 (d.name, d.limit, counter)
95 })
96 .collect();
97
98 let mut nodes = self.nodes.write().unwrap();
99 nodes.insert(
100 id.clone(),
101 NodeQuota {
102 dimensions,
103 exhausted_until: None,
104 },
105 );
106 }
107
108 pub fn record_usage(&self, id: &Id, amounts: &[(&str, u64)]) -> bool {
114 let mut nodes = self.nodes.write().unwrap();
115 let Some(node) = nodes.get_mut(id) else {
116 return false;
117 };
118
119 if let Some(t) = node.exhausted_until {
121 if Instant::now() >= t {
122 node.exhausted_until = None;
123 }
124 }
125
126 for &(dim_name, amount) in amounts {
127 for (name, _, counter) in &node.dimensions {
128 if name == dim_name {
129 counter.record(amount);
130 break;
131 }
132 }
133 }
134 true
135 }
136
137 pub fn mark_exhausted(&self, id: &Id, duration: Duration) {
139 let mut nodes = self.nodes.write().unwrap();
140 if let Some(node) = nodes.get_mut(id) {
141 node.exhausted_until = Some(Instant::now() + duration);
142 }
143 }
144
145 pub fn has_capacity(&self, id: &Id, estimated: &[(&str, u64)]) -> bool {
149 let nodes = self.nodes.read().unwrap();
150 let Some(node) = nodes.get(id) else {
151 return true;
152 };
153
154 if let Some(t) = node.exhausted_until {
156 if Instant::now() < t {
157 return false;
158 }
159 }
160
161 for &(dim_name, est) in estimated {
162 for (name, limit, counter) in &node.dimensions {
163 if name == dim_name {
164 if counter.remaining(*limit) < est {
165 return false;
166 }
167 break;
168 }
169 }
170 }
171
172 true
173 }
174
175 pub fn remaining(&self, id: &Id, dimension: &str) -> u64 {
177 let nodes = self.nodes.read().unwrap();
178 let Some(node) = nodes.get(id) else {
179 return u64::MAX;
180 };
181
182 for (name, limit, counter) in &node.dimensions {
183 if name == dimension {
184 return counter.remaining(*limit);
185 }
186 }
187
188 u64::MAX
189 }
190
191 pub fn pressure(&self, id: &Id) -> f64 {
196 let nodes = self.nodes.read().unwrap();
197 let Some(node) = nodes.get(id) else {
198 return 0.0;
199 };
200
201 if let Some(t) = node.exhausted_until {
203 if Instant::now() < t {
204 return 1.0;
205 }
206 }
207
208 let mut max_pressure: f64 = 0.0;
209 for (_, limit, counter) in &node.dimensions {
210 if *limit == 0 {
211 continue;
212 }
213 let usage = counter.sum() as f64 / *limit as f64;
214 if usage > max_pressure {
215 max_pressure = usage;
216 }
217 }
218
219 max_pressure.min(1.0)
220 }
221}
222
223impl<Id: Eq + Hash + Clone> Default for QuotaTracker<Id> {
224 fn default() -> Self {
225 Self::new()
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 fn make_tracker_with_node(id: &str, limit: u64) -> (QuotaTracker<String>, String) {
234 let tracker = QuotaTracker::<String>::new();
235 let node_id = id.to_string();
236 tracker.register(
237 &node_id,
238 QuotaConfig::builder()
239 .dimensions(vec![QuotaDimension::builder()
240 .name("rpm".into())
241 .window(Duration::from_secs(60))
242 .limit(limit)
243 .build()])
244 .build(),
245 );
246 (tracker, node_id)
247 }
248
249 #[test]
250 fn basic_register_and_record() {
251 let (tracker, id) = make_tracker_with_node("n1", 100);
252
253 tracker.record_usage(&id, &[("rpm", 10)]);
254 assert_eq!(tracker.remaining(&id, "rpm"), 90);
255
256 tracker.record_usage(&id, &[("rpm", 5)]);
257 assert_eq!(tracker.remaining(&id, "rpm"), 85);
258 }
259
260 #[test]
261 fn pressure_increases_with_usage() {
262 let (tracker, id) = make_tracker_with_node("n1", 100);
263
264 let p0 = tracker.pressure(&id);
265 assert!((p0 - 0.0).abs() < f64::EPSILON);
266
267 tracker.record_usage(&id, &[("rpm", 50)]);
268 let p50 = tracker.pressure(&id);
269 assert!((p50 - 0.5).abs() < f64::EPSILON);
270
271 tracker.record_usage(&id, &[("rpm", 50)]);
272 let p100 = tracker.pressure(&id);
273 assert!((p100 - 1.0).abs() < f64::EPSILON);
274 }
275
276 #[test]
277 fn has_capacity_with_estimation() {
278 let (tracker, id) = make_tracker_with_node("n1", 100);
279
280 tracker.record_usage(&id, &[("rpm", 90)]);
281 assert!(tracker.has_capacity(&id, &[("rpm", 10)]));
282 assert!(!tracker.has_capacity(&id, &[("rpm", 11)]));
283 }
284
285 #[test]
286 fn mark_exhausted_blocks_capacity() {
287 let (tracker, id) = make_tracker_with_node("n1", 100);
288
289 tracker.mark_exhausted(&id, Duration::from_secs(60));
290 assert!(!tracker.has_capacity(&id, &[("rpm", 1)]));
291 assert!((tracker.pressure(&id) - 1.0).abs() < f64::EPSILON);
292 }
293
294 #[test]
295 fn exhausted_expires() {
296 let (tracker, id) = make_tracker_with_node("n1", 100);
297
298 tracker.mark_exhausted(&id, Duration::from_millis(20));
299 assert!(!tracker.has_capacity(&id, &[("rpm", 1)]));
300
301 std::thread::sleep(Duration::from_millis(30));
302 assert!(tracker.has_capacity(&id, &[("rpm", 1)]));
303 }
304
305 #[test]
306 fn unknown_node_has_unlimited_capacity() {
307 let tracker = QuotaTracker::<String>::new();
308 let unknown = "unknown".to_string();
309
310 assert!(tracker.has_capacity(&unknown, &[("rpm", 999)]));
311 assert_eq!(tracker.remaining(&unknown, "rpm"), u64::MAX);
312 assert!((tracker.pressure(&unknown) - 0.0).abs() < f64::EPSILON);
313 }
314}