engineio_rs/server/
server.rs

1use std::{
2    sync::{atomic::AtomicUsize, Arc},
3    time::Duration,
4};
5
6use bytes::Bytes;
7use dashmap::DashMap;
8use tokio::{
9    net::TcpListener,
10    sync::{
11        mpsc::{Receiver, Sender},
12        Mutex,
13    },
14    time::{interval, Instant},
15};
16use tracing::{trace, warn};
17
18use crate::{
19    error::Result,
20    packet::HandshakePacket,
21    server::http::{handle_http, PollingHandle},
22    socket::Socket,
23    transports::TransportType,
24    Event, Packet, PacketType, Sid,
25};
26
27#[derive(Clone)]
28pub struct Server {
29    pub(super) inner: Arc<ServerInner>,
30}
31
32pub(super) struct ServerInner {
33    pub(super) port: u16,
34    pub(super) server_option: ServerOption,
35    pub(super) id_generator: SidGenerator,
36    pub(super) polling_handles: Arc<DashMap<Sid, PollingHandle>>,
37    pub(super) polling_buffer: usize,
38    pub(super) event_tx: Arc<Sender<Event>>,
39    pub(super) event_rx: Arc<Mutex<Receiver<Event>>>,
40    pub(super) sockets: Arc<DashMap<Sid, Socket>>,
41}
42
43#[derive(Debug, Clone, Copy)]
44pub struct ServerOption {
45    pub ping_timeout: u64,
46    pub ping_interval: u64,
47    pub max_payload: usize,
48}
49
50#[derive(Default)]
51pub(super) struct SidGenerator {
52    seq: AtomicUsize,
53}
54
55impl Server {
56    pub async fn serve(&self) {
57        let addr = format!("0.0.0.0:{}", self.inner.port);
58        let listener = TcpListener::bind(&addr)
59            .await
60            .expect("engine-io server can not listen port");
61
62        while let Ok((stream, peer_addr)) = listener.accept().await {
63            let server = self.clone();
64            tokio::spawn(async move { handle_http(server, stream, peer_addr).await });
65        }
66    }
67
68    pub async fn emit(&self, sid: &Sid, packet: Packet) -> Result<()> {
69        trace!("emit {} {:?}", sid, packet);
70        let sockets = &self.inner.sockets;
71        let socket = sockets.get(sid);
72        if let Some(s) = socket {
73            s.emit(packet).await?;
74        }
75        Ok(())
76    }
77
78    pub fn event_rx(&self) -> Arc<Mutex<Receiver<Event>>> {
79        self.inner.event_rx.clone()
80    }
81
82    pub async fn socket(&self, sid: &Sid) -> Option<Socket> {
83        let sockets = &self.inner.sockets;
84        sockets.get(sid).map(|x| x.to_owned())
85    }
86
87    pub async fn close_socket(&self, sid: &Sid) {
88        let sockets = &self.inner.sockets;
89        if let Some((_, socket)) = sockets.remove(sid) {
90            let _ = socket.disconnect().await;
91        }
92    }
93
94    pub(crate) fn polling_handles(&self) -> Arc<DashMap<Sid, PollingHandle>> {
95        self.inner.polling_handles.clone()
96    }
97
98    pub(crate) async fn polling_handle(&self, sid: &Sid) -> Option<PollingHandle> {
99        let handles = &self.inner.polling_handles;
100        let handle = handles.get(sid);
101        handle.map(|h| h.to_owned())
102    }
103
104    pub(crate) async fn drain_polling(&self, sid: &Sid) {
105        if let Some(socket) = self.socket(sid).await {
106            let _ = socket.emit(Packet::noop()).await;
107        }
108    }
109
110    pub(crate) fn polling_buffer(&self) -> usize {
111        self.inner.polling_buffer
112    }
113
114    pub(crate) fn generate_sid(&self) -> Sid {
115        self.inner.id_generator.generate()
116    }
117
118    pub(crate) fn handshake_packet(
119        &self,
120        upgrades: Vec<String>,
121        sid: Option<Sid>,
122    ) -> HandshakePacket {
123        let sid = match sid {
124            Some(sid) => sid,
125            None => self.inner.id_generator.generate(),
126        };
127
128        HandshakePacket {
129            sid,
130            upgrades,
131            ping_interval: self.inner.server_option.ping_interval,
132            ping_timeout: self.inner.server_option.ping_timeout,
133            max_payload: self.inner.server_option.max_payload,
134        }
135    }
136
137    pub(crate) async fn store_transport(
138        &self,
139        sid: Sid,
140        transport: TransportType,
141        is_upgrade: bool,
142    ) -> Result<()> {
143        trace!("store_transport {} {:?}", sid, transport);
144        let handshake = self.handshake_packet(vec!["webscocket".to_owned()], Some(sid.clone()));
145        if is_upgrade {
146            let sockets = &self.inner.sockets;
147            match sockets.get_mut(&sid) {
148                Some(socket) => socket.upgrade(transport).await,
149                None => warn!("upgrade polling not exist {:?}", sid),
150            };
151        } else {
152            let socket = Socket::new(
153                transport,
154                handshake,
155                Some(self.inner.event_tx.clone()),
156                false, // server no need to pong
157                true,
158            );
159
160            socket.connect().await?;
161
162            let sockets = &self.inner.sockets;
163            let _ = sockets.insert(sid.clone(), socket);
164            self.start_ping_pong(&sid);
165        }
166
167        Ok(())
168    }
169
170    pub(crate) fn start_ping_pong(&self, sid: &Sid) {
171        let sid = sid.to_owned();
172        let server = self.clone();
173        let option = server.inner.server_option;
174        let timeout = Duration::from_millis(option.ping_timeout + option.ping_interval);
175        let duration = Duration::from_millis(option.ping_interval);
176        trace!("start_ping_pong {} interval {:?}", sid, duration);
177        let mut interval = interval(duration);
178
179        tokio::spawn(async move {
180            loop {
181                interval.tick().await;
182                let ping_packet = Packet {
183                    ptype: PacketType::Ping,
184                    data: Bytes::new(),
185                };
186                if let Err(e) = server.emit(&sid, ping_packet).await {
187                    trace!("emit ping error {} {}", sid, e);
188                    break;
189                };
190                let last_pong = server.last_pong(&sid).await;
191                match last_pong {
192                    Some(instant) if instant.elapsed() < timeout => {}
193                    _ => break,
194                }
195            }
196            trace!("pong_timeout close {}", sid);
197            server.close_socket(&sid).await;
198        });
199    }
200
201    pub(crate) fn max_payload(&self) -> usize {
202        1000
203    }
204
205    async fn last_pong(&self, sid: &Sid) -> Option<Instant> {
206        let sockets = &self.inner.sockets;
207        Some(sockets.get(sid)?.last_pong().await)
208    }
209}
210
211impl Default for ServerOption {
212    fn default() -> Self {
213        Self {
214            ping_timeout: 25000,
215            ping_interval: 20000,
216            max_payload: 102400,
217        }
218    }
219}
220
221impl SidGenerator {
222    fn generate(&self) -> Sid {
223        let seq = self.seq.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
224        Arc::new(base64::encode(seq.to_string()))
225    }
226}
227
228#[cfg(test)]
229mod test {
230    use super::*;
231
232    use std::time::Duration;
233
234    use futures_util::{Stream, StreamExt};
235    use reqwest::Url;
236
237    use crate::{server::builder::ServerBuilder, socket::SocketBuilder, Packet};
238
239    #[tokio::test(flavor = "multi_thread", worker_threads = 3)]
240    async fn test_connection() -> Result<()> {
241        // tracing_subscriber::fmt()
242        //     .with_env_filter("engineio=trace")
243        //     .init();
244        let url = crate::test::rust_engine_io_server();
245        let (mut rx, _server) = start_server(url.clone()).await;
246
247        let socket = SocketBuilder::new(url.clone()).build_polling().await?;
248        test_data_transport(socket, &mut rx).await?;
249
250        let socket = SocketBuilder::new(url.clone()).build().await?;
251        test_data_transport(socket, &mut rx).await?;
252
253        let socket = SocketBuilder::new(url.clone()).build_websocket().await?;
254        test_data_transport(socket, &mut rx).await?;
255
256        let socket = SocketBuilder::new(url)
257            .build_websocket_with_upgrade()
258            .await?;
259        test_data_transport(socket, &mut rx).await?;
260
261        Ok(())
262    }
263
264    #[tokio::test]
265    async fn test_pong_timeout() -> Result<()> {
266        // tracing_subscriber::fmt() .with_env_filter("engineio=trace") .init();
267        let url = crate::test::rust_engine_io_timeout_server();
268        let _ = start_server(url.clone()).await;
269
270        let socket = SocketBuilder::new(url.clone())
271            .should_pong_for_test(false)
272            .build_polling()
273            .await?;
274        test_transport_timeout(socket).await?;
275
276        let socket = SocketBuilder::new(url.clone())
277            .should_pong_for_test(false)
278            .build()
279            .await?;
280        test_transport_timeout(socket).await?;
281
282        let socket = SocketBuilder::new(url.clone())
283            .should_pong_for_test(false)
284            .build_websocket()
285            .await?;
286        test_transport_timeout(socket).await?;
287
288        let socket = SocketBuilder::new(url)
289            .should_pong_for_test(false)
290            .build_websocket_with_upgrade()
291            .await?;
292        test_transport_timeout(socket).await?;
293
294        Ok(())
295    }
296
297    async fn test_transport_timeout(mut client: Socket) -> Result<()> {
298        client.connect().await?;
299
300        let client_clone = client.clone();
301        tokio::spawn(async move {
302            loop {
303                let next = client.next().await;
304                if next.is_none() {
305                    break;
306                }
307            }
308        });
309
310        tokio::time::sleep(Duration::from_millis(200)).await;
311
312        // closed by server
313        assert!(!client_clone.is_connected());
314
315        Ok(())
316    }
317
318    async fn start_server(url: Url) -> (Receiver<String>, Server) {
319        let port = url.port().unwrap();
320        let server_option = ServerOption {
321            ping_timeout: 20,
322            ping_interval: 20,
323            max_payload: 102400,
324        };
325        let (server, rx) = setup(port, server_option);
326        let server_clone = server.clone();
327
328        tokio::spawn(async move {
329            server_clone.serve().await;
330        });
331
332        // wait server start
333        tokio::time::sleep(Duration::from_millis(100)).await;
334
335        (rx, server)
336    }
337
338    fn setup(port: u16, server_option: ServerOption) -> (Server, Receiver<String>) {
339        let (tx, rx) = tokio::sync::mpsc::channel(100);
340        let server = ServerBuilder::new(port)
341            .polling_buffer(100)
342            .event_size(100)
343            .server_option(server_option)
344            .build();
345
346        let event_rx = server.event_rx();
347        let server_clone = server.clone();
348
349        tokio::spawn(async move {
350            let mut event_rx = event_rx.lock().await;
351
352            while let Some(event) = event_rx.recv().await {
353                match event {
354                    Event::OnOpen(sid) => {
355                        let socket = server_clone.socket(&sid).await;
356                        poll_stream(socket.unwrap());
357                        let _ = tx.send(format!("open {}", sid)).await;
358                    }
359                    Event::OnPacket(_sid, packet) => {
360                        let _ = tx.send(String::from(packet.ptype)).await;
361                    }
362                    Event::OnData(_sid, data) => {
363                        let data = std::str::from_utf8(&data).unwrap();
364                        let _ = tx.send(data.to_owned()).await;
365                    }
366                    Event::OnClose(_sid) => {
367                        let _ = tx.send("close".to_owned()).await;
368                    }
369                    _ => {}
370                };
371            }
372        });
373
374        (server, rx)
375    }
376
377    async fn test_data_transport(client: Socket, server_rx: &mut Receiver<String>) -> Result<()> {
378        client.connect().await?;
379        let client_clone = client.clone();
380
381        // ignore item send by last client
382        while let Some(item) = server_rx.recv().await {
383            if item.starts_with("open") {
384                break;
385            }
386        }
387        poll_stream(client_clone);
388
389        client
390            .emit(Packet::new(crate::PacketType::Message, Bytes::from("msg")))
391            .await?;
392
393        // wait ping pong
394        tokio::time::sleep(Duration::from_millis(100)).await;
395
396        client.disconnect().await?;
397
398        let mut receive_pong = false;
399        let mut receive_msg = false;
400
401        while let Some(item) = server_rx.recv().await {
402            match item.as_str() {
403                "3" => receive_pong = true,
404                "msg" => receive_msg = true,
405                "close" => break,
406                _ => {}
407            }
408        }
409
410        assert!(receive_pong);
411        assert!(receive_msg);
412        assert!(!client.is_connected());
413
414        Ok(())
415    }
416
417    fn poll_stream(mut stream: impl Stream + Unpin + Send + 'static) {
418        tokio::spawn(async move { while stream.next().await.is_some() {} });
419    }
420}