dagrs 0.8.1

Dagrs follows the concept of Flow-based Programming and is suitable for the execution of multiple tasks with graph-like dependencies. Dagrs has the characteristics of high performance and asynchronous execution. It provides users with a convenient programming interface.
Documentation
use std::{
    collections::{HashMap, HashSet},
    marker::PhantomData,
    sync::Arc,
};

use futures::future::{join_all, select_ok};
use tokio::sync::{Mutex, broadcast, mpsc};

use crate::NodeId;
use crate::graph::error::{DagrsError, DagrsResult, ErrorCode};

use super::information_packet::Content;

#[derive(Default)]
pub struct InChannels(
    pub(crate) HashMap<NodeId, Arc<Mutex<InChannel>>>,
    pub(crate) HashSet<NodeId>,
);

impl InChannels {
    pub async fn recv_from(&mut self, id: &NodeId) -> DagrsResult<Content> {
        if self.is_disabled(id) {
            return Err(disabled_channel(*id));
        }
        match self.get(id) {
            Some(channel) => channel.lock().await.recv(*id).await,
            None => Err(no_such_channel(*id)),
        }
    }

    pub async fn recv_any(&mut self) -> DagrsResult<(NodeId, Content)> {
        let mut futures = Vec::new();
        let ids = self.keys();

        for id in ids {
            let channel = self.get(&id).ok_or_else(|| no_such_channel(id))?;
            let fut = Box::pin(async move {
                let content = channel.lock().await.recv(id).await?;
                Ok::<_, DagrsError>((id, content))
            });
            futures.push(fut);
        }

        if futures.is_empty() {
            return Err(no_such_channel(NodeId(0)).with_detail("scope", "recv_any"));
        }

        match select_ok(futures).await {
            Ok((result, _)) => Ok(result),
            Err(err) => Err(err),
        }
    }

    pub async fn map<F, T>(&mut self, f: F) -> Vec<T>
    where
        F: FnMut(DagrsResult<Content>) -> T,
    {
        let disabled = self.1.clone();
        let futures = self
            .0
            .iter_mut()
            .filter(|(id, _)| !disabled.contains(id))
            .map(|(id, c)| async move { c.lock().await.recv(*id).await });
        join_all(futures).await.into_iter().map(f).collect()
    }

    pub async fn close(&mut self, id: &NodeId) {
        if let Some(c) = self.get(id) {
            c.lock().await.close();
            self.0.remove(id);
            self.1.remove(id);
        }
    }

    pub(crate) fn insert(&mut self, node_id: NodeId, channel: Arc<Mutex<InChannel>>) {
        self.0.insert(node_id, channel);
        self.1.remove(&node_id);
    }

    pub(crate) async fn close_all(&mut self) {
        let channels: Vec<_> = self.0.values().cloned().collect();
        for channel in channels {
            channel.lock().await.close();
        }
        self.1.clear();
    }

    pub(crate) fn disable_sender(&mut self, node_id: NodeId) {
        if self.0.contains_key(&node_id) {
            self.1.insert(node_id);
        }
    }

    pub(crate) fn enable_sender(&mut self, node_id: NodeId) {
        self.1.remove(&node_id);
    }

    pub(crate) fn clear_disabled(&mut self) {
        self.1.clear();
    }

    pub fn get_sender_ids(&self) -> Vec<NodeId> {
        self.keys()
    }

    fn get(&self, id: &NodeId) -> Option<Arc<Mutex<InChannel>>> {
        self.0.get(id).cloned()
    }

    fn keys(&self) -> Vec<NodeId> {
        self.0
            .keys()
            .filter(|id| !self.is_disabled(id))
            .copied()
            .collect()
    }

    fn is_disabled(&self, id: &NodeId) -> bool {
        self.1.contains(id)
    }
}

pub enum InChannel {
    Mpsc(mpsc::Receiver<Content>),
    Bcst(broadcast::Receiver<Content>),
}

