use super::mpsc::{
self,
error::{SendError, TryRecvError, TrySendError},
};
use crate::sync::Mutex;
use futures::Stream;
use std::{
collections::HashMap,
hash::Hash,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
#[derive(Clone)]
pub struct Guard<B: Eq + Hash + Clone> {
sequence: u64,
tracker: Arc<Mutex<State<B>>>,
batch: Option<B>,
}
impl<B: Eq + Hash + Clone> Drop for Guard<B> {
fn drop(&mut self) {
let mut state = self.tracker.lock();
*state.pending.get_mut(&self.sequence).unwrap() = true;
let mut current_watermark = state.watermark;
while let Some(delivered) = state.pending.get(&(current_watermark + 1)) {
if !*delivered {
break;
}
state.pending.remove(&(current_watermark + 1));
current_watermark += 1;
state.watermark = current_watermark;
}
if let Some(batch) = &self.batch {
let count = state.batches.get_mut(batch).unwrap();
if *count > 1 {
*count -= 1;
} else {
state.batches.remove(batch);
}
}
}
}
pub struct Message<T, B: Eq + Hash + Clone> {
pub data: T,
pub guard: Arc<Guard<B>>,
}
struct State<B> {
next: u64,
watermark: u64,
batches: HashMap<B, usize>,
pending: HashMap<u64, bool>,
}
impl<B> Default for State<B> {
fn default() -> Self {
Self {
next: 1,
watermark: 0,
batches: HashMap::new(),
pending: HashMap::new(),
}
}
}
#[derive(Clone)]
struct Tracker<B: Eq + Hash + Clone> {
state: Arc<Mutex<State<B>>>,
}
impl<B: Eq + Hash + Clone> Tracker<B> {
fn new() -> Self {
Self {
state: Arc::new(Mutex::new(State::default())),
}
}
fn guard(&self, batch: Option<B>) -> Guard<B> {
let mut state = self.state.lock();
let sequence = state.next;
state.next += 1;
state.pending.insert(sequence, false);
if let Some(batch) = &batch {
*state.batches.entry(batch.clone()).or_insert(0) += 1;
}
Guard {
sequence,
tracker: self.state.clone(),
batch,
}
}
}
#[derive(Clone)]
pub struct Sender<T, B: Eq + Hash + Clone> {
inner: mpsc::Sender<Message<T, B>>,
tracker: Tracker<B>,
}
impl<T, B: Eq + Hash + Clone> Sender<T, B> {
pub async fn send(&self, batch: Option<B>, data: T) -> Result<u64, SendError<Message<T, B>>> {
let guard = Arc::new(self.tracker.guard(batch));
let watermark = guard.sequence;
let msg = Message { data, guard };
self.inner.send(msg).await?;
Ok(watermark)
}
pub fn try_send(&self, batch: Option<B>, data: T) -> Result<u64, TrySendError<Message<T, B>>> {
let guard = Arc::new(self.tracker.guard(batch));
let watermark = guard.sequence;
let msg = Message { data, guard };
self.inner.try_send(msg)?;
Ok(watermark)
}
pub fn watermark(&self) -> u64 {
self.tracker.state.lock().watermark
}
pub fn pending(&self, batch: B) -> usize {
self.tracker
.state
.lock()
.batches
.get(&batch)
.copied()
.unwrap_or(0)
}
}
pub struct Receiver<T, B: Eq + Hash + Clone> {
inner: mpsc::Receiver<Message<T, B>>,
}
impl<T, B: Eq + Hash + Clone> Receiver<T, B> {
pub async fn recv(&mut self) -> Option<Message<T, B>> {
self.inner.recv().await
}
pub fn try_recv(&mut self) -> Result<Message<T, B>, TryRecvError> {
self.inner.try_recv()
}
}
impl<T, B: Eq + Hash + Clone> Stream for Receiver<T, B> {
type Item = Message<T, B>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.inner.poll_recv(cx)
}
}
pub fn bounded<T, B: Eq + Hash + Clone>(buffer: usize) -> (Sender<T, B>, Receiver<T, B>) {
let (tx, rx) = mpsc::channel(buffer);
let sender = Sender {
inner: tx,
tracker: Tracker::new(),
};
let receiver = Receiver { inner: rx };
(sender, receiver)
}
#[cfg(test)]
mod tests {
use super::*;
use futures::executor::block_on;
#[test]
fn test_basic() {
block_on(async move {
let (sender, mut receiver) = bounded::<i32, u64>(10);
let watermark = sender.send(None, 42).await.unwrap();
assert_eq!(watermark, 1);
assert_eq!(sender.watermark(), 0);
let msg = receiver.recv().await.unwrap();
assert_eq!(msg.data, 42);
assert_eq!(sender.watermark(), 0);
drop(msg.guard);
assert_eq!(sender.watermark(), 1);
});
}
#[test]
fn test_batch_tracking() {
block_on(async move {
let (sender, mut receiver) = bounded::<String, u64>(10);
let watermark1 = sender.send(Some(100), "msg1".into()).await.unwrap();
let watermark2 = sender.send(Some(100), "msg2".into()).await.unwrap();
let watermark3 = sender.send(Some(200), "msg3".into()).await.unwrap();
assert_eq!(watermark1, 1);
assert_eq!(watermark2, 2);
assert_eq!(watermark3, 3);
assert_eq!(sender.pending(100), 2);
assert_eq!(sender.pending(200), 1);
assert_eq!(sender.pending(300), 0);
let msg1 = receiver.recv().await.unwrap();
assert_eq!(msg1.data, "msg1");
drop(msg1.guard);
assert_eq!(sender.pending(100), 1);
assert_eq!(sender.pending(200), 1);
let msg2 = receiver.recv().await.unwrap();
let msg3 = receiver.recv().await.unwrap();
drop(msg2.guard);
drop(msg3.guard);
assert_eq!(sender.pending(100), 0);
assert_eq!(sender.pending(200), 0);
});
}
#[test]
fn test_cloned_guards() {
block_on(async move {
let (sender, mut receiver) = bounded::<&str, u64>(10);
let watermark = sender.send(Some(1), "test").await.unwrap();
assert_eq!(watermark, 1);
let msg = receiver.recv().await.unwrap();
assert_eq!(msg.data, "test");
let msg_guard_clone1 = msg.guard.clone();
let msg_guard_clone2 = msg.guard.clone();
assert_eq!(sender.pending(1), 1);
assert_eq!(sender.watermark(), 0);
drop(msg.guard);
drop(msg_guard_clone1);
assert_eq!(sender.pending(1), 1);
assert_eq!(sender.watermark(), 0);
drop(msg_guard_clone2);
assert_eq!(sender.pending(1), 0);
assert_eq!(sender.watermark(), 1);
});
}
#[test]
fn test_try_send() {
block_on(async move {
let (sender, mut receiver) = bounded::<i32, u64>(2);
let watermark1 = sender.try_send(Some(10), 1).unwrap();
let watermark2 = sender.try_send(Some(10), 2).unwrap();
assert_eq!(sender.pending(10), 2);
assert_eq!(watermark1, 1);
assert_eq!(watermark2, 2);
let msg1 = receiver.recv().await.unwrap();
assert_eq!(msg1.data, 1);
drop(msg1.guard);
assert_eq!(sender.pending(10), 1);
let msg2 = receiver.recv().await.unwrap();
drop(msg2.guard);
assert_eq!(sender.pending(10), 0);
});
}
#[test]
fn test_channel_closure() {
block_on(async move {
let (sender, receiver) = bounded::<i32, u64>(10);
let _guard = sender.send(None, 1).await.unwrap();
drop(receiver);
assert!(sender.send(None, 2).await.is_err());
});
}
}