engineio_rs/socket/
socket.rs

1use std::{
2    fmt::Debug,
3    pin::Pin,
4    sync::{
5        atomic::{AtomicBool, Ordering},
6        Arc,
7    },
8    task::{ready, Poll},
9};
10
11use async_stream::try_stream;
12use bytes::Bytes;
13use futures_util::{FutureExt, Stream, StreamExt};
14use tokio::{
15    sync::{mpsc::Sender, Mutex},
16    time::Instant,
17};
18use tracing::trace;
19
20use crate::{
21    error::Result,
22    packet::{HandshakePacket, Payload},
23    transports::{Data, TransportType},
24    Error, Packet, PacketType, Sid, StreamGenerator,
25};
26
27#[derive(Clone)]
28pub struct Socket {
29    transport: Arc<Mutex<TransportType>>,
30    event_tx: Option<Arc<Sender<Event>>>,
31    connected: Arc<AtomicBool>,
32    last_ping: Arc<Mutex<Instant>>,
33    last_pong: Arc<Mutex<Instant>>,
34    connection_data: Arc<HandshakePacket>,
35    generator: Arc<Mutex<StreamGenerator<Packet, Error>>>,
36    server_end: bool,
37    should_pong: bool,
38}
39
40#[derive(Debug)]
41pub enum Event {
42    OnOpen(Sid),
43    OnClose(Sid),
44    OnData(Sid, Bytes),
45    OnPacket(Sid, Packet),
46    OnError(Sid, String),
47}
48
49impl Socket {
50    #[allow(clippy::too_many_arguments)]
51    pub(crate) fn new(
52        transport: TransportType,
53        handshake: HandshakePacket,
54        event_tx: Option<Arc<Sender<Event>>>,
55        should_pong: bool,
56        server_end: bool,
57    ) -> Self {
58        Socket {
59            transport: Arc::new(Mutex::new(transport.clone())),
60            connected: Arc::new(AtomicBool::default()),
61            last_ping: Arc::new(Mutex::new(Instant::now())),
62            last_pong: Arc::new(Mutex::new(Instant::now())),
63            connection_data: Arc::new(handshake),
64            generator: Arc::new(Mutex::new(StreamGenerator::new(Self::stream(transport)))),
65            event_tx,
66            server_end,
67            should_pong,
68        }
69    }
70
71    /// Opens the connection to a specified server. The first Pong packet is sent
72    /// to the server to trigger the Ping-cycle.
73    pub async fn connect(&self) -> Result<()> {
74        // SAFETY: Has valid handshake due to type
75        self.connected.store(true, Ordering::Release);
76        if let Some(ref event_tx) = self.event_tx {
77            event_tx.send(Event::OnOpen(self.sid())).await?;
78        }
79
80        // set the last ping to now and set the connected state
81        *self.last_ping.lock().await = Instant::now();
82
83        if !self.server_end {
84            // emit a pong packet to keep trigger the ping cycle on the server
85            self.emit(Packet::new(PacketType::Pong, Bytes::new()))
86                .await?;
87        }
88
89        Ok(())
90    }
91
92    #[cfg(feature = "server")]
93    pub(crate) async fn last_pong(&self) -> Instant {
94        *(self.last_pong.lock().await)
95    }
96
97    async fn handle_incoming_packet(&self, packet: Packet) {
98        trace!("handle_incoming_packet {:?}", packet);
99        // update last_pong on any packet, incoming data is a good sign of other side's liveness
100        self.ponged().await;
101        // check for the appropriate action or callback
102        self.handle_packet(packet.clone()).await;
103        match packet.ptype {
104            PacketType::MessageBinary | PacketType::Message => {
105                self.handle_data(packet.data).await;
106            }
107            PacketType::Close => {
108                self.handle_close().await;
109            }
110            PacketType::Upgrade => {
111                // this is already checked during the handshake, so just do nothing here
112            }
113            PacketType::Ping => {
114                self.pinged().await;
115                // server and pong timeout test case should not pong
116                if self.should_pong {
117                    let _ = self.emit(Packet::new(PacketType::Pong, packet.data)).await;
118                }
119            }
120            PacketType::Pong | PacketType::Open | PacketType::Noop => (),
121        }
122    }
123
124    fn sid(&self) -> Sid {
125        Arc::clone(&self.connection_data.sid)
126    }
127
128    pub async fn disconnect(&self) -> Result<()> {
129        if !self.is_connected() {
130            return Ok(());
131        }
132
133        if let Some(ref event_tx) = self.event_tx {
134            event_tx.send(Event::OnClose(self.sid())).await?;
135        }
136
137        self.emit(Packet::new(PacketType::Close, Bytes::new()))
138            .await?;
139
140        self.connected.store(false, Ordering::Release);
141
142        Ok(())
143    }
144
145    /// Sends a packet to the server.
146    pub async fn emit_multi(&self, packets: Vec<Packet>) -> Result<()> {
147        if !self.connected.load(Ordering::Acquire) {
148            let error = Error::IllegalActionBeforeOpen();
149            self.on_error(format!("{}", error)).await;
150            return Err(error);
151        }
152
153        trace!("socket emit {:?}", packets);
154        let lock = self.transport.lock().await;
155        for packet in packets {
156            // send a post request with the encoded payload as body
157            // if this is a binary attachment, then send the raw bytes
158            let data = match packet.ptype {
159                PacketType::MessageBinary => Data::Binary(packet.data),
160                _ => Data::Text(packet.into()),
161            };
162
163            let fut = lock.as_transport().emit(data);
164
165            if let Err(error) = fut.await {
166                self.on_error(error.to_string()).await;
167                return Err(error);
168            }
169        }
170
171        Ok(())
172    }
173
174    /// Sends a packet to the server.
175    pub async fn emit(&self, packet: Packet) -> Result<()> {
176        if !self.connected.load(Ordering::Acquire) {
177            let error = Error::IllegalActionBeforeOpen();
178            self.on_error(format!("{}", error)).await;
179            return Err(error);
180        }
181
182        // send a post request with the encoded payload as body
183        // if this is a binary attachment, then send the raw bytes
184        let data = match packet.ptype {
185            PacketType::MessageBinary => Data::Binary(packet.data),
186            _ => Data::Text(packet.into()),
187        };
188
189        let lock = self.transport.lock().await;
190        trace!("socket emit {:?} through {:?}", data, lock);
191        let fut = lock.as_transport().emit(data);
192
193        if let Err(error) = fut.await {
194            self.on_error(error.to_string()).await;
195            return Err(error);
196        }
197
198        Ok(())
199    }
200
201    /// Calls the error callback with a given message.
202    #[inline]
203    async fn on_error(&self, text: String) {
204        trace!("socket on_error {}", text);
205        if let Some(ref event_tx) = self.event_tx {
206            let _ = event_tx.send(Event::OnError(self.sid(), text)).await;
207        }
208    }
209
210    // Check if the underlying transport client is connected.
211    pub fn is_connected(&self) -> bool {
212        self.connected.load(Ordering::Acquire)
213    }
214
215    pub(crate) async fn pinged(&self) {
216        *self.last_ping.lock().await = Instant::now();
217    }
218
219    pub(crate) async fn ponged(&self) {
220        *self.last_pong.lock().await = Instant::now();
221    }
222
223    pub(crate) async fn handle_packet(&self, packet: Packet) {
224        if let Some(ref event_tx) = self.event_tx {
225            let _ = event_tx.send(Event::OnPacket(self.sid(), packet)).await;
226        }
227    }
228
229    pub(crate) async fn handle_data(&self, data: Bytes) {
230        if let Some(ref event_tx) = self.event_tx {
231            let _ = event_tx.send(Event::OnData(self.sid(), data)).await;
232        }
233    }
234
235    pub(crate) async fn handle_close(&self) {
236        if !self.is_connected() {
237            return;
238        }
239        if let Some(ref event_tx) = self.event_tx {
240            let _ = event_tx.send(Event::OnClose(self.sid())).await;
241        }
242
243        self.connected.store(false, Ordering::Release);
244    }
245
246    pub(crate) async fn upgrade(&self, transport: TransportType) {
247        trace!("socket upgrade from {:?}", transport);
248        let mut lock = self.transport.lock().await;
249        *lock = transport.clone();
250
251        let mut lock = self.generator.lock().await;
252        *lock = StreamGenerator::new(Self::stream(transport));
253    }
254
255    /// Helper method that parses bytes and returns an iterator over the elements.
256    fn parse_payload(bytes: Bytes) -> impl Stream<Item = Result<Packet>> {
257        try_stream! {
258            let payload = Payload::try_from(bytes);
259
260            for elem in payload?.into_iter() {
261                trace!("parse_payload yield {:?}", elem);
262                yield elem;
263            }
264        }
265    }
266
267    /// Creates a stream over the incoming packets, uses the streams provided by the
268    /// underlying transport types.
269    fn stream(
270        mut transport: TransportType,
271    ) -> Pin<Box<impl Stream<Item = Result<Packet>> + 'static + Send>> {
272        // map the byte stream of the underlying transport
273        // to a packet stream
274        Box::pin(try_stream! {
275            for await payload in transport.as_pin_box() {
276                for await packet in Self::parse_payload(payload?) {
277                    yield packet?;
278                }
279            }
280        })
281    }
282}
283
284impl Stream for Socket {
285    type Item = Result<Packet>;
286
287    fn poll_next(
288        self: Pin<&mut Self>,
289        cx: &mut std::task::Context<'_>,
290    ) -> std::task::Poll<Option<Self::Item>> {
291        let mut lock = ready!(Box::pin(self.generator.lock()).poll_unpin(cx));
292        let item = lock.poll_next_unpin(cx);
293        if let Poll::Ready(Some(Ok(packet))) = &item {
294            ready!(Box::pin(self.handle_incoming_packet(packet.clone())).poll_unpin(cx));
295        }
296        item
297    }
298}
299
300// impl Stre
301#[cfg_attr(tarpaulin, ignore)]
302impl Debug for Socket {
303    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
304        f.debug_struct("Socket")
305            .field("transport", &self.transport)
306            .field("connected", &self.connected)
307            .field("last_ping", &self.last_ping)
308            .field("last_pong", &self.last_pong)
309            .field("connection_data", &self.connection_data)
310            .field("server_end", &self.server_end)
311            .finish()
312    }
313}