use core::{
cell::Cell,
fmt::Debug,
ops::{Deref, DerefMut},
};
use ufotofu::{BufferedConsumer, BulkConsumer, Consumer};
use crate::{mutex::WriteGuard, Mutex};
#[derive(Debug)]
pub struct State<C, ConsumerErr> {
m: Mutex<MutexState<C, ConsumerErr>>,
unclosed_handle_count: Cell<usize>,
}
impl<C, ConsumerErr> State<C, ConsumerErr> {
pub fn new(consumer: C) -> Self {
State {
m: Mutex::new(MutexState {
c: consumer,
error: None,
}),
unclosed_handle_count: Cell::new(1),
}
}
}
#[derive(Debug)]
struct MutexState<C, ConsumerErr> {
c: C,
error: Option<ConsumerErr>,
}
#[derive(Debug)]
pub struct SharedConsumer<R, C, ConsumerErr>
where
R: Deref<Target = State<C, ConsumerErr>> + Clone,
{
state_ref: R,
}
impl<R, C, ConsumerErr> Clone for SharedConsumer<R, C, ConsumerErr>
where
R: Deref<Target = State<C, ConsumerErr>> + Clone,
{
fn clone(&self) -> Self {
self.state_ref
.deref()
.unclosed_handle_count
.set(self.state_ref.deref().unclosed_handle_count.get() + 1);
Self {
state_ref: self.state_ref.clone(),
}
}
}
impl<R, C, ConsumerErr> SharedConsumer<R, C, ConsumerErr>
where
R: Deref<Target = State<C, ConsumerErr>> + Clone,
{
pub fn new(state_ref: R) -> Self {
Self { state_ref }
}
pub async fn access_consumer(&self) -> SharedConsumerAccess<C, ConsumerErr> {
SharedConsumerAccess {
c: self.state_ref.deref().m.write().await,
unclosed_handle_count: &self.state_ref.deref().unclosed_handle_count,
}
}
}
#[derive(Debug)]
pub struct SharedConsumerAccess<'shared_consumer, C, ConsumerErr> {
c: WriteGuard<'shared_consumer, MutexState<C, ConsumerErr>>,
unclosed_handle_count: &'shared_consumer Cell<usize>,
}
impl<C, ConsumerErr> Consumer for SharedConsumerAccess<'_, C, ConsumerErr>
where
C: Consumer<Error = ConsumerErr>,
C::Final: Clone,
ConsumerErr: Clone,
{
type Item = C::Item;
type Final = C::Final;
type Error = C::Error;
async fn consume(&mut self, item: Self::Item) -> Result<(), Self::Error> {
let inner_state = self.c.deref_mut();
match inner_state.error.as_ref() {
Some(err) => Err(err.clone()),
None => match inner_state.c.consume(item).await {
Ok(()) => Ok(()),
Err(err) => {
inner_state.error = Some(err.clone());
Err(err)
}
},
}
}
async fn close(&mut self, fin: Self::Final) -> Result<(), Self::Error> {
let inner_state = self.c.deref_mut();
match inner_state.error.as_ref() {
Some(err) => Err(err.clone()),
None => {
self.unclosed_handle_count
.set(self.unclosed_handle_count.get() - 1);
if self.unclosed_handle_count.get() == 0 {
match inner_state.c.close(fin).await {
Ok(()) => Ok(()),
Err(err) => {
inner_state.error = Some(err.clone());
Err(err)
}
}
} else {
Ok(())
}
}
}
}
}
impl<C, ConsumerErr> BufferedConsumer for SharedConsumerAccess<'_, C, ConsumerErr>
where
C: BufferedConsumer<Error = ConsumerErr>,
C::Final: Clone,
ConsumerErr: Clone,
{
async fn flush(&mut self) -> Result<(), Self::Error> {
let inner_state = self.c.deref_mut();
match inner_state.error.as_ref() {
Some(err) => Err(err.clone()),
None => match inner_state.c.flush().await {
Ok(()) => Ok(()),
Err(err) => {
inner_state.error = Some(err.clone());
Err(err)
}
},
}
}
}
impl<C, ConsumerErr> BulkConsumer for SharedConsumerAccess<'_, C, ConsumerErr>
where
C: BulkConsumer<Error = ConsumerErr>,
C::Final: Clone,
ConsumerErr: Clone,
{
async fn expose_slots<'a>(&'a mut self) -> Result<&'a mut [Self::Item], Self::Error>
where
Self::Item: 'a,
{
let inner_state = self.c.deref_mut();
match inner_state.error.as_ref() {
Some(err) => Err(err.clone()),
None => match inner_state.c.expose_slots().await {
Ok(slots) => Ok(slots),
Err(err) => {
inner_state.error = Some(err.clone());
Err(err)
}
},
}
}
async fn consume_slots(&mut self, amount: usize) -> Result<(), Self::Error> {
let inner_state = self.c.deref_mut();
match inner_state.error.as_ref() {
Some(err) => Err(err.clone()),
None => match inner_state.c.consume_slots(amount).await {
Ok(()) => Ok(()),
Err(err) => {
inner_state.error = Some(err.clone());
Err(err)
}
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use core::time::Duration;
use either::Either::{Left, Right};
use smol::{block_on, Timer};
use ufotofu::{
consumer::{TestConsumer, TestConsumerBuilder},
Consumer, Producer,
};
use ufotofu_queues::Fixed;
use crate::spsc::{self, new_spsc};
#[test]
fn test_shared_consumer_errors() {
let underlying_c: TestConsumer<u8, (), i16> = TestConsumerBuilder::new(-4, 3).build();
let state = State::new(underlying_c);
let shared1 = SharedConsumer::new(&state);
let shared2 = shared1.clone();
let write_some_items1 = async {
{
let mut c_handle = shared1.access_consumer().await;
Timer::after(Duration::from_millis(50)).await; assert_eq!(Ok(()), c_handle.consume(1).await);
}
Timer::after(Duration::from_millis(50)).await;
{
let mut c_handle = shared1.access_consumer().await;
assert_eq!(Ok(()), c_handle.consume(3).await);
assert_eq!(Err(-4), c_handle.consume(4).await);
}
};
let write_some_items2 = async {
Timer::after(Duration::from_millis(10)).await;
{
let mut c_handle = shared2.access_consumer().await;
assert_eq!(Ok(()), c_handle.consume(2).await);
}
Timer::after(Duration::from_millis(50)).await;
let mut c_handle = shared2.access_consumer().await;
assert_eq!(Err(-4), c_handle.consume(4).await); };
block_on(futures::future::join(write_some_items1, write_some_items2));
}
#[test]
fn test_shared_consumer_closing() {
let spsc_state: spsc::State<Fixed<u8>, i16, ()> =
spsc::State::new(Fixed::new(16 ));
let (sender, mut receiver) = new_spsc(&spsc_state);
let state = State::new(sender);
let shared1 = SharedConsumer::new(&state);
let shared2 = shared1.clone();
let write_some_items1 = async {
{
let mut c_handle = shared1.access_consumer().await;
Timer::after(Duration::from_millis(50)).await; assert_eq!(Ok(()), c_handle.consume(1).await);
}
Timer::after(Duration::from_millis(50)).await;
{
let mut c_handle = shared1.access_consumer().await;
assert_eq!(Ok(()), c_handle.consume(3).await);
assert_eq!(Ok(()), c_handle.close(-1).await);
}
};
let write_some_items2 = async {
Timer::after(Duration::from_millis(10)).await;
{
let mut c_handle = shared2.access_consumer().await;
assert_eq!(Ok(()), c_handle.consume(2).await);
}
Timer::after(Duration::from_millis(50)).await;
let mut c_handle = shared2.access_consumer().await;
assert_eq!(Ok(()), c_handle.close(-2).await);
assert_eq!(Ok(Left(1)), receiver.produce().await);
assert_eq!(Ok(Left(2)), receiver.produce().await);
assert_eq!(Ok(Left(3)), receiver.produce().await);
assert_eq!(Ok(Right(-2)), receiver.produce().await);
};
block_on(futures::future::join(write_some_items1, write_some_items2));
}
}