Skip to main content

rs_pkg/network/websocket/server/
server.rs

1use super::super::{ErrorHandlerType, MessageHandlerType};
2use super::config::Config;
3use crate::async_fn::wrap_fn;
4use futures_util::{SinkExt, StreamExt};
5use std::fmt::Debug;
6use std::{
7    error::Error,
8    net::{Ipv4Addr, SocketAddr, SocketAddrV4},
9    str::FromStr,
10    sync::Arc,
11};
12use tokio::sync::broadcast;
13use tokio::{
14    net::TcpListener,
15    select,
16    sync::{
17        Mutex, RwLock,
18        mpsc::{self, Sender},
19    },
20};
21use tokio_tungstenite::{accept_async, tungstenite::Message};
22use tracing::{Instrument, debug_span, error, info, warn};
23
24#[derive(Clone)]
25enum ServerType {
26    Independent(SocketAddr),
27    AxumHandler,
28}
29
30#[derive(Clone)]
31pub struct Server<M> {
32    typ: ServerType,
33    pub(crate) message_handler: Arc<MessageHandlerType<M>>,
34    pub(crate) error_handler: Arc<ErrorHandlerType>,
35    pub(crate) id: Arc<RwLock<isize>>,
36
37    close: Arc<Sender<()>>,
38    inner_close: Arc<RwLock<broadcast::Sender<()>>>,
39
40    broadcast: Arc<RwLock<broadcast::Sender<M>>>,
41}
42
43impl<M> Server<M>
44where
45    M: Clone,
46{
47    pub fn new(cfg: &Config) -> Self {
48        let mut server_type = ServerType::AxumHandler;
49        if !cfg.is_router {
50            server_type = ServerType::Independent(SocketAddr::V4(SocketAddrV4::new(
51                Ipv4Addr::from_str(&cfg.host).expect("invalid ws server host"),
52                cfg.port,
53            )));
54        }
55
56        let (send, mut recv) = mpsc::channel(1);
57        let (bs, _) = broadcast::channel(1);
58        let bs = Arc::new(RwLock::new(bs));
59        let bs_monitor = bs.clone();
60        tokio::spawn(async move {
61            let guard = bs_monitor.read().await;
62            recv.recv().await;
63            _ = guard.send(());
64        });
65
66        let (broadcast_sender, _) = broadcast::channel(1000);
67
68        Server {
69            typ: server_type,
70            message_handler: wrap_fn(|_| async { None }),
71            error_handler: wrap_fn(|_| async {}),
72            id: Arc::new(RwLock::new(Default::default())),
73            close: Arc::new(send),
74            inner_close: bs.clone(),
75            broadcast: Arc::new(RwLock::new(broadcast_sender)),
76        }
77    }
78
79    pub async fn stop(&self) -> Result<(), mpsc::error::SendError<()>> {
80        self.close.send(()).await
81    }
82
83    pub fn with_message_handler<F, Fut>(mut self, h: F) -> Self
84    where
85        F: Fn(M) -> Fut + Send + Sync + 'static,
86        Fut: Future<Output = Option<M>> + Send + Sync + 'static,
87    {
88        self.message_handler = wrap_fn(h);
89        self
90    }
91
92    pub fn with_error_handler<F, Fut>(mut self, h: F) -> Self
93    where
94        F: Fn(Box<dyn Error + Send + Sync + 'static>) -> Fut + Send + Sync + 'static,
95        Fut: Future<Output = ()> + Send + Sync + 'static,
96    {
97        self.error_handler = wrap_fn(h);
98        self
99    }
100
101    pub async fn broadcast(&self, msg: M) -> Result<usize, broadcast::error::SendError<M>> {
102        let guard = self.broadcast.read().await;
103        guard.send(msg)
104    }
105
106    pub async fn handle_stream<S, E>(&self, s: S)
107    where
108        M: Debug + Send + Sync + 'static,
109        S: StreamExt<Item = Result<M, E>> + SinkExt<M, Error = E> + Sized + Send + 'static,
110        E: Error + Send + Sync + 'static,
111    {
112        let mut id = self.id.write().await;
113        let current_id = *id;
114        *id += 1;
115
116        let close_guard = self.inner_close.read().await;
117        let mut done = close_guard.subscribe();
118        let (sink, stream) = s.split();
119        let sink = Arc::new(Mutex::new(sink));
120        let stream = Arc::new(Mutex::new(stream));
121        let msg_handler = self.message_handler.clone();
122        let err_handler = self.error_handler.clone();
123        let broadcast_sender = self.broadcast.read().await;
124        let mut broadcast_receiver = broadcast_sender.subscribe();
125
126        tokio::spawn(async move {
127            let mut sink = sink.lock().await;
128            let mut guard = stream.lock().await;
129            loop {
130                select! {
131                    _ = done.recv() => {
132                        warn!("Conn {} Exit with done", current_id);
133                        break
134                    },
135
136                    m = broadcast_receiver.recv() => {
137                            match m {
138                                Ok(msg) => {
139                                    if let Err(e) = sink.send(msg).await {
140                                        err_handler(Box::new(e)).await;
141                                        return
142                                    };
143                                }
144
145                                Err(e) => {
146                                    err_handler(Box::new(e)).await;
147                                    return
148                                }
149                            }
150                    }
151
152                    t = guard.next() => {
153                        match t {
154                            Some(Ok(msg)) => {
155                                if let Some(msg) = msg_handler(msg).await {
156                                    if let Err(e) = sink.send(msg).await {
157                                        err_handler(Box::new(e)).await;
158                                    };
159                                }
160                            },
161                            Some(Err(err)) => {
162                                err_handler(Box::new(err)).await;
163                            },
164                            None => return,
165                        }
166                    }
167                }
168            }
169        });
170    }
171}
172
173impl Server<Message> {
174    pub async fn run(&self) {
175        match self.typ {
176            ServerType::Independent(addr) => {
177                let l = TcpListener::bind(addr)
178                    .await
179                    .inspect_err(|e| error!("bind listener failed: {}", e))
180                    .unwrap();
181
182                info!("Websocket Server host on: {}", addr.to_string());
183
184                let guard = self.inner_close.read().await;
185                let mut server_done = guard.subscribe();
186
187                let s = self.clone();
188                tokio::spawn(
189                    async move {
190                    loop {
191                        select! {
192			                      Ok((tcp_stream,_)) = l.accept() => {
193																if let Ok(ws_stream) = accept_async(tcp_stream) .await
194		                                .inspect_err(|e| error!("new stream failed: {}", e))
195		                            {
196																		s.handle_stream(ws_stream).await;
197		                            }
198														}
199
200														_ = server_done.recv() => {
201															return
202														}
203                        }
204                    }
205                  }
206                    .instrument(debug_span!("new_conn")),
207                );
208            }
209
210            ServerType::AxumHandler => panic!("unexpected failed"),
211        }
212    }
213}