use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ScopeKind {
Principal,
Namespace,
Global,
}
impl ScopeKind {
pub fn as_str(self) -> &'static str {
match self {
ScopeKind::Principal => "principal",
ScopeKind::Namespace => "namespace",
ScopeKind::Global => "global",
}
}
pub fn parse(s: &str) -> Option<Self> {
match s {
"principal" => Some(ScopeKind::Principal),
"namespace" => Some(ScopeKind::Namespace),
"global" => Some(ScopeKind::Global),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct QuotaScope {
pub kind: ScopeKind,
pub value: String,
}
impl QuotaScope {
pub fn principal(id: impl Into<String>) -> Self {
Self {
kind: ScopeKind::Principal,
value: id.into(),
}
}
pub fn namespace(name: impl Into<String>) -> Self {
Self {
kind: ScopeKind::Namespace,
value: name.into(),
}
}
pub fn global() -> Self {
Self {
kind: ScopeKind::Global,
value: "*".to_string(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Tier {
Bronze,
Silver,
Gold,
}
impl Tier {
pub fn as_str(self) -> &'static str {
match self {
Tier::Bronze => "bronze",
Tier::Silver => "silver",
Tier::Gold => "gold",
}
}
pub fn parse(s: &str) -> Option<Self> {
match s.to_ascii_lowercase().as_str() {
"bronze" => Some(Tier::Bronze),
"silver" => Some(Tier::Silver),
"gold" => Some(Tier::Gold),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct QuotaPolicy {
pub scope: QuotaScope,
pub rps_limit: Option<u64>,
pub cost_budget_per_sec: Option<u64>,
pub max_concurrent_expanded: Option<u32>,
pub tier: Option<Tier>,
pub policy_version: u64,
}
impl QuotaPolicy {
pub fn provisional_for(scope: QuotaScope) -> Self {
Self {
scope,
rps_limit: Some(PROVISIONAL_DEFAULTS.rps),
cost_budget_per_sec: Some(PROVISIONAL_DEFAULTS.cost_budget),
max_concurrent_expanded: Some(PROVISIONAL_DEFAULTS.max_concurrent_expanded),
tier: None,
policy_version: 1,
}
}
pub fn fallback_for(scope: QuotaScope) -> Self {
Self {
scope,
rps_limit: Some(FALLBACK_DEFAULTS.rps),
cost_budget_per_sec: Some(FALLBACK_DEFAULTS.cost_budget),
max_concurrent_expanded: Some(FALLBACK_DEFAULTS.max_concurrent_expanded),
tier: None,
policy_version: 0, }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ProvisionalDefaults {
pub rps: u64,
pub cost_budget: u64,
pub max_concurrent_expanded: u32,
pub cache_ttl: Duration,
}
pub const PROVISIONAL_DEFAULTS: ProvisionalDefaults = ProvisionalDefaults {
rps: 100,
cost_budget: 1000,
max_concurrent_expanded: 4,
cache_ttl: Duration::from_secs(600),
};
pub const FALLBACK_DEFAULTS: ProvisionalDefaults = ProvisionalDefaults {
rps: 50, cost_budget: 500, max_concurrent_expanded: 2, cache_ttl: Duration::from_secs(60),
};
pub trait PolicyResolver: Send + Sync {
fn resolve(
&self,
principal: Option<&str>,
namespace: Option<&str>,
) -> Result<QuotaPolicy, ResolveError>;
}
#[derive(Debug, thiserror::Error)]
pub enum ResolveError {
#[error("control DB unavailable and no cached policy")]
Unavailable,
#[error("backend error: {0}")]
Backend(String),
}
pub fn provisional_global() -> QuotaPolicy {
QuotaPolicy::provisional_for(QuotaScope::global())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn scope_kind_round_trip() {
for k in [
ScopeKind::Principal,
ScopeKind::Namespace,
ScopeKind::Global,
] {
assert_eq!(ScopeKind::parse(k.as_str()), Some(k));
}
assert!(ScopeKind::parse("nonsense").is_none());
}
#[test]
fn tier_round_trip_case_insensitive() {
assert_eq!(Tier::parse("Gold"), Some(Tier::Gold));
assert_eq!(Tier::parse("SILVER"), Some(Tier::Silver));
assert_eq!(Tier::parse("bronze"), Some(Tier::Bronze));
assert!(Tier::parse("platinum").is_none());
}
#[test]
fn provisional_for_is_full_policy() {
let p = QuotaPolicy::provisional_for(QuotaScope::principal("alice"));
assert_eq!(p.rps_limit, Some(PROVISIONAL_DEFAULTS.rps));
assert_eq!(
p.cost_budget_per_sec,
Some(PROVISIONAL_DEFAULTS.cost_budget)
);
assert_eq!(
p.max_concurrent_expanded,
Some(PROVISIONAL_DEFAULTS.max_concurrent_expanded)
);
assert_eq!(p.policy_version, 1);
}
#[test]
fn fallback_is_strictly_smaller_than_provisional() {
assert!(FALLBACK_DEFAULTS.rps <= PROVISIONAL_DEFAULTS.rps);
assert!(FALLBACK_DEFAULTS.cost_budget <= PROVISIONAL_DEFAULTS.cost_budget);
assert!(
FALLBACK_DEFAULTS.max_concurrent_expanded
<= PROVISIONAL_DEFAULTS.max_concurrent_expanded
);
}
#[test]
fn fallback_policy_has_synthetic_version_zero() {
let p = QuotaPolicy::fallback_for(QuotaScope::namespace("hot_ns"));
assert_eq!(p.policy_version, 0);
}
#[test]
fn provisional_global_helper_returns_global_scope() {
let p = provisional_global();
assert_eq!(p.scope.kind, ScopeKind::Global);
assert_eq!(p.scope.value, "*");
}
#[test]
fn quota_scope_constructors_set_kind() {
assert_eq!(QuotaScope::principal("alice").kind, ScopeKind::Principal);
assert_eq!(QuotaScope::namespace("ns1").kind, ScopeKind::Namespace);
assert_eq!(QuotaScope::global().kind, ScopeKind::Global);
assert_eq!(QuotaScope::global().value, "*");
}
struct StaticResolver {
principals: std::collections::HashMap<String, QuotaPolicy>,
namespaces: std::collections::HashMap<String, QuotaPolicy>,
global: QuotaPolicy,
}
impl PolicyResolver for StaticResolver {
fn resolve(
&self,
principal: Option<&str>,
namespace: Option<&str>,
) -> Result<QuotaPolicy, ResolveError> {
if let Some(p) = principal {
if let Some(pol) = self.principals.get(p) {
return Ok(pol.clone());
}
}
if let Some(n) = namespace {
if let Some(pol) = self.namespaces.get(n) {
return Ok(pol.clone());
}
}
Ok(self.global.clone())
}
}
#[test]
fn resolver_picks_principal_first() {
let r = StaticResolver {
principals: std::collections::HashMap::from([(
"alice".to_string(),
QuotaPolicy {
rps_limit: Some(999),
..QuotaPolicy::provisional_for(QuotaScope::principal("alice"))
},
)]),
namespaces: std::collections::HashMap::new(),
global: provisional_global(),
};
let pol = r.resolve(Some("alice"), Some("any_ns")).unwrap();
assert_eq!(pol.rps_limit, Some(999));
}
#[test]
fn resolver_falls_back_to_namespace_when_no_principal() {
let r = StaticResolver {
principals: std::collections::HashMap::new(),
namespaces: std::collections::HashMap::from([(
"shared".to_string(),
QuotaPolicy {
rps_limit: Some(50),
..QuotaPolicy::provisional_for(QuotaScope::namespace("shared"))
},
)]),
global: provisional_global(),
};
let pol = r.resolve(Some("never_seen"), Some("shared")).unwrap();
assert_eq!(pol.rps_limit, Some(50));
}
#[test]
fn resolver_falls_back_to_global_when_no_match() {
let r = StaticResolver {
principals: std::collections::HashMap::new(),
namespaces: std::collections::HashMap::new(),
global: provisional_global(),
};
let pol = r.resolve(Some("ghost"), Some("nowhere")).unwrap();
assert_eq!(pol.scope.kind, ScopeKind::Global);
}
}