flowly-service 0.4.13

Flowly is a library of modular and reusable components for building robust pipelines processing audio, video and other.
Documentation
use std::{
    marker::PhantomData,
    pin::{Pin, pin},
    sync::Arc,
    task::{Poll, ready},
};

use futures::{FutureExt, Stream, StreamExt};
use tokio::sync::{Mutex, OwnedMutexGuard};

use crate::{Context, Service};

pub struct ConcurrentRx<T: Send> {
    guard: OwnedMutexGuard<flowly_spsc::Receiver<Option<T>>>,
}

impl<T: Send> Stream for ConcurrentRx<T> {
    type Item = T;

    fn poll_next(
        mut self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<Option<Self::Item>> {
        match ready!(self.guard.poll_recv(cx)) {
            Some(Some(val)) => Poll::Ready(Some(val)),
            Some(None) => Poll::Ready(None),
            None => Poll::Ready(None),
        }
    }
}

struct ConcurrentTask<I: Send, S: Service<I>> {
    #[allow(dead_code)]
    id: u32,
    tx: flowly_spsc::Sender<I>,
    m: PhantomData<S>,
    _handle: tokio::task::JoinHandle<()>,
    rx: Arc<Mutex<flowly_spsc::Receiver<Option<S::Out>>>>,
}

impl<I, S> ConcurrentTask<I, S>
where
    S::Out: Send + 'static,
    I: Send + 'static,
    S: Service<I> + Send + 'static,
{
    fn new(id: u32, mut s: S, cx: Context) -> Self {
        let (tx, mut in_rx) = flowly_spsc::channel(1);
        let (mut out_tx, out_rx) = flowly_spsc::channel(1);

        let _handle = tokio::spawn(async move {
            'recv: while let Some(item) = in_rx.recv().await {
                let mut s = pin!(s.handle(item, &cx));

                while let Some(x) = s.next().await {
                    if out_tx.send(Some(x)).await.is_err() {
                        log::error!("cannot send the message. channel closed!");
                        break 'recv;
                    }
                }

                if out_tx.send(None).await.is_err() {
                    log::error!("cannot send the message. channel closed!");
                    break 'recv;
                }
            }
        });

        Self {
            id,
            tx,
            rx: Arc::new(tokio::sync::Mutex::new(out_rx)),
            _handle,
            m: PhantomData,
        }
    }

    #[inline]
    fn is_available(&self) -> bool {
        self.rx.try_lock().is_ok()
    }

    #[inline]
    async fn send(
        &mut self,
        input: I,
    ) -> Result<ConcurrentRx<S::Out>, flowly_spsc::TrySendError<I>> {
        self.tx.send(input).await?;

        Ok(ConcurrentRx {
            guard: self.rx.clone().lock_owned().await,
        })
    }
}

pub struct ConcurrentEach<I: Send + 'static, S: Service<I>> {
    service: S,
    tasks: Vec<ConcurrentTask<I, S>>,
    _m: PhantomData<I>,
    limit: usize,
}

impl<I: Send + 'static + Clone, S: Service<I> + Clone> Clone for ConcurrentEach<I, S> {
    fn clone(&self) -> Self {
        Self {
            service: self.service.clone(),
            tasks: Vec::new(),
            _m: self._m,
            limit: self.limit,
        }
    }
}

impl<I, S> ConcurrentEach<I, S>
where
    I: Send,
    S: Service<I> + Send,
    S::Out: Send,
{
    pub fn new(service: S, limit: usize) -> Self {
        Self {
            service,
            tasks: Vec::with_capacity(limit),
            _m: PhantomData,
            limit,
        }
    }
}

impl<I, R, E, S> Service<I> for ConcurrentEach<I, S>
where
    I: Send,
    R: Send + 'static,
    E: Send + 'static,
    S: Service<I, Out = Result<R, E>> + Clone + Send + 'static,
{
    type Out = Result<ConcurrentRx<S::Out>, E>;

    fn handle(&mut self, input: I, cx: &Context) -> impl Stream<Item = Self::Out> + Send {
        async move {
            let index = if self.tasks.len() < self.limit {
                let index = self.tasks.len();
                self.tasks.push(ConcurrentTask::new(
                    index as u32,
                    self.service.clone(),
                    cx.clone(),
                ));
                index
            } else {
                let mut index = fastrand::usize(0..self.tasks.len());

                for idx in 0..self.tasks.len() {
                    let idx = (idx + self.tasks.len()) % self.tasks.len();
                    if self.tasks[idx].is_available() {
                        index = idx;
                        break;
                    }
                }

                index
            };

            Ok(self.tasks[index].send(input).await.unwrap())
        }
        .into_stream()
    }
}

pub fn concurrent_each<I, S>(service: S, limit: usize) -> ConcurrentEach<I, S>
where
    I: Send,
    S: Send + Service<I> + Clone + 'static,
    S::Out: Send,
{
    ConcurrentEach::new(service, limit)
}