use std::{
cmp,
fmt::Debug,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
time::Duration,
};
use async_trait::async_trait;
use conv::ValueFrom;
use tokio::{
sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError},
time::timeout,
};
pub use partitioning::PartitionedLimiter;
pub use rejection_delay::RejectionDelay;
pub use token::Token;
use crate::limits::{LimitAlgorithm, Sample};
mod partitioning;
mod rejection_delay;
mod token;
type CapacityUnit = usize;
type AtomicCapacityUnit = AtomicUsize;
#[async_trait]
pub trait Limiter: Debug + Sync {
async fn try_acquire(&self) -> Option<Token>;
async fn acquire_timeout(&self, duration: Duration) -> Option<Token>;
async fn release(&self, token: Token, outcome: Option<Outcome>) -> CapacityUnit;
}
#[derive(Debug)]
pub struct DefaultLimiter<T> {
limit_algo: T,
semaphore: Arc<Semaphore>,
limit: AtomicCapacityUnit,
in_flight: Arc<AtomicCapacityUnit>,
#[cfg(test)]
notifier: Option<Arc<tokio::sync::Notify>>,
}
#[derive(Debug, Clone, Copy)]
pub struct LimiterState {
limit: CapacityUnit,
available: CapacityUnit,
in_flight: CapacityUnit,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Outcome {
Success,
Overload,
}
impl<T> DefaultLimiter<T>
where
T: LimitAlgorithm,
{
pub fn new(limit_algo: T) -> Self {
let initial_permits = limit_algo.limit();
assert!(initial_permits >= 1);
Self {
limit_algo,
semaphore: Arc::new(Semaphore::new(initial_permits)),
limit: AtomicCapacityUnit::new(initial_permits),
in_flight: Arc::new(AtomicCapacityUnit::new(0)),
#[cfg(test)]
notifier: None,
}
}
#[cfg(test)]
pub fn with_release_notifier(mut self, n: Arc<tokio::sync::Notify>) -> Self {
self.notifier.replace(n);
self
}
fn new_sample(&self, latency: Duration, outcome: Outcome) -> Sample {
Sample {
latency,
in_flight: self.in_flight(),
outcome,
}
}
fn available(&self) -> CapacityUnit {
self.semaphore.available_permits()
}
pub(crate) fn limit(&self) -> CapacityUnit {
self.limit.load(Ordering::Acquire)
}
fn in_flight(&self) -> CapacityUnit {
self.in_flight.load(Ordering::Acquire)
}
pub(crate) fn in_flight_shared(&self) -> Arc<AtomicCapacityUnit> {
self.in_flight.clone()
}
pub fn state(&self) -> LimiterState {
LimiterState {
limit: self.limit(),
available: self.available(),
in_flight: self.in_flight(),
}
}
pub(crate) fn mint_token(&self, permit: OwnedSemaphorePermit) -> Token {
Token::new(permit, self.in_flight.clone())
}
}
#[async_trait]
impl<T> Limiter for DefaultLimiter<T>
where
T: LimitAlgorithm + Sync + Debug,
{
async fn try_acquire(&self) -> Option<Token> {
match Arc::clone(&self.semaphore).try_acquire_owned() {
Ok(permit) => Some(self.mint_token(permit)),
Err(TryAcquireError::NoPermits) => None,
Err(TryAcquireError::Closed) => {
panic!("we own the semaphore, we shouldn't have closed it")
}
}
}
async fn acquire_timeout(&self, duration: Duration) -> Option<Token> {
match timeout(duration, Arc::clone(&self.semaphore).acquire_owned()).await {
Ok(Ok(permit)) => Some(self.mint_token(permit)),
Err(_) => None,
Ok(Err(_)) => {
panic!("we own the semaphore, we shouldn't have closed it")
}
}
}
async fn release(&self, token: Token, outcome: Option<Outcome>) -> CapacityUnit {
let limit = if let Some(outcome) = outcome {
let sample = self.new_sample(token.latency(), outcome);
let new_limit = self.limit_algo.update(sample).await;
let old_limit = self.limit.swap(new_limit, Ordering::SeqCst);
match new_limit.cmp(&old_limit) {
cmp::Ordering::Greater => {
self.semaphore.add_permits(new_limit - old_limit);
#[cfg(test)]
if let Some(n) = &self.notifier {
n.notify_one();
}
}
cmp::Ordering::Less => {
let semaphore = self.semaphore.clone();
#[cfg(test)]
let notifier = self.notifier.clone();
tokio::spawn(async move {
let permits = semaphore
.acquire_many(
u32::value_from(old_limit - new_limit)
.expect("change in limit shouldn't be > u32::MAX"),
)
.await
.expect("we own the semaphore, we shouldn't have closed it");
permits.forget();
#[cfg(test)]
if let Some(n) = notifier {
n.notify_one();
}
});
}
_ =>
{
#[cfg(test)]
if let Some(n) = &self.notifier {
n.notify_one();
}
}
}
new_limit
} else {
self.limit_algo.limit()
};
drop(token);
limit
}
}
impl LimiterState {
pub fn limit(&self) -> CapacityUnit {
self.limit
}
pub fn available(&self) -> CapacityUnit {
self.available
}
pub fn in_flight(&self) -> CapacityUnit {
self.in_flight
}
}
impl Outcome {
pub(crate) fn overloaded_or(self, other: Outcome) -> Outcome {
use Outcome::*;
match (self, other) {
(Success, Overload) => Overload,
_ => self,
}
}
}
#[cfg(test)]
mod tests {
use crate::{
limiter::{DefaultLimiter, Limiter, Outcome},
limits::Fixed,
};
#[tokio::test]
async fn it_works() {
let limiter = DefaultLimiter::new(Fixed::new(10));
let token = limiter.try_acquire().await.unwrap();
limiter.release(token, Some(Outcome::Success)).await;
assert_eq!(limiter.limit(), 10);
}
}