use std::sync::Arc;
use crate::decision::Decision;
#[cfg(feature = "tokio")]
use crate::error::ThrottleError;
use crate::limiter::{Limiter, acquire_all, peek_all};
type Dimension = (Box<str>, Arc<dyn Limiter>);
#[derive(Clone)]
pub struct MultiLimiter {
dimensions: Arc<[Dimension]>,
}
#[inline]
fn cost_for(name: &str, costs: &[(&str, u32)]) -> u32 {
costs
.iter()
.copied()
.find(|(n, _)| *n == name)
.map_or(0, |(_, c)| c)
}
impl MultiLimiter {
#[must_use]
pub fn builder() -> MultiLimiterBuilder {
MultiLimiterBuilder {
dimensions: Vec::new(),
}
}
#[inline]
fn pairs<'a>(
&'a self,
costs: &'a [(&'a str, u32)],
) -> impl Iterator<Item = (&'a dyn Limiter, u32)> + Clone {
self.dimensions
.iter()
.map(move |(name, limiter)| (limiter.as_ref(), cost_for(name, costs)))
}
#[inline]
#[must_use]
pub fn peek_costs(&self, costs: &[(&str, u32)]) -> Decision {
peek_all(self.pairs(costs))
}
#[inline]
#[must_use]
pub fn try_acquire_costs(&self, costs: &[(&str, u32)]) -> bool {
acquire_all(self.pairs(costs)).is_acquired()
}
#[must_use]
pub fn available(&self, dimension: &str) -> Option<u32> {
self.dimensions
.iter()
.find(|(name, _)| name.as_ref() == dimension)
.map(|(_, limiter)| limiter.available())
}
}
#[cfg(feature = "tokio")]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
impl MultiLimiter {
pub async fn acquire_costs(&self, costs: &[(&str, u32)]) -> Result<(), ThrottleError> {
loop {
match acquire_all(self.pairs(costs)) {
Decision::Acquired => return Ok(()),
Decision::Impossible => return Err(self.capacity_error(costs)),
Decision::Retry { after } => tokio::time::sleep(after).await,
}
}
}
fn capacity_error(&self, costs: &[(&str, u32)]) -> ThrottleError {
for (name, limiter) in self.dimensions.iter() {
let cost = cost_for(name, costs);
if cost > limiter.capacity() {
return ThrottleError::CostExceedsCapacity {
cost,
capacity: limiter.capacity(),
};
}
}
for (name, limiter) in self.dimensions.iter() {
let cost = cost_for(name, costs);
if limiter.peek(cost) == Decision::Impossible {
return ThrottleError::CostExceedsCapacity {
cost,
capacity: limiter.capacity(),
};
}
}
ThrottleError::CostExceedsCapacity {
cost: 0,
capacity: 0,
}
}
}
#[derive(Default)]
pub struct MultiLimiterBuilder {
dimensions: Vec<Dimension>,
}
impl MultiLimiterBuilder {
#[must_use]
pub fn dimension(mut self, name: impl Into<Box<str>>, limiter: impl Limiter + 'static) -> Self {
self.dimensions.push((name.into(), Arc::new(limiter)));
self
}
#[must_use]
pub fn shared(mut self, name: impl Into<Box<str>>, limiter: Arc<dyn Limiter>) -> Self {
self.dimensions.push((name.into(), limiter));
self
}
#[must_use]
pub fn build(self) -> MultiLimiter {
MultiLimiter {
dimensions: self.dimensions.into(),
}
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::MultiLimiter;
use crate::throttle::Throttle;
use clock_lib::ManualClock;
use core::time::Duration;
use std::sync::Arc;
fn assert_send_sync<T: Send + Sync>() {}
#[test]
fn test_multi_limiter_is_send_sync() {
assert_send_sync::<MultiLimiter>();
}
#[test]
fn test_all_dimensions_must_afford_their_share() {
let limiter = MultiLimiter::builder()
.dimension("requests", Throttle::per_second(10))
.dimension("tokens", Throttle::per_second(1000))
.build();
assert!(limiter.try_acquire_costs(&[("requests", 1), ("tokens", 1000)]));
assert!(!limiter.try_acquire_costs(&[("requests", 1), ("tokens", 1)]));
assert_eq!(limiter.available("requests"), Some(9));
}
#[test]
fn test_unmentioned_dimension_is_not_charged() {
let limiter = MultiLimiter::builder()
.dimension("requests", Throttle::per_second(2))
.dimension("tokens", Throttle::per_second(100))
.build();
assert!(limiter.try_acquire_costs(&[("requests", 1)]));
assert_eq!(limiter.available("tokens"), Some(100));
assert_eq!(limiter.available("requests"), Some(1));
}
#[test]
fn test_unknown_dimension_name_is_ignored() {
let limiter = MultiLimiter::builder()
.dimension("requests", Throttle::per_second(1))
.build();
assert!(limiter.try_acquire_costs(&[("requests", 1), ("nonexistent", 999)]));
}
#[test]
fn test_available_is_none_for_unknown_dimension() {
let limiter = MultiLimiter::builder()
.dimension("requests", Throttle::per_second(1))
.build();
assert_eq!(limiter.available("missing"), None);
}
#[test]
fn test_peek_costs_does_not_consume() {
let limiter = MultiLimiter::builder()
.dimension("requests", Throttle::per_second(5))
.build();
assert!(limiter.peek_costs(&[("requests", 5)]).is_acquired());
assert_eq!(limiter.available("requests"), Some(5));
}
#[test]
fn test_refill_recovers_each_dimension_under_manual_clock() {
let clock = Arc::new(ManualClock::new());
let limiter = MultiLimiter::builder()
.dimension(
"requests",
Throttle::per_second(2).with_clock(clock.clone()),
)
.dimension("tokens", Throttle::per_second(10).with_clock(clock.clone()))
.build();
assert!(limiter.try_acquire_costs(&[("requests", 2), ("tokens", 10)]));
assert!(!limiter.try_acquire_costs(&[("requests", 1), ("tokens", 1)]));
clock.advance(Duration::from_secs(1));
assert!(limiter.try_acquire_costs(&[("requests", 2), ("tokens", 10)]));
}
#[cfg(feature = "tokio")]
#[tokio::test]
async fn test_acquire_costs_errors_and_names_the_overspent_dimension() {
use crate::error::ThrottleError;
let limiter = MultiLimiter::builder()
.dimension("requests", Throttle::per_second(100))
.dimension("tokens", Throttle::per_second(1000))
.build();
let err = limiter
.acquire_costs(&[("requests", 1), ("tokens", 2000)])
.await
.unwrap_err();
assert_eq!(
err,
ThrottleError::CostExceedsCapacity {
cost: 2000,
capacity: 1000,
}
);
}
}