use std::{
convert::Infallible,
sync::{Arc, Mutex},
};
use crate::util::backoff::Backoff;
use super::{Policy, PolicyOutput};
#[derive(Debug)]
pub struct ConcurrentPolicy<B> {
max: usize,
current: Arc<Mutex<usize>>,
backoff: B,
}
impl<B> Clone for ConcurrentPolicy<B>
where
B: Clone,
{
fn clone(&self) -> Self {
ConcurrentPolicy {
max: self.max,
current: self.current.clone(),
backoff: self.backoff.clone(),
}
}
}
impl ConcurrentPolicy<()> {
pub fn new(max: usize) -> Self {
ConcurrentPolicy {
max,
current: Arc::new(Mutex::new(0)),
backoff: (),
}
}
}
impl<B> ConcurrentPolicy<B> {
pub fn with_backoff(max: usize, backoff: B) -> Self {
ConcurrentPolicy {
max,
current: Arc::new(Mutex::new(0)),
backoff,
}
}
}
#[derive(Debug)]
pub struct ConcurrentGuard {
current: Arc<Mutex<usize>>,
}
impl Drop for ConcurrentGuard {
fn drop(&mut self) {
let mut current = self.current.lock().unwrap();
*current -= 1;
}
}
impl<B, Request> Policy<Request> for ConcurrentPolicy<B>
where
B: Backoff,
{
type Guard = ConcurrentGuard;
type Error = Infallible;
async fn check(&self, _: &mut Request) -> PolicyOutput<Self::Guard, Self::Error> {
{
let mut current = self.current.lock().unwrap();
if *current < self.max {
*current += 1;
return PolicyOutput::Ready(ConcurrentGuard {
current: self.current.clone(),
});
}
}
self.backoff.next_backoff().await;
PolicyOutput::Retry
}
}
#[derive(Debug)]
pub struct LimitReached;
impl std::fmt::Display for LimitReached {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("LimitReached")
}
}
impl std::error::Error for LimitReached {}
impl<Request> Policy<Request> for ConcurrentPolicy<()> {
type Guard = ConcurrentGuard;
type Error = LimitReached;
async fn check(&self, _: &mut Request) -> PolicyOutput<Self::Guard, Self::Error> {
let mut current = self.current.lock().unwrap();
if *current < self.max {
*current += 1;
PolicyOutput::Ready(ConcurrentGuard {
current: self.current.clone(),
})
} else {
PolicyOutput::Abort(LimitReached)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_ready<G, E>(output: PolicyOutput<G, E>) -> G {
match output {
PolicyOutput::Ready(guard) => guard,
_ => panic!("unexpected output, expected ready"),
}
}
fn assert_abort<G, E>(output: PolicyOutput<G, E>) {
match output {
PolicyOutput::Abort(_) => (),
_ => panic!("unexpected output, expected abort"),
}
}
#[tokio::test]
async fn concurrent_policy() {
let policy = ConcurrentPolicy::new(2);
let guard_1 = assert_ready(policy.check(&mut ()).await);
let guard_2 = assert_ready(policy.check(&mut ()).await);
assert_abort(policy.check(&mut ()).await);
drop(guard_1);
let _guard_3 = assert_ready(policy.check(&mut ()).await);
assert_abort(policy.check(&mut ()).await);
drop(guard_2);
assert_ready(policy.check(&mut ()).await);
}
}