use super::{Policy, PolicyOutput, PolicyResult};
use crate::Context;
use parking_lot::Mutex;
use rama_utils::backoff::Backoff;
use std::fmt;
use std::sync::Arc;
pub struct ConcurrentPolicy<B, C> {
tracker: C,
backoff: B,
}
impl<B: fmt::Debug, C: fmt::Debug> std::fmt::Debug for ConcurrentPolicy<B, C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConcurrentPolicy")
.field("tracker", &self.tracker)
.field("backoff", &self.backoff)
.finish()
}
}
impl<B, C> Clone for ConcurrentPolicy<B, C>
where
B: Clone,
C: Clone,
{
fn clone(&self) -> Self {
ConcurrentPolicy {
tracker: self.tracker.clone(),
backoff: self.backoff.clone(),
}
}
}
impl<B, C> ConcurrentPolicy<B, C> {
pub fn with_backoff(backoff: B, tracker: C) -> Self {
ConcurrentPolicy { tracker, backoff }
}
}
impl<C> ConcurrentPolicy<(), C> {
pub const fn new(tracker: C) -> Self {
ConcurrentPolicy {
tracker,
backoff: (),
}
}
}
impl ConcurrentPolicy<(), ConcurrentCounter> {
pub fn max(max: usize) -> Self {
ConcurrentPolicy {
tracker: ConcurrentCounter::new(max),
backoff: (),
}
}
}
impl<B> ConcurrentPolicy<B, ConcurrentCounter> {
pub fn max_with_backoff(max: usize, backoff: B) -> Self {
ConcurrentPolicy {
tracker: ConcurrentCounter::new(max),
backoff,
}
}
}
impl<B, C, State, Request> Policy<State, Request> for ConcurrentPolicy<B, C>
where
B: Backoff,
State: Clone + Send + Sync + 'static,
Request: Send + 'static,
C: ConcurrentTracker,
{
type Guard = C::Guard;
type Error = C::Error;
async fn check(
&self,
ctx: Context<State>,
request: Request,
) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
let tracker_err = match self.tracker.try_access() {
Ok(guard) => {
return PolicyResult {
ctx,
request,
output: PolicyOutput::Ready(guard),
};
}
Err(err) => err,
};
let output = if !self.backoff.next_backoff().await {
PolicyOutput::Abort(tracker_err)
} else {
PolicyOutput::Retry
};
PolicyResult {
ctx,
request,
output,
}
}
}
rama_utils::macros::error::static_str_error! {
#[doc = "request aborted due to exhausted concurrency limit"]
pub struct LimitReached;
}
pub trait ConcurrentTracker: Send + Sync + 'static {
type Guard: Send + 'static;
type Error: Send + Sync + 'static;
fn try_access(&self) -> Result<Self::Guard, Self::Error>;
}
#[derive(Debug, Clone)]
pub struct ConcurrentCounter {
max: usize,
current: Arc<Mutex<usize>>,
}
impl ConcurrentCounter {
pub fn new(max: usize) -> Self {
ConcurrentCounter {
max,
current: Arc::new(Mutex::new(0)),
}
}
}
impl ConcurrentTracker for ConcurrentCounter {
type Guard = ConcurrentCounterGuard;
type Error = LimitReached;
fn try_access(&self) -> Result<Self::Guard, Self::Error> {
let mut current = self.current.lock();
if *current < self.max {
*current += 1;
Ok(ConcurrentCounterGuard {
current: self.current.clone(),
})
} else {
Err(LimitReached)
}
}
}
#[derive(Debug)]
pub struct ConcurrentCounterGuard {
current: Arc<Mutex<usize>>,
}
impl Drop for ConcurrentCounterGuard {
fn drop(&mut self) {
let mut current = self.current.lock();
*current -= 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_ready<S, R, G, E>(result: PolicyResult<S, R, G, E>) -> G {
match result.output {
PolicyOutput::Ready(guard) => guard,
_ => panic!("unexpected output, expected ready"),
}
}
fn assert_abort<S, R, G, E>(result: PolicyResult<S, R, G, E>) {
match result.output {
PolicyOutput::Abort(_) => (),
_ => panic!("unexpected output, expected abort"),
}
}
#[tokio::test]
async fn concurrent_policy_zero() {
let policy = ConcurrentPolicy::max(0);
assert_abort(policy.check(Context::default(), ()).await);
}
#[tokio::test]
async fn concurrent_policy() {
let policy = ConcurrentPolicy::max(2);
let guard_1 = assert_ready(policy.check(Context::default(), ()).await);
let guard_2 = assert_ready(policy.check(Context::default(), ()).await);
assert_abort(policy.check(Context::default(), ()).await);
drop(guard_1);
let _guard_3 = assert_ready(policy.check(Context::default(), ()).await);
assert_abort(policy.check(Context::default(), ()).await);
drop(guard_2);
assert_ready(policy.check(Context::default(), ()).await);
}
#[tokio::test]
async fn concurrent_policy_clone() {
let policy = ConcurrentPolicy::max(2);
let policy_clone = policy.clone();
let guard_1 = assert_ready(policy.check(Context::default(), ()).await);
let _guard_2 = assert_ready(policy_clone.check(Context::default(), ()).await);
assert_abort(policy.check(Context::default(), ()).await);
drop(guard_1);
assert_ready(policy.check(Context::default(), ()).await);
}
}