rs_pkg/network/websocket/server/
server.rs1use super::super::{ErrorHandlerType, MessageHandlerType};
2use super::config::Config;
3use crate::async_fn::wrap_fn;
4use futures_util::{SinkExt, StreamExt};
5use std::fmt::Debug;
6use std::{
7 error::Error,
8 net::{Ipv4Addr, SocketAddr, SocketAddrV4},
9 str::FromStr,
10 sync::Arc,
11};
12use tokio::sync::broadcast;
13use tokio::{
14 net::TcpListener,
15 select,
16 sync::{
17 Mutex, RwLock,
18 mpsc::{self, Sender},
19 },
20};
21use tokio_tungstenite::{accept_async, tungstenite::Message};
22use tracing::{Instrument, debug_span, error, info, warn};
23
24#[derive(Clone)]
25enum ServerType {
26 Independent(SocketAddr),
27 AxumHandler,
28}
29
30#[derive(Clone)]
31pub struct Server<M> {
32 typ: ServerType,
33 pub(crate) message_handler: Arc<MessageHandlerType<M>>,
34 pub(crate) error_handler: Arc<ErrorHandlerType>,
35 pub(crate) id: Arc<RwLock<isize>>,
36
37 close: Arc<Sender<()>>,
38 inner_close: Arc<RwLock<broadcast::Sender<()>>>,
39
40 broadcast: Arc<RwLock<broadcast::Sender<M>>>,
41}
42
43impl<M> Server<M>
44where
45 M: Clone,
46{
47 pub fn new(cfg: &Config) -> Self {
48 let mut server_type = ServerType::AxumHandler;
49 if !cfg.is_router {
50 server_type = ServerType::Independent(SocketAddr::V4(SocketAddrV4::new(
51 Ipv4Addr::from_str(&cfg.host).expect("invalid ws server host"),
52 cfg.port,
53 )));
54 }
55
56 let (send, mut recv) = mpsc::channel(1);
57 let (bs, _) = broadcast::channel(1);
58 let bs = Arc::new(RwLock::new(bs));
59 let bs_monitor = bs.clone();
60 tokio::spawn(async move {
61 let guard = bs_monitor.read().await;
62 recv.recv().await;
63 _ = guard.send(());
64 });
65
66 let (broadcast_sender, _) = broadcast::channel(1000);
67
68 Server {
69 typ: server_type,
70 message_handler: wrap_fn(|_| async { None }),
71 error_handler: wrap_fn(|_| async {}),
72 id: Arc::new(RwLock::new(Default::default())),
73 close: Arc::new(send),
74 inner_close: bs.clone(),
75 broadcast: Arc::new(RwLock::new(broadcast_sender)),
76 }
77 }
78
79 pub async fn stop(&self) -> Result<(), mpsc::error::SendError<()>> {
80 self.close.send(()).await
81 }
82
83 pub fn with_message_handler<F, Fut>(mut self, h: F) -> Self
84 where
85 F: Fn(M) -> Fut + Send + Sync + 'static,
86 Fut: Future<Output = Option<M>> + Send + Sync + 'static,
87 {
88 self.message_handler = wrap_fn(h);
89 self
90 }
91
92 pub fn with_error_handler<F, Fut>(mut self, h: F) -> Self
93 where
94 F: Fn(Box<dyn Error + Send + Sync + 'static>) -> Fut + Send + Sync + 'static,
95 Fut: Future<Output = ()> + Send + Sync + 'static,
96 {
97 self.error_handler = wrap_fn(h);
98 self
99 }
100
101 pub async fn broadcast(&self, msg: M) -> Result<usize, broadcast::error::SendError<M>> {
102 let guard = self.broadcast.read().await;
103 guard.send(msg)
104 }
105
106 pub async fn handle_stream<S, E>(&self, s: S)
107 where
108 M: Debug + Send + Sync + 'static,
109 S: StreamExt<Item = Result<M, E>> + SinkExt<M, Error = E> + Sized + Send + 'static,
110 E: Error + Send + Sync + 'static,
111 {
112 let mut id = self.id.write().await;
113 let current_id = *id;
114 *id += 1;
115
116 let close_guard = self.inner_close.read().await;
117 let mut done = close_guard.subscribe();
118 let (sink, stream) = s.split();
119 let sink = Arc::new(Mutex::new(sink));
120 let stream = Arc::new(Mutex::new(stream));
121 let msg_handler = self.message_handler.clone();
122 let err_handler = self.error_handler.clone();
123 let broadcast_sender = self.broadcast.read().await;
124 let mut broadcast_receiver = broadcast_sender.subscribe();
125
126 tokio::spawn(async move {
127 let mut sink = sink.lock().await;
128 let mut guard = stream.lock().await;
129 loop {
130 select! {
131 _ = done.recv() => {
132 warn!("Conn {} Exit with done", current_id);
133 break
134 },
135
136 m = broadcast_receiver.recv() => {
137 match m {
138 Ok(msg) => {
139 if let Err(e) = sink.send(msg).await {
140 err_handler(Box::new(e)).await;
141 return
142 };
143 }
144
145 Err(e) => {
146 err_handler(Box::new(e)).await;
147 return
148 }
149 }
150 }
151
152 t = guard.next() => {
153 match t {
154 Some(Ok(msg)) => {
155 if let Some(msg) = msg_handler(msg).await {
156 if let Err(e) = sink.send(msg).await {
157 err_handler(Box::new(e)).await;
158 };
159 }
160 },
161 Some(Err(err)) => {
162 err_handler(Box::new(err)).await;
163 },
164 None => return,
165 }
166 }
167 }
168 }
169 });
170 }
171}
172
173impl Server<Message> {
174 pub async fn run(&self) {
175 match self.typ {
176 ServerType::Independent(addr) => {
177 let l = TcpListener::bind(addr)
178 .await
179 .inspect_err(|e| error!("bind listener failed: {}", e))
180 .unwrap();
181
182 info!("Websocket Server host on: {}", addr.to_string());
183
184 let guard = self.inner_close.read().await;
185 let mut server_done = guard.subscribe();
186
187 let s = self.clone();
188 tokio::spawn(
189 async move {
190 loop {
191 select! {
192 Ok((tcp_stream,_)) = l.accept() => {
193 if let Ok(ws_stream) = accept_async(tcp_stream) .await
194 .inspect_err(|e| error!("new stream failed: {}", e))
195 {
196 s.handle_stream(ws_stream).await;
197 }
198 }
199
200 _ = server_done.recv() => {
201 return
202 }
203 }
204 }
205 }
206 .instrument(debug_span!("new_conn")),
207 );
208 }
209
210 ServerType::AxumHandler => panic!("unexpected failed"),
211 }
212 }
213}