use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use dashmap::DashMap;
use dashmap::mapref::entry::Entry;
use tokio::sync::{OwnedSemaphorePermit, Semaphore, broadcast};
use hitbox_core::{CacheValue, CacheableResponse};
use crate::CacheKey;
use crate::policy::ConcurrencyLimit;
type InFlightEntry<T> = (broadcast::Sender<Arc<CacheValue<T>>>, Arc<Semaphore>);
const CHANNEL_CAPACITY: usize = 1;
#[derive(Debug, Clone)]
pub enum ConcurrencyError {
Lagged(u64),
Closed,
}
pub enum ConcurrencyDecision<Res> {
Proceed(OwnedSemaphorePermit),
ProceedWithoutPermit,
Await(Pin<Box<dyn Future<Output = Result<Res, ConcurrencyError>> + Send>>),
}
pub trait ConcurrencyManager<Res>: Send + Sync
where
Res: CacheableResponse,
{
fn check(
&self,
cache_key: &CacheKey,
concurrency: ConcurrencyLimit,
) -> ConcurrencyDecision<Res>;
fn resolve(&self, cache_key: &CacheKey, cache_value: &CacheValue<Res::Cached>);
fn cleanup(&self, cache_key: &CacheKey);
}
impl<Res, T> ConcurrencyManager<Res> for Arc<T>
where
T: ConcurrencyManager<Res>,
Res: CacheableResponse,
{
fn check(
&self,
cache_key: &CacheKey,
concurrency: ConcurrencyLimit,
) -> ConcurrencyDecision<Res> {
self.as_ref().check(cache_key, concurrency)
}
fn resolve(&self, cache_key: &CacheKey, cache_value: &CacheValue<Res::Cached>) {
self.as_ref().resolve(cache_key, cache_value);
}
fn cleanup(&self, cache_key: &CacheKey) {
self.as_ref().cleanup(cache_key);
}
}
#[derive(Clone)]
pub struct NoopConcurrencyManager;
impl<Res> ConcurrencyManager<Res> for NoopConcurrencyManager
where
Res: CacheableResponse + Send + 'static,
{
fn check(
&self,
_cache_key: &CacheKey,
_concurrency: ConcurrencyLimit,
) -> ConcurrencyDecision<Res> {
ConcurrencyDecision::ProceedWithoutPermit
}
fn resolve(&self, _cache_key: &CacheKey, _cache_value: &CacheValue<Res::Cached>) {
}
fn cleanup(&self, _cache_key: &CacheKey) {
}
}
pub struct BroadcastConcurrencyManager<Res>
where
Res: CacheableResponse,
{
in_flight: Arc<DashMap<CacheKey, InFlightEntry<Res::Cached>>>,
}
impl<Res> Clone for BroadcastConcurrencyManager<Res>
where
Res: CacheableResponse,
{
fn clone(&self) -> Self {
Self {
in_flight: Arc::clone(&self.in_flight),
}
}
}
impl<Res> Default for BroadcastConcurrencyManager<Res>
where
Res: CacheableResponse,
{
fn default() -> Self {
Self::new()
}
}
impl<Res> BroadcastConcurrencyManager<Res>
where
Res: CacheableResponse,
{
pub fn new() -> Self {
Self {
in_flight: Arc::new(DashMap::new()),
}
}
}
impl<Res> ConcurrencyManager<Res> for BroadcastConcurrencyManager<Res>
where
Res: CacheableResponse + Send + 'static,
Res::Cached: Send + Sync + Clone + 'static,
{
fn check(
&self,
cache_key: &CacheKey,
concurrency: ConcurrencyLimit,
) -> ConcurrencyDecision<Res> {
let concurrency: usize = concurrency.get().into();
match self.in_flight.entry(cache_key.clone()) {
Entry::Occupied(entry) => {
let (sender, semaphore) = entry.get();
if let Ok(permit) = semaphore.clone().try_acquire_owned() {
ConcurrencyDecision::Proceed(permit)
} else {
let mut receiver = sender.subscribe();
let future = Box::pin(async move {
match receiver.recv().await {
Ok(cache_value) => {
Ok(Res::from_cached(cache_value.data().clone()).await)
}
Err(broadcast::error::RecvError::Lagged(n)) => {
Err(ConcurrencyError::Lagged(n))
}
Err(broadcast::error::RecvError::Closed) => {
Err(ConcurrencyError::Closed)
}
}
});
ConcurrencyDecision::Await(future)
}
}
Entry::Vacant(entry) => {
let (sender, _receiver) = broadcast::channel(CHANNEL_CAPACITY);
let semaphore = Arc::new(Semaphore::new(concurrency));
let permit = semaphore
.clone()
.try_acquire_owned()
.expect("First permit acquisition should never fail");
entry.insert((sender, semaphore));
ConcurrencyDecision::Proceed(permit)
}
}
}
fn resolve(&self, cache_key: &CacheKey, cache_value: &CacheValue<Res::Cached>) {
if let Some((_, (sender, _semaphore))) = self.in_flight.remove(cache_key) {
let shared_value = Arc::new(cache_value.clone());
let _ = sender.send(shared_value);
}
}
fn cleanup(&self, cache_key: &CacheKey) {
self.in_flight.remove(cache_key);
}
}