1use std::{
2 sync::{atomic::AtomicUsize, Arc},
3 time::Duration,
4};
5
6use bytes::Bytes;
7use dashmap::DashMap;
8use tokio::{
9 net::TcpListener,
10 sync::{
11 mpsc::{Receiver, Sender},
12 Mutex,
13 },
14 time::{interval, Instant},
15};
16use tracing::{trace, warn};
17
18use crate::{
19 error::Result,
20 packet::HandshakePacket,
21 server::http::{handle_http, PollingHandle},
22 socket::Socket,
23 transports::TransportType,
24 Event, Packet, PacketType, Sid,
25};
26
27#[derive(Clone)]
28pub struct Server {
29 pub(super) inner: Arc<ServerInner>,
30}
31
32pub(super) struct ServerInner {
33 pub(super) port: u16,
34 pub(super) server_option: ServerOption,
35 pub(super) id_generator: SidGenerator,
36 pub(super) polling_handles: Arc<DashMap<Sid, PollingHandle>>,
37 pub(super) polling_buffer: usize,
38 pub(super) event_tx: Arc<Sender<Event>>,
39 pub(super) event_rx: Arc<Mutex<Receiver<Event>>>,
40 pub(super) sockets: Arc<DashMap<Sid, Socket>>,
41}
42
43#[derive(Debug, Clone, Copy)]
44pub struct ServerOption {
45 pub ping_timeout: u64,
46 pub ping_interval: u64,
47 pub max_payload: usize,
48}
49
50#[derive(Default)]
51pub(super) struct SidGenerator {
52 seq: AtomicUsize,
53}
54
55impl Server {
56 pub async fn serve(&self) {
57 let addr = format!("0.0.0.0:{}", self.inner.port);
58 let listener = TcpListener::bind(&addr)
59 .await
60 .expect("engine-io server can not listen port");
61
62 while let Ok((stream, peer_addr)) = listener.accept().await {
63 let server = self.clone();
64 tokio::spawn(async move { handle_http(server, stream, peer_addr).await });
65 }
66 }
67
68 pub async fn emit(&self, sid: &Sid, packet: Packet) -> Result<()> {
69 trace!("emit {} {:?}", sid, packet);
70 let sockets = &self.inner.sockets;
71 let socket = sockets.get(sid);
72 if let Some(s) = socket {
73 s.emit(packet).await?;
74 }
75 Ok(())
76 }
77
78 pub fn event_rx(&self) -> Arc<Mutex<Receiver<Event>>> {
79 self.inner.event_rx.clone()
80 }
81
82 pub async fn socket(&self, sid: &Sid) -> Option<Socket> {
83 let sockets = &self.inner.sockets;
84 sockets.get(sid).map(|x| x.to_owned())
85 }
86
87 pub async fn close_socket(&self, sid: &Sid) {
88 let sockets = &self.inner.sockets;
89 if let Some((_, socket)) = sockets.remove(sid) {
90 let _ = socket.disconnect().await;
91 }
92 }
93
94 pub(crate) fn polling_handles(&self) -> Arc<DashMap<Sid, PollingHandle>> {
95 self.inner.polling_handles.clone()
96 }
97
98 pub(crate) async fn polling_handle(&self, sid: &Sid) -> Option<PollingHandle> {
99 let handles = &self.inner.polling_handles;
100 let handle = handles.get(sid);
101 handle.map(|h| h.to_owned())
102 }
103
104 pub(crate) async fn drain_polling(&self, sid: &Sid) {
105 if let Some(socket) = self.socket(sid).await {
106 let _ = socket.emit(Packet::noop()).await;
107 }
108 }
109
110 pub(crate) fn polling_buffer(&self) -> usize {
111 self.inner.polling_buffer
112 }
113
114 pub(crate) fn generate_sid(&self) -> Sid {
115 self.inner.id_generator.generate()
116 }
117
118 pub(crate) fn handshake_packet(
119 &self,
120 upgrades: Vec<String>,
121 sid: Option<Sid>,
122 ) -> HandshakePacket {
123 let sid = match sid {
124 Some(sid) => sid,
125 None => self.inner.id_generator.generate(),
126 };
127
128 HandshakePacket {
129 sid,
130 upgrades,
131 ping_interval: self.inner.server_option.ping_interval,
132 ping_timeout: self.inner.server_option.ping_timeout,
133 max_payload: self.inner.server_option.max_payload,
134 }
135 }
136
137 pub(crate) async fn store_transport(
138 &self,
139 sid: Sid,
140 transport: TransportType,
141 is_upgrade: bool,
142 ) -> Result<()> {
143 trace!("store_transport {} {:?}", sid, transport);
144 let handshake = self.handshake_packet(vec!["webscocket".to_owned()], Some(sid.clone()));
145 if is_upgrade {
146 let sockets = &self.inner.sockets;
147 match sockets.get_mut(&sid) {
148 Some(socket) => socket.upgrade(transport).await,
149 None => warn!("upgrade polling not exist {:?}", sid),
150 };
151 } else {
152 let socket = Socket::new(
153 transport,
154 handshake,
155 Some(self.inner.event_tx.clone()),
156 false, true,
158 );
159
160 socket.connect().await?;
161
162 let sockets = &self.inner.sockets;
163 let _ = sockets.insert(sid.clone(), socket);
164 self.start_ping_pong(&sid);
165 }
166
167 Ok(())
168 }
169
170 pub(crate) fn start_ping_pong(&self, sid: &Sid) {
171 let sid = sid.to_owned();
172 let server = self.clone();
173 let option = server.inner.server_option;
174 let timeout = Duration::from_millis(option.ping_timeout + option.ping_interval);
175 let duration = Duration::from_millis(option.ping_interval);
176 trace!("start_ping_pong {} interval {:?}", sid, duration);
177 let mut interval = interval(duration);
178
179 tokio::spawn(async move {
180 loop {
181 interval.tick().await;
182 let ping_packet = Packet {
183 ptype: PacketType::Ping,
184 data: Bytes::new(),
185 };
186 if let Err(e) = server.emit(&sid, ping_packet).await {
187 trace!("emit ping error {} {}", sid, e);
188 break;
189 };
190 let last_pong = server.last_pong(&sid).await;
191 match last_pong {
192 Some(instant) if instant.elapsed() < timeout => {}
193 _ => break,
194 }
195 }
196 trace!("pong_timeout close {}", sid);
197 server.close_socket(&sid).await;
198 });
199 }
200
201 pub(crate) fn max_payload(&self) -> usize {
202 1000
203 }
204
205 async fn last_pong(&self, sid: &Sid) -> Option<Instant> {
206 let sockets = &self.inner.sockets;
207 Some(sockets.get(sid)?.last_pong().await)
208 }
209}
210
211impl Default for ServerOption {
212 fn default() -> Self {
213 Self {
214 ping_timeout: 25000,
215 ping_interval: 20000,
216 max_payload: 102400,
217 }
218 }
219}
220
221impl SidGenerator {
222 fn generate(&self) -> Sid {
223 let seq = self.seq.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
224 Arc::new(base64::encode(seq.to_string()))
225 }
226}
227
228#[cfg(test)]
229mod test {
230 use super::*;
231
232 use std::time::Duration;
233
234 use futures_util::{Stream, StreamExt};
235 use reqwest::Url;
236
237 use crate::{server::builder::ServerBuilder, socket::SocketBuilder, Packet};
238
239 #[tokio::test(flavor = "multi_thread", worker_threads = 3)]
240 async fn test_connection() -> Result<()> {
241 let url = crate::test::rust_engine_io_server();
245 let (mut rx, _server) = start_server(url.clone()).await;
246
247 let socket = SocketBuilder::new(url.clone()).build_polling().await?;
248 test_data_transport(socket, &mut rx).await?;
249
250 let socket = SocketBuilder::new(url.clone()).build().await?;
251 test_data_transport(socket, &mut rx).await?;
252
253 let socket = SocketBuilder::new(url.clone()).build_websocket().await?;
254 test_data_transport(socket, &mut rx).await?;
255
256 let socket = SocketBuilder::new(url)
257 .build_websocket_with_upgrade()
258 .await?;
259 test_data_transport(socket, &mut rx).await?;
260
261 Ok(())
262 }
263
264 #[tokio::test]
265 async fn test_pong_timeout() -> Result<()> {
266 let url = crate::test::rust_engine_io_timeout_server();
268 let _ = start_server(url.clone()).await;
269
270 let socket = SocketBuilder::new(url.clone())
271 .should_pong_for_test(false)
272 .build_polling()
273 .await?;
274 test_transport_timeout(socket).await?;
275
276 let socket = SocketBuilder::new(url.clone())
277 .should_pong_for_test(false)
278 .build()
279 .await?;
280 test_transport_timeout(socket).await?;
281
282 let socket = SocketBuilder::new(url.clone())
283 .should_pong_for_test(false)
284 .build_websocket()
285 .await?;
286 test_transport_timeout(socket).await?;
287
288 let socket = SocketBuilder::new(url)
289 .should_pong_for_test(false)
290 .build_websocket_with_upgrade()
291 .await?;
292 test_transport_timeout(socket).await?;
293
294 Ok(())
295 }
296
297 async fn test_transport_timeout(mut client: Socket) -> Result<()> {
298 client.connect().await?;
299
300 let client_clone = client.clone();
301 tokio::spawn(async move {
302 loop {
303 let next = client.next().await;
304 if next.is_none() {
305 break;
306 }
307 }
308 });
309
310 tokio::time::sleep(Duration::from_millis(200)).await;
311
312 assert!(!client_clone.is_connected());
314
315 Ok(())
316 }
317
318 async fn start_server(url: Url) -> (Receiver<String>, Server) {
319 let port = url.port().unwrap();
320 let server_option = ServerOption {
321 ping_timeout: 20,
322 ping_interval: 20,
323 max_payload: 102400,
324 };
325 let (server, rx) = setup(port, server_option);
326 let server_clone = server.clone();
327
328 tokio::spawn(async move {
329 server_clone.serve().await;
330 });
331
332 tokio::time::sleep(Duration::from_millis(100)).await;
334
335 (rx, server)
336 }
337
338 fn setup(port: u16, server_option: ServerOption) -> (Server, Receiver<String>) {
339 let (tx, rx) = tokio::sync::mpsc::channel(100);
340 let server = ServerBuilder::new(port)
341 .polling_buffer(100)
342 .event_size(100)
343 .server_option(server_option)
344 .build();
345
346 let event_rx = server.event_rx();
347 let server_clone = server.clone();
348
349 tokio::spawn(async move {
350 let mut event_rx = event_rx.lock().await;
351
352 while let Some(event) = event_rx.recv().await {
353 match event {
354 Event::OnOpen(sid) => {
355 let socket = server_clone.socket(&sid).await;
356 poll_stream(socket.unwrap());
357 let _ = tx.send(format!("open {}", sid)).await;
358 }
359 Event::OnPacket(_sid, packet) => {
360 let _ = tx.send(String::from(packet.ptype)).await;
361 }
362 Event::OnData(_sid, data) => {
363 let data = std::str::from_utf8(&data).unwrap();
364 let _ = tx.send(data.to_owned()).await;
365 }
366 Event::OnClose(_sid) => {
367 let _ = tx.send("close".to_owned()).await;
368 }
369 _ => {}
370 };
371 }
372 });
373
374 (server, rx)
375 }
376
377 async fn test_data_transport(client: Socket, server_rx: &mut Receiver<String>) -> Result<()> {
378 client.connect().await?;
379 let client_clone = client.clone();
380
381 while let Some(item) = server_rx.recv().await {
383 if item.starts_with("open") {
384 break;
385 }
386 }
387 poll_stream(client_clone);
388
389 client
390 .emit(Packet::new(crate::PacketType::Message, Bytes::from("msg")))
391 .await?;
392
393 tokio::time::sleep(Duration::from_millis(100)).await;
395
396 client.disconnect().await?;
397
398 let mut receive_pong = false;
399 let mut receive_msg = false;
400
401 while let Some(item) = server_rx.recv().await {
402 match item.as_str() {
403 "3" => receive_pong = true,
404 "msg" => receive_msg = true,
405 "close" => break,
406 _ => {}
407 }
408 }
409
410 assert!(receive_pong);
411 assert!(receive_msg);
412 assert!(!client.is_connected());
413
414 Ok(())
415 }
416
417 fn poll_stream(mut stream: impl Stream + Unpin + Send + 'static) {
418 tokio::spawn(async move { while stream.next().await.is_some() {} });
419 }
420}