rs-pkg 0.0.3

rust pkg for programming
Documentation
use super::super::{ErrorHandlerType, MessageHandlerType};
use super::config::Config;
use crate::async_fn::wrap_fn;
use futures_util::{SinkExt, StreamExt};
use std::fmt::Debug;
use std::{
    error::Error,
    net::{Ipv4Addr, SocketAddr, SocketAddrV4},
    str::FromStr,
    sync::Arc,
};
use tokio::sync::broadcast;
use tokio::{
    net::TcpListener,
    select,
    sync::{
        Mutex, RwLock,
        mpsc::{self, Sender},
    },
};
use tokio_tungstenite::{accept_async, tungstenite::Message};
use tracing::{Instrument, debug_span, error, info, warn};

#[derive(Clone)]
enum ServerType {
    Independent(SocketAddr),
    AxumHandler,
}

#[derive(Clone)]
pub struct Server<M> {
    typ: ServerType,
    pub(crate) message_handler: Arc<MessageHandlerType<M>>,
    pub(crate) error_handler: Arc<ErrorHandlerType>,
    pub(crate) id: Arc<RwLock<isize>>,

    close: Arc<Sender<()>>,
    inner_close: Arc<RwLock<broadcast::Sender<()>>>,

    broadcast: Arc<RwLock<broadcast::Sender<M>>>,
}

impl<M> Server<M>
where
    M: Clone,
{
    pub fn new(cfg: &Config) -> Self {
        let mut server_type = ServerType::AxumHandler;
        if !cfg.is_router {
            server_type = ServerType::Independent(SocketAddr::V4(SocketAddrV4::new(
                Ipv4Addr::from_str(&cfg.host).expect("invalid ws server host"),
                cfg.port,
            )));
        }

        let (send, mut recv) = mpsc::channel(1);
        let (bs, _) = broadcast::channel(1);
        let bs = Arc::new(RwLock::new(bs));
        let bs_monitor = bs.clone();
        tokio::spawn(async move {
            let guard = bs_monitor.read().await;
            recv.recv().await;
            _ = guard.send(());
        });

        let (broadcast_sender, _) = broadcast::channel(1000);

        Server {
            typ: server_type,
            message_handler: wrap_fn(|_| async { None }),
            error_handler: wrap_fn(|_| async {}),
            id: Arc::new(RwLock::new(Default::default())),
            close: Arc::new(send),
            inner_close: bs.clone(),
            broadcast: Arc::new(RwLock::new(broadcast_sender)),
        }
    }

    pub async fn stop(&self) -> Result<(), mpsc::error::SendError<()>> {
        self.close.send(()).await
    }

    pub fn with_message_handler<F, Fut>(mut self, h: F) -> Self
    where
        F: Fn(M) -> Fut + Send + Sync + 'static,
        Fut: Future<Output = Option<M>> + Send + Sync + 'static,
    {
        self.message_handler = wrap_fn(h);
        self
    }

    pub fn with_error_handler<F, Fut>(mut self, h: F) -> Self
    where
        F: Fn(Box<dyn Error + Send + Sync + 'static>) -> Fut + Send + Sync + 'static,
        Fut: Future<Output = ()> + Send + Sync + 'static,
    {
        self.error_handler = wrap_fn(h);
        self
    }

    pub async fn broadcast(&self, msg: M) -> Result<usize, broadcast::error::SendError<M>> {
        let guard = self.broadcast.read().await;
        guard.send(msg)
    }

    pub async fn handle_stream<S, E>(&self, s: S)
    where
        M: Debug + Send + Sync + 'static,
        S: StreamExt<Item = Result<M, E>> + SinkExt<M, Error = E> + Sized + Send + 'static,
        E: Error + Send + Sync + 'static,
    {
        let mut id = self.id.write().await;
        let current_id = *id;
        *id += 1;

        let close_guard = self.inner_close.read().await;
        let mut done = close_guard.subscribe();
        let (sink, stream) = s.split();
        let sink = Arc::new(Mutex::new(sink));
        let stream = Arc::new(Mutex::new(stream));
        let msg_handler = self.message_handler.clone();
        let err_handler = self.error_handler.clone();
        let broadcast_sender = self.broadcast.read().await;
        let mut broadcast_receiver = broadcast_sender.subscribe();

        tokio::spawn(async move {
            let mut sink = sink.lock().await;
            let mut guard = stream.lock().await;
            loop {
                select! {
                    _ = done.recv() => {
                        warn!("Conn {} Exit with done", current_id);
                        break
                    },

                    m = broadcast_receiver.recv() => {
                            match m {
                                Ok(msg) => {
                                    if let Err(e) = sink.send(msg).await {
                                        err_handler(Box::new(e)).await;
                                        return
                                    };
                                }

                                Err(e) => {
                                    err_handler(Box::new(e)).await;
                                    return
                                }
                            }
                    }

                    t = guard.next() => {
                        match t {
                            Some(Ok(msg)) => {
                                if let Some(msg) = msg_handler(msg).await {
                                    if let Err(e) = sink.send(msg).await {
                                        err_handler(Box::new(e)).await;
                                    };
                                }
                            },
                            Some(Err(err)) => {
                                err_handler(Box::new(err)).await;
                            },
                            None => return,
                        }
                    }
                }
            }
        });
    }
}

impl Server<Message> {
    pub async fn run(&self) {
        match self.typ {
            ServerType::Independent(addr) => {
                let l = TcpListener::bind(addr)
                    .await
                    .inspect_err(|e| error!("bind listener failed: {}", e))
                    .unwrap();

                info!("Websocket Server host on: {}", addr.to_string());

                let guard = self.inner_close.read().await;
                let mut server_done = guard.subscribe();

                let s = self.clone();
                tokio::spawn(
                    async move {
                    loop {
                        select! {
			                      Ok((tcp_stream,_)) = l.accept() => {
																if let Ok(ws_stream) = accept_async(tcp_stream) .await
		                                .inspect_err(|e| error!("new stream failed: {}", e))
		                            {
																		s.handle_stream(ws_stream).await;
		                            }
														}

														_ = server_done.recv() => {
															return
														}
                        }
                    }
                  }
                    .instrument(debug_span!("new_conn")),
                );
            }

            ServerType::AxumHandler => panic!("unexpected failed"),
        }
    }
}