impl InChannel {
    async fn recv(&mut self, channel_id: NodeId) -> DagrsResult<Content> {
        match self {
            InChannel::Mpsc(receiver) => receiver.recv().await.ok_or_else(|| {
                DagrsError::new(ErrorCode::DgChn0002Closed, "channel is closed")
                    .with_channel(channel_id.as_usize())
            }),
            InChannel::Bcst(receiver) => match receiver.recv().await {
                Ok(v) => Ok(v),
                Err(broadcast::error::RecvError::Closed) => Err(DagrsError::new(
                    ErrorCode::DgChn0002Closed,
                    "channel is closed",
                )
                .with_channel(channel_id.as_usize())),
                Err(broadcast::error::RecvError::Lagged(x)) => Err(DagrsError::new(
                    ErrorCode::DgChn0003Lagged,
                    "channel receiver lagged behind broadcast sender",
                )
                .with_channel(channel_id.as_usize())
                .with_detail("lagged", x.to_string())),
            },
        }
    }

    fn close(&mut self) {
        match self {
            InChannel::Mpsc(receiver) => receiver.close(),
            InChannel::Bcst(_) => (),
        }
    }
}

#[derive(Default)]
pub struct TypedInChannels<T: Send + Sync + 'static>(
    pub(crate) HashMap<NodeId, Arc<Mutex<InChannel>>>,
    pub(crate) HashSet<NodeId>,
    pub(crate) PhantomData<T>,
);

impl<T: Send + Sync + 'static> TypedInChannels<T> {
    pub async fn recv_from(&mut self, id: &NodeId) -> DagrsResult<Option<Arc<T>>> {
        if self.is_disabled(id) {
            return Err(disabled_channel(*id));
        }
        match self.get(id) {
            Some(channel) => {
                let content: Content = channel.lock().await.recv(*id).await?;
                Ok(content.into_inner())
            }
            None => Err(no_such_channel(*id)),
        }
    }

    pub async fn recv_any(&mut self) -> DagrsResult<(NodeId, Option<Arc<T>>)> {
        let mut futures = Vec::new();
        let ids = self.keys();

        for id in ids {
            let channel = self.get(&id).ok_or_else(|| no_such_channel(id))?;
            let fut = Box::pin(async move {
                let content: Content = channel.lock().await.recv(id).await?;
                Ok::<_, DagrsError>((id, content.into_inner()))
            });
            futures.push(fut);
        }

        if futures.is_empty() {
            return Err(no_such_channel(NodeId(0)).with_detail("scope", "typed_recv_any"));
        }

        match select_ok(futures).await {
            Ok((result, _)) => Ok(result),
            Err(err) => Err(err),
        }
    }

    pub async fn map<F, U>(&mut self, f: F) -> Vec<U>
    where
        F: FnMut(DagrsResult<Option<Arc<T>>>) -> U,
    {
        let disabled = self.1.clone();
        let futures = self
            .0
            .iter_mut()
            .filter(|(id, _)| !disabled.contains(id))
            .map(|(id, c)| async move {
                let content: Content = c.lock().await.recv(*id).await?;
                Ok(content.into_inner())
            });
        join_all(futures).await.into_iter().map(f).collect()
    }

    pub async fn close(&mut self, id: &NodeId) {
        if let Some(c) = self.get(id) {
            c.lock().await.close();
            self.0.remove(id);
            self.1.remove(id);
        }
    }

    fn get(&self, id: &NodeId) -> Option<Arc<Mutex<InChannel>>> {
        self.0.get(id).cloned()
    }

    fn keys(&self) -> Vec<NodeId> {
        self.0
            .keys()
            .filter(|id| !self.is_disabled(id))
            .copied()
            .collect()
    }

    fn is_disabled(&self, id: &NodeId) -> bool {
        self.1.contains(id)
    }
}

fn no_such_channel(id: NodeId) -> DagrsError {
    DagrsError::new(ErrorCode::DgChn0001NoSuchChannel, "channel not found")
        .with_channel(id.as_usize())
}

fn disabled_channel(id: NodeId) -> DagrsError {
    DagrsError::new(
        ErrorCode::DgChn0002Closed,
        "channel is disabled by restored control flow state",
    )
    .with_channel(id.as_usize())
    .with_detail("state", "disabled")
}