use either::Either::{self, *};
use fairly_unsafe_cell::*;
use std::{
cell::Cell,
convert::Infallible,
marker::PhantomData,
ops::{Deref, DerefMut},
};
use ufotofu::{BufferedConsumer, BufferedProducer, BulkConsumer, BulkProducer, Consumer, Producer};
use ufotofu_queues::Queue;
use crate::{extend_lifetime, extend_lifetime_mut, Mutex, TakeCell};
#[derive(Debug)]
pub struct State<Q, F, E> {
queue: Mutex<Q>,
last: FairlyUnsafeCell<Option<Result<F, E>>>,
len: Cell<usize>,
notify_the_sender: TakeCell<()>,
notify_the_receiver: TakeCell<()>,
nothing_dropped_yet: Cell<bool>,
}
impl<Q: Queue, F, E> State<Q, F, E> {
pub fn new(queue: Q) -> Self {
State {
len: Cell::new(queue.len()),
queue: Mutex::new(queue),
last: FairlyUnsafeCell::new(None),
notify_the_sender: TakeCell::new_with(()),
notify_the_receiver: TakeCell::new_with(()),
nothing_dropped_yet: Cell::new(true),
}
}
pub fn len(&self) -> usize {
self.len.get()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn close(&self, fin: F) {
let mut last = unsafe { self.last.borrow_mut() };
debug_assert!(
last.is_none(),
"Must not call `close` or `close_sync` multiple times or after calling `cause_error`."
);
*last = Some(Ok(fin));
self.notify_the_receiver.set(());
}
pub fn cause_error(&self, err: E) {
let mut last = unsafe { self.last.borrow_mut() };
debug_assert!(
last.is_none(),
"Must not call `cause_error` multiple times or after calling `close` or `close_sync`."
);
*last = Some(Err(err));
self.notify_the_receiver.set(());
}
pub fn has_been_closed_or_errored_yet(&self) -> bool {
unsafe { self.last.borrow().is_some() }
}
}
pub fn new_spsc<R, Q, F, E>(state_ref: R) -> (Sender<R, Q, F, E>, Receiver<R, Q, F, E>)
where
R: Deref<Target = State<Q, F, E>> + Clone,
{
(
Sender {
state: state_ref.clone(),
phantom: PhantomData,
},
Receiver {
state: state_ref,
phantom: PhantomData,
},
)
}
#[derive(Debug)]
pub struct Sender<R: Deref<Target = State<Q, F, E>>, Q, F, E> {
state: R,
phantom: PhantomData<(Q, F, E)>,
}
#[derive(Debug)]
pub struct Receiver<R: Deref<Target = State<Q, F, E>>, Q, F, E> {
state: R,
phantom: PhantomData<(Q, F, E)>,
}
impl<R: Deref<Target = State<Q, F, E>>, Q: Queue, F, E> Sender<R, Q, F, E> {
pub fn len(&self) -> usize {
self.state.len.get()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn cause_error(&mut self, err: E) {
self.state.cause_error(err)
}
pub fn close_sync(&mut self, fin: F) {
self.state.close(fin)
}
pub fn is_receiver_dropped(&self) -> bool {
self.state.nothing_dropped_yet.get()
}
}
impl<R: Deref<Target = State<Q, F, E>>, Q, F, E> Drop for Sender<R, Q, F, E> {
fn drop(&mut self) {
self.state.nothing_dropped_yet.set(false);
}
}
impl<R: Deref<Target = State<Q, F, E>>, Q: Queue, F, E> Consumer for Sender<R, Q, F, E> {
type Item = Q::Item;
type Final = F;
type Error = Infallible;
async fn consume(&mut self, item_: Self::Item) -> Result<(), Self::Error> {
let mut item = item_;
loop {
let did_it_work = {
self.state.queue.write().await.deref_mut().enqueue(item)
};
match did_it_work {
Some(item_) => {
let () = self.state.notify_the_sender.take().await;
item = item_;
}
None => {
self.state.len.set(self.state.len.get() + 1);
self.state.notify_the_receiver.set(());
return Ok(());
}
}
}
}
async fn close(&mut self, fin: Self::Final) -> Result<(), Self::Error> {
self.close_sync(fin);
Ok(())
}
}
impl<R: Deref<Target = State<Q, F, E>>, Q: Queue, F, E> BufferedConsumer for Sender<R, Q, F, E> {
async fn flush(&mut self) -> Result<(), Self::Error> {
Ok(()) }
}
impl<R: Deref<Target = State<Q, F, E>>, Q: Queue, F, E> BulkConsumer for Sender<R, Q, F, E> {
async fn expose_slots<'a>(&'a mut self) -> Result<&'a mut [Self::Item], Self::Error>
where
Self::Item: 'a,
{
loop {
match self.state.queue.write().await.deref_mut().expose_slots() {
None => {
}
Some(slots) => {
let slots: &'a mut [Q::Item] = unsafe { extend_lifetime_mut(slots) };
return Ok(slots);
}
}
let () = self.state.notify_the_sender.take().await;
}
}
async fn consume_slots(&mut self, amount: usize) -> Result<(), Self::Error> {
self.state
.queue
.write()
.await
.deref_mut()
.consider_enqueued(amount);
self.state.len.set(self.state.len.get() + amount);
self.state.notify_the_receiver.set(());
Ok(())
}
}
impl<R: Deref<Target = State<Q, F, E>>, Q: Queue, F, E> Receiver<R, Q, F, E> {
pub fn len(&self) -> usize {
self.state.len.get()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn is_receiver_dropped(&self) -> bool {
self.state.nothing_dropped_yet.get()
}
}
impl<R: Deref<Target = State<Q, F, E>>, Q, F, E> Drop for Receiver<R, Q, F, E> {
fn drop(&mut self) {
self.state.nothing_dropped_yet.set(false);
}
}
impl<R: Deref<Target = State<Q, F, E>>, Q: Queue, F, E> Producer for Receiver<R, Q, F, E> {
type Item = Q::Item;
type Final = F;
type Error = E;
async fn produce(&mut self) -> Result<Either<Self::Item, Self::Final>, Self::Error> {
loop {
match self.state.queue.write().await.deref_mut().dequeue() {
Some(item) => {
self.state.len.set(self.state.len.get() - 1);
self.state.notify_the_sender.set(());
return Ok(Left(item));
}
None => {
match unsafe { self.state.last.borrow_mut().take() } {
Some(Ok(fin)) => {
return Ok(Right(fin));
}
Some(Err(err)) => {
return Err(err);
}
None => {
}
}
}
}
let () = self.state.notify_the_receiver.take().await;
}
}
}
impl<R: Deref<Target = State<Q, F, E>>, Q: Queue, F, E> BufferedProducer for Receiver<R, Q, F, E> {
async fn slurp(&mut self) -> Result<(), Self::Error> {
if self.is_empty() {
match unsafe { self.state.last.borrow_mut().take() } {
None => { }
Some(Err(err)) => return Err(err),
Some(Ok(fin)) => {
unsafe { *self.state.last.borrow_mut().deref_mut() = Some(Ok(fin)) }
}
}
}
Ok(()) }
}
impl<R: Deref<Target = State<Q, F, E>>, Q: Queue, F, E> BulkProducer for Receiver<R, Q, F, E> {
async fn expose_items<'a>(
&'a mut self,
) -> Result<Either<&'a [Self::Item], Self::Final>, Self::Error>
where
Self::Item: 'a,
{
loop {
match self.state.queue.write().await.deref_mut().expose_items() {
None => {
match unsafe { self.state.last.borrow_mut().take() } {
Some(Ok(fin)) => {
return Ok(Right(fin));
}
Some(Err(err)) => {
return Err(err);
}
None => {
}
}
}
Some(items) => {
let items: &'a [Q::Item] = unsafe { extend_lifetime(items) };
return Ok(Left(items));
}
}
let () = self.state.notify_the_receiver.take().await;
}
}
async fn consider_produced(&mut self, amount: usize) -> Result<(), Self::Error> {
self.state
.queue
.write()
.await
.deref_mut()
.consider_dequeued(amount);
self.state.len.set(self.state.len.get() - amount);
self.state.notify_the_sender.set(());
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::join;
use ufotofu_queues::Fixed;
#[test]
fn test_spsc_sufficient_capacity() {
let state: State<_, _, ()> = State::new(Fixed::new(99 ));
let (mut sender, mut receiver) = new_spsc(&state);
pollster::block_on(async {
assert!(sender.consume(300).await.is_ok());
assert!(sender.consume(400).await.is_ok());
assert!(sender.consume(500).await.is_ok());
assert!(sender.close(-17).await.is_ok());
assert_eq!(300, receiver.produce().await.unwrap().unwrap_left());
assert_eq!(400, receiver.produce().await.unwrap().unwrap_left());
assert_eq!(500, receiver.produce().await.unwrap().unwrap_left());
assert_eq!(-17, receiver.produce().await.unwrap().unwrap_right());
});
}
#[test]
fn test_spsc_low_capacity() {
pollster::block_on(async {
let state: State<_, _, ()> = State::new(Fixed::new(3 ));
let (mut sender, mut receiver) = new_spsc(&state);
let send_things = async {
assert!(sender.consume(300).await.is_ok());
assert!(sender.consume(400).await.is_ok());
assert!(sender.consume(500).await.is_ok());
assert!(sender.close(-17).await.is_ok());
};
let receive_things = async {
assert_eq!(300, receiver.produce().await.unwrap().unwrap_left());
assert_eq!(400, receiver.produce().await.unwrap().unwrap_left());
assert_eq!(500, receiver.produce().await.unwrap().unwrap_left());
assert_eq!(-17, receiver.produce().await.unwrap().unwrap_right());
};
join!(send_things, receive_things);
});
}
#[test]
fn test_spsc_immediate_final() {
pollster::block_on(async {
let state: State<Fixed<u8>, i16, ()> = State::new(Fixed::new(3 ));
let (mut sender, mut receiver) = new_spsc(&state);
let send_things = async {
assert!(sender.close(-17).await.is_ok());
};
let receive_things = async {
assert_eq!(-17, receiver.produce().await.unwrap().unwrap_right());
};
join!(send_things, receive_things);
});
}
#[test]
fn test_spsc_immediate_error() {
pollster::block_on(async {
let state: State<Fixed<u8>, i16, i16> = State::new(Fixed::new(3 ));
let (mut sender, mut receiver) = new_spsc(&state);
let send_things = async {
sender.cause_error(-17);
};
let receive_things = async {
assert_eq!(-17, receiver.produce().await.unwrap_err());
};
join!(send_things, receive_things);
});
}
#[test]
fn test_spsc_slurp() {
pollster::block_on(async {
let state: State<Fixed<u8>, i16, i16> = State::new(Fixed::new(3 ));
let (mut sender, mut receiver) = new_spsc(&state);
let send_things = async {
sender.cause_error(-17);
};
let receive_things = async {
assert_eq!(-17, receiver.slurp().await.unwrap_err());
};
join!(send_things, receive_things);
});
}
#[test]
fn test_spsc_receive_then_send_concurrently() {
pollster::block_on(async {
let state: State<Fixed<u64>, i16, i16> = State::new(Fixed::new(3 ));
let (mut sender, mut receiver) = new_spsc(&state);
let send_things = async {
assert!(sender.consume(300).await.is_ok());
assert!(sender.consume(400).await.is_ok());
assert!(sender.consume(500).await.is_ok());
assert!(sender.close(-17).await.is_ok());
};
let receive_things = async {
assert_eq!(300, receiver.produce().await.unwrap().unwrap_left());
assert_eq!(400, receiver.produce().await.unwrap().unwrap_left());
assert_eq!(500, receiver.produce().await.unwrap().unwrap_left());
assert_eq!(-17, receiver.produce().await.unwrap().unwrap_right());
};
join!(receive_things, send_things);
});
}
}