use dashmap::{mapref::entry::Entry, DashMap};
use futures_util::future::{self, Either};
use std::{fmt, hash::Hash, mem, sync::Arc};
use tokio::sync::{oneshot, OwnedMutexGuard};
#[derive(Clone)]
pub struct BatchMutex<Key: Eq + Hash, Item, T = ()> {
queue: Arc<DashMap<Key, BatchState<Key, Item, T>>>,
}
impl<Key: Eq + Hash, Item, T> Default for BatchMutex<Key, Item, T> {
fn default() -> Self {
Self {
queue: <_>::default(),
}
}
}
impl<Key, Item, T> fmt::Debug for BatchMutex<Key, Item, T>
where
Key: Eq + Hash,
Item: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BatchMutex").finish_non_exhaustive()
}
}
impl<Key, Item, T> BatchMutex<Key, Item, T>
where
Key: Eq + Hash + Clone,
{
pub async fn submit(&self, batch_key: Key, item: Item) -> BatchResult<Key, Item, T> {
let (tx, rx) = tokio::sync::oneshot::channel();
let batch_lock = {
let mut state = self.queue.entry(batch_key.clone()).or_default();
state.items.push(item);
state.senders.push(tx);
Arc::clone(&state.lock)
};
match future::select(rx, Box::pin(batch_lock.lock_owned())).await {
Either::Left((result, guard)) => {
drop(guard);
match result {
Ok(Ok((val, _))) => BatchResult::Done(val),
Err(_) | Ok(Err(_)) => BatchResult::Failed,
}
}
Either::Right((guard, rx)) => {
let batch = {
let mut state = self.queue.get_mut(&batch_key).unwrap(); Batch {
guard: Some(guard),
items: mem::take(&mut state.items),
senders: state.senders.drain(..).map(Some).collect(),
cleaner: Cleaner {
queue: Arc::clone(&self.queue),
key: Some(batch_key),
},
local_rx: rx,
}
};
BatchResult::Work(batch)
}
}
}
}
struct BatchState<Key: Eq + Hash, Item, T> {
items: Vec<Item>,
senders: Vec<Sender<Key, Item, T>>,
lock: Arc<tokio::sync::Mutex<()>>,
}
type Sender<K, V, T> = oneshot::Sender<Result<(T, Cleaner<K, V, T>), Cleaner<K, V, T>>>;
type Receiver<K, V, T> = oneshot::Receiver<Result<(T, Cleaner<K, V, T>), Cleaner<K, V, T>>>;
impl<Key: Eq + Hash, Item, T> Default for BatchState<Key, Item, T> {
fn default() -> Self {
Self {
items: <_>::default(),
senders: <_>::default(),
lock: <_>::default(),
}
}
}
#[must_use]
#[derive(Debug)]
pub enum BatchResult<Key: Eq + Hash + Clone, Item, T> {
Work(Batch<Key, Item, T>),
Done(T),
Failed,
}
pub struct Batch<Key: Eq + Hash + Clone, Item, T> {
pub items: Vec<Item>,
guard: Option<OwnedMutexGuard<()>>,
senders: Vec<Option<Sender<Key, Item, T>>>,
cleaner: Cleaner<Key, Item, T>,
local_rx: Receiver<Key, Item, T>,
}
impl<Key, Item, T> fmt::Debug for Batch<Key, Item, T>
where
Key: Eq + Hash + Clone,
Item: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Batch")
.field("items", &self.items)
.finish_non_exhaustive()
}
}
impl<Key: Eq + Hash + Clone, Item, T> Batch<Key, Item, T> {
pub fn notify_done(&mut self, item_index: usize, val: T) {
if let Some(tx) = self.senders.get_mut(item_index).and_then(|i| i.take()) {
let _ = tx.send(Ok((val, self.cleaner.clone())));
}
}
pub fn recv_local_notify_done(&mut self) -> Option<T> {
self.local_rx.try_recv().ok().and_then(|v| Some(v.ok()?.0))
}
pub fn pull_waiting_items(&mut self) -> bool {
if let Some(mut next) = self
.cleaner
.key
.as_ref()
.and_then(|k| self.cleaner.queue.get_mut(k))
.filter(|n| !n.items.is_empty())
{
self.items.append(&mut next.items);
self.senders.extend(next.senders.drain(..).map(Some));
true
} else {
false
}
}
fn notify_all_failed(&mut self) {
for tx in &mut self.senders {
let _ = tx.take().map(|tx| tx.send(Err(self.cleaner.clone())));
}
}
}
impl<Key: Eq + Hash + Clone, Item> Batch<Key, Item, ()> {
pub fn notify_all_done(&mut self) {
for tx in &mut self.senders {
let _ = tx.take().map(|tx| tx.send(Ok(((), self.cleaner.clone()))));
}
}
}
impl<Key: Eq + Hash + Clone, Item, T> Drop for Batch<Key, Item, T> {
fn drop(&mut self) {
self.notify_all_failed(); self.guard.take(); }
}
struct Cleaner<Key: Eq + Hash, Item, T> {
queue: Arc<DashMap<Key, BatchState<Key, Item, T>>>,
key: Option<Key>,
}
impl<Key: Eq + Hash + Clone, Item, T> Clone for Cleaner<Key, Item, T> {
fn clone(&self) -> Self {
Self {
queue: Arc::clone(&self.queue),
key: self.key.clone(),
}
}
}
impl<Key: Eq + Hash, Item, T> Drop for Cleaner<Key, Item, T> {
fn drop(&mut self) {
if let Some(key) = self.key.take() {
if let Entry::Occupied(entry) = self.queue.entry(key) {
if Arc::strong_count(&entry.get().lock) == 1 {
entry.remove();
}
}
}
}
}