use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use parking_lot::Mutex;
use super::bucket::{TokenBucket, TokenBucketConfig};
use super::policy::{QuotaPolicy, QuotaScope};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BucketDimension {
Rps,
Cost,
}
impl BucketDimension {
pub fn as_str(self) -> &'static str {
match self {
BucketDimension::Rps => "rps",
BucketDimension::Cost => "cost",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BucketKey {
pub scope: QuotaScope,
pub dimension: BucketDimension,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ConsumeOutcome {
Allowed {
remaining: u64,
},
Rejected {
retry_after: Option<Duration>,
tokens_available: u64,
tokens_requested: u64,
},
NotConfigured,
}
#[derive(Clone, Default)]
pub struct BucketRegistry {
inner: Arc<Mutex<HashMap<BucketKey, Arc<Mutex<TokenBucket>>>>>,
}
impl BucketRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.inner.lock().len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn clear(&self) {
self.inner.lock().clear();
}
pub fn consume(
&self,
scope: QuotaScope,
dimension: BucketDimension,
tokens: u64,
policy: &QuotaPolicy,
) -> ConsumeOutcome {
let limit = match dimension {
BucketDimension::Rps => policy.rps_limit,
BucketDimension::Cost => policy.cost_budget_per_sec,
};
let Some(rate) = limit else {
return ConsumeOutcome::NotConfigured;
};
let cfg = TokenBucketConfig::new(rate.max(1), rate);
let key = BucketKey { scope, dimension };
let bucket_arc = {
let mut guard = self.inner.lock();
guard
.entry(key)
.or_insert_with(|| Arc::new(Mutex::new(TokenBucket::new(cfg))))
.clone()
};
let mut bucket = bucket_arc.lock();
if bucket.try_consume(tokens) {
ConsumeOutcome::Allowed {
remaining: bucket.tokens(),
}
} else {
let retry_after = bucket.time_until_n(tokens);
ConsumeOutcome::Rejected {
retry_after,
tokens_available: bucket.tokens(),
tokens_requested: tokens,
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::admission::policy::PROVISIONAL_DEFAULTS;
fn provisional_principal(id: &str) -> QuotaPolicy {
QuotaPolicy::provisional_for(QuotaScope::principal(id))
}
#[test]
fn empty_registry_starts_with_no_buckets() {
let r = BucketRegistry::new();
assert!(r.is_empty());
}
#[test]
fn first_consume_materializes_bucket() {
let r = BucketRegistry::new();
let pol = provisional_principal("alice");
let _ = r.consume(
QuotaScope::principal("alice"),
BucketDimension::Rps,
1,
&pol,
);
assert_eq!(r.len(), 1);
}
#[test]
fn rps_dimension_allows_within_limit() {
let r = BucketRegistry::new();
let pol = provisional_principal("alice");
let outcome = r.consume(
QuotaScope::principal("alice"),
BucketDimension::Rps,
1,
&pol,
);
assert!(matches!(outcome, ConsumeOutcome::Allowed { .. }));
}
#[test]
fn rps_dimension_rejects_when_burst_exhausted() {
let r = BucketRegistry::new();
let pol = QuotaPolicy {
rps_limit: Some(3),
..provisional_principal("alice")
};
let scope = QuotaScope::principal("alice");
let pol = QuotaPolicy {
rps_limit: Some(10),
..provisional_principal("alice")
};
let mut allowed = 0;
let mut rejected = 0;
for _ in 0..5 {
match r.consume(scope.clone(), BucketDimension::Rps, 1, &pol) {
ConsumeOutcome::Allowed { .. } => allowed += 1,
ConsumeOutcome::Rejected { .. } => rejected += 1,
ConsumeOutcome::NotConfigured => panic!("dimension is configured"),
}
}
assert!(allowed >= 1, "at least the warm tokens should pass");
assert!(rejected >= 1, "exhausting the warm tokens should reject");
}
#[test]
fn cost_dimension_uses_cost_budget_field() {
let r = BucketRegistry::new();
let pol = QuotaPolicy {
cost_budget_per_sec: Some(20),
..provisional_principal("alice")
};
let scope = QuotaScope::principal("alice");
let first = r.consume(scope.clone(), BucketDimension::Cost, 5, &pol);
assert!(matches!(first, ConsumeOutcome::Allowed { .. }));
let second = r.consume(scope.clone(), BucketDimension::Cost, 5, &pol);
assert!(matches!(second, ConsumeOutcome::Rejected { .. }));
}
#[test]
fn rps_and_cost_buckets_are_independent() {
let r = BucketRegistry::new();
let pol = provisional_principal("alice");
let scope = QuotaScope::principal("alice");
let _ = r.consume(scope.clone(), BucketDimension::Rps, 1, &pol);
let _ = r.consume(scope.clone(), BucketDimension::Cost, 5, &pol);
assert_eq!(r.len(), 2);
}
#[test]
fn dimension_not_configured_is_pass_through() {
let r = BucketRegistry::new();
let pol = QuotaPolicy {
rps_limit: None,
cost_budget_per_sec: Some(100),
..provisional_principal("alice")
};
let outcome = r.consume(
QuotaScope::principal("alice"),
BucketDimension::Rps,
1,
&pol,
);
assert_eq!(outcome, ConsumeOutcome::NotConfigured);
assert_eq!(r.len(), 0);
}
#[test]
fn different_scopes_have_independent_buckets() {
let r = BucketRegistry::new();
let pol_a = provisional_principal("alice");
let pol_b = provisional_principal("bob");
let _ = r.consume(
QuotaScope::principal("alice"),
BucketDimension::Rps,
1,
&pol_a,
);
let _ = r.consume(
QuotaScope::principal("bob"),
BucketDimension::Rps,
1,
&pol_b,
);
assert_eq!(r.len(), 2);
}
#[test]
fn rejected_outcome_includes_diagnostics() {
let r = BucketRegistry::new();
let pol = QuotaPolicy {
rps_limit: Some(10),
..provisional_principal("alice")
};
for _ in 0..10 {
let _ = r.consume(
QuotaScope::principal("alice"),
BucketDimension::Rps,
1,
&pol,
);
}
let outcome = r.consume(
QuotaScope::principal("alice"),
BucketDimension::Rps,
100,
&pol,
);
match outcome {
ConsumeOutcome::Rejected {
retry_after,
tokens_available,
tokens_requested,
} => {
assert_eq!(tokens_requested, 100);
assert!(tokens_available < 100);
assert!(retry_after.is_none());
}
other => panic!("expected Rejected, got {:?}", other),
}
}
#[test]
fn clear_drops_all_buckets() {
let r = BucketRegistry::new();
let pol = provisional_principal("alice");
let _ = r.consume(
QuotaScope::principal("alice"),
BucketDimension::Rps,
1,
&pol,
);
let _ = r.consume(QuotaScope::namespace("ns1"), BucketDimension::Cost, 1, &pol);
assert!(r.len() >= 1);
r.clear();
assert!(r.is_empty());
}
#[test]
fn registry_is_clone_cheap() {
let r = BucketRegistry::new();
let r2 = r.clone();
let pol = provisional_principal("alice");
let _ = r.consume(
QuotaScope::principal("alice"),
BucketDimension::Rps,
1,
&pol,
);
assert_eq!(r2.len(), 1);
}
#[test]
fn provisional_defaults_yield_well_formed_buckets() {
let r = BucketRegistry::new();
let pol = QuotaPolicy::provisional_for(QuotaScope::global());
let outcome = r.consume(QuotaScope::global(), BucketDimension::Cost, 100, &pol);
match outcome {
ConsumeOutcome::Allowed { remaining } => {
assert!(remaining < PROVISIONAL_DEFAULTS.cost_budget);
}
other => panic!("expected Allowed, got {:?}", other),
}
}
}