use std::collections::HashMap;
use std::hash::Hash;
use std::sync::RwLock;
use std::time::{Duration, Instant};
use crate::window::FixedWindowCounter;
#[derive(Clone, bon::Builder)]
pub struct QuotaDimension {
pub name: String,
pub window: Duration,
pub limit: u64,
#[builder(default = 6)]
pub resolution: usize,
}
#[derive(Clone, bon::Builder)]
pub struct QuotaConfig {
pub dimensions: Vec<QuotaDimension>,
}
struct NodeQuota {
dimensions: Vec<(String, u64, FixedWindowCounter)>,
exhausted_until: Option<Instant>,
}
pub struct QuotaTracker<Id: Eq + Hash + Clone> {
nodes: RwLock<HashMap<Id, NodeQuota>>,
}
impl<Id: Eq + Hash + Clone> std::fmt::Debug for QuotaTracker<Id> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let count = self.nodes.read().unwrap().len();
f.debug_struct("QuotaTracker")
.field("tracked_nodes", &count)
.finish_non_exhaustive()
}
}
impl<Id: Eq + Hash + Clone> QuotaTracker<Id> {
pub fn new() -> Self {
Self {
nodes: RwLock::new(HashMap::new()),
}
}
pub fn register(&self, id: &Id, config: QuotaConfig) {
let dimensions = config
.dimensions
.into_iter()
.map(|d| {
let counter = FixedWindowCounter::new(d.window, d.resolution);
(d.name, d.limit, counter)
})
.collect();
let mut nodes = self.nodes.write().unwrap();
nodes.insert(
id.clone(),
NodeQuota {
dimensions,
exhausted_until: None,
},
);
}
pub fn record_usage(&self, id: &Id, amounts: &[(&str, u64)]) -> bool {
let mut nodes = self.nodes.write().unwrap();
let Some(node) = nodes.get_mut(id) else {
return false;
};
if let Some(t) = node.exhausted_until {
if Instant::now() >= t {
node.exhausted_until = None;
}
}
for &(dim_name, amount) in amounts {
for (name, _, counter) in &node.dimensions {
if name == dim_name {
counter.record(amount);
break;
}
}
}
true
}
pub fn mark_exhausted(&self, id: &Id, duration: Duration) {
let mut nodes = self.nodes.write().unwrap();
if let Some(node) = nodes.get_mut(id) {
node.exhausted_until = Some(Instant::now() + duration);
}
}
pub fn has_capacity(&self, id: &Id, estimated: &[(&str, u64)]) -> bool {
let nodes = self.nodes.read().unwrap();
let Some(node) = nodes.get(id) else {
return true;
};
if let Some(t) = node.exhausted_until {
if Instant::now() < t {
return false;
}
}
for &(dim_name, est) in estimated {
for (name, limit, counter) in &node.dimensions {
if name == dim_name {
if counter.remaining(*limit) < est {
return false;
}
break;
}
}
}
true
}
pub fn remaining(&self, id: &Id, dimension: &str) -> u64 {
let nodes = self.nodes.read().unwrap();
let Some(node) = nodes.get(id) else {
return u64::MAX;
};
for (name, limit, counter) in &node.dimensions {
if name == dimension {
return counter.remaining(*limit);
}
}
u64::MAX
}
pub fn pressure(&self, id: &Id) -> f64 {
let nodes = self.nodes.read().unwrap();
let Some(node) = nodes.get(id) else {
return 0.0;
};
if let Some(t) = node.exhausted_until {
if Instant::now() < t {
return 1.0;
}
}
let mut max_pressure: f64 = 0.0;
for (_, limit, counter) in &node.dimensions {
if *limit == 0 {
continue;
}
let usage = counter.sum() as f64 / *limit as f64;
if usage > max_pressure {
max_pressure = usage;
}
}
max_pressure.min(1.0)
}
}
impl<Id: Eq + Hash + Clone> Default for QuotaTracker<Id> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_tracker_with_node(id: &str, limit: u64) -> (QuotaTracker<String>, String) {
let tracker = QuotaTracker::<String>::new();
let node_id = id.to_string();
tracker.register(
&node_id,
QuotaConfig::builder()
.dimensions(vec![QuotaDimension::builder()
.name("rpm".into())
.window(Duration::from_secs(60))
.limit(limit)
.build()])
.build(),
);
(tracker, node_id)
}
#[test]
fn basic_register_and_record() {
let (tracker, id) = make_tracker_with_node("n1", 100);
tracker.record_usage(&id, &[("rpm", 10)]);
assert_eq!(tracker.remaining(&id, "rpm"), 90);
tracker.record_usage(&id, &[("rpm", 5)]);
assert_eq!(tracker.remaining(&id, "rpm"), 85);
}
#[test]
fn pressure_increases_with_usage() {
let (tracker, id) = make_tracker_with_node("n1", 100);
let p0 = tracker.pressure(&id);
assert!((p0 - 0.0).abs() < f64::EPSILON);
tracker.record_usage(&id, &[("rpm", 50)]);
let p50 = tracker.pressure(&id);
assert!((p50 - 0.5).abs() < f64::EPSILON);
tracker.record_usage(&id, &[("rpm", 50)]);
let p100 = tracker.pressure(&id);
assert!((p100 - 1.0).abs() < f64::EPSILON);
}
#[test]
fn has_capacity_with_estimation() {
let (tracker, id) = make_tracker_with_node("n1", 100);
tracker.record_usage(&id, &[("rpm", 90)]);
assert!(tracker.has_capacity(&id, &[("rpm", 10)]));
assert!(!tracker.has_capacity(&id, &[("rpm", 11)]));
}
#[test]
fn mark_exhausted_blocks_capacity() {
let (tracker, id) = make_tracker_with_node("n1", 100);
tracker.mark_exhausted(&id, Duration::from_secs(60));
assert!(!tracker.has_capacity(&id, &[("rpm", 1)]));
assert!((tracker.pressure(&id) - 1.0).abs() < f64::EPSILON);
}
#[test]
fn exhausted_expires() {
let (tracker, id) = make_tracker_with_node("n1", 100);
tracker.mark_exhausted(&id, Duration::from_millis(20));
assert!(!tracker.has_capacity(&id, &[("rpm", 1)]));
std::thread::sleep(Duration::from_millis(30));
assert!(tracker.has_capacity(&id, &[("rpm", 1)]));
}
#[test]
fn unknown_node_has_unlimited_capacity() {
let tracker = QuotaTracker::<String>::new();
let unknown = "unknown".to_string();
assert!(tracker.has_capacity(&unknown, &[("rpm", 999)]));
assert_eq!(tracker.remaining(&unknown, "rpm"), u64::MAX);
assert!((tracker.pressure(&unknown) - 0.0).abs() < f64::EPSILON);
}
}