rust_engineio/
socket.rs

1use crate::callback::OptionalCallback;
2use crate::transport::TransportType;
3
4use crate::error::{Error, Result};
5use crate::packet::{HandshakePacket, Packet, PacketId, Payload};
6use bytes::Bytes;
7use std::convert::TryFrom;
8use std::sync::RwLock;
9use std::time::Duration;
10use std::{fmt::Debug, sync::atomic::Ordering};
11use std::{
12    sync::{atomic::AtomicBool, Arc, Mutex},
13    time::Instant,
14};
15
16/// The default maximum ping timeout as calculated from the pingInterval and pingTimeout.
17/// See https://socket.io/docs/v4/server-options/#pinginterval and
18/// https://socket.io/docs/v4/server-options/#pingtimeout
19pub const DEFAULT_MAX_POLL_TIMEOUT: Duration = Duration::from_secs(45);
20
21/// An `engine.io` socket which manages a connection with the server and allows
22/// it to register common callbacks.
23#[derive(Clone)]
24pub struct Socket {
25    transport: Arc<TransportType>,
26    on_close: OptionalCallback<()>,
27    on_data: OptionalCallback<Bytes>,
28    on_error: OptionalCallback<String>,
29    on_open: OptionalCallback<()>,
30    on_packet: OptionalCallback<Packet>,
31    connected: Arc<AtomicBool>,
32    last_ping: Arc<Mutex<Instant>>,
33    last_pong: Arc<Mutex<Instant>>,
34    connection_data: Arc<HandshakePacket>,
35    /// Since we get packets in payloads it's possible to have a state where only some of the packets have been consumed.
36    remaining_packets: Arc<RwLock<Option<crate::packet::IntoIter>>>,
37    max_ping_timeout: u64,
38}
39
40impl Socket {
41    pub(crate) fn new(
42        transport: TransportType,
43        handshake: HandshakePacket,
44        on_close: OptionalCallback<()>,
45        on_data: OptionalCallback<Bytes>,
46        on_error: OptionalCallback<String>,
47        on_open: OptionalCallback<()>,
48        on_packet: OptionalCallback<Packet>,
49    ) -> Self {
50        let max_ping_timeout = handshake.ping_interval + handshake.ping_timeout;
51
52        Socket {
53            on_close,
54            on_data,
55            on_error,
56            on_open,
57            on_packet,
58            transport: Arc::new(transport),
59            connected: Arc::new(AtomicBool::default()),
60            last_ping: Arc::new(Mutex::new(Instant::now())),
61            last_pong: Arc::new(Mutex::new(Instant::now())),
62            connection_data: Arc::new(handshake),
63            remaining_packets: Arc::new(RwLock::new(None)),
64            max_ping_timeout,
65        }
66    }
67
68    /// Opens the connection to a specified server. The first Pong packet is sent
69    /// to the server to trigger the Ping-cycle.
70    pub fn connect(&self) -> Result<()> {
71        // SAFETY: Has valid handshake due to type
72        self.connected.store(true, Ordering::Release);
73
74        if let Some(on_open) = self.on_open.as_ref() {
75            spawn_scoped!(on_open(()));
76        }
77
78        // set the last ping to now and set the connected state
79        *self.last_ping.lock()? = Instant::now();
80
81        // emit a pong packet to keep trigger the ping cycle on the server
82        self.emit(Packet::new(PacketId::Pong, Bytes::new()))?;
83
84        Ok(())
85    }
86
87    pub fn disconnect(&self) -> Result<()> {
88        if let Some(on_close) = self.on_close.as_ref() {
89            spawn_scoped!(on_close(()));
90        }
91
92        // will not succeed when connection to the server is interrupted
93        let _ = self.emit(Packet::new(PacketId::Close, Bytes::new()));
94
95        self.connected.store(false, Ordering::Release);
96
97        Ok(())
98    }
99
100    /// Sends a packet to the server.
101    pub fn emit(&self, packet: Packet) -> Result<()> {
102        if !self.connected.load(Ordering::Acquire) {
103            let error = Error::IllegalActionBeforeOpen();
104            self.call_error_callback(format!("{}", error));
105            return Err(error);
106        }
107
108        let is_binary = packet.packet_id == PacketId::MessageBinary;
109
110        // send a post request with the encoded payload as body
111        // if this is a binary attachment, then send the raw bytes
112        let data: Bytes = if is_binary {
113            packet.data
114        } else {
115            packet.into()
116        };
117
118        if let Err(error) = self.transport.as_transport().emit(data, is_binary) {
119            self.call_error_callback(error.to_string());
120            return Err(error);
121        }
122
123        Ok(())
124    }
125
126    /// Polls for next payload
127    pub(crate) fn poll(&self) -> Result<Option<Packet>> {
128        loop {
129            if self.connected.load(Ordering::Acquire) {
130                if self.remaining_packets.read()?.is_some() {
131                    // SAFETY: checked is some above
132                    let mut iter = self.remaining_packets.write()?;
133                    let iter = iter.as_mut().unwrap();
134                    if let Some(packet) = iter.next() {
135                        return Ok(Some(packet));
136                    }
137                }
138
139                // Iterator has run out of packets, get a new payload.
140                // Make sure that payload is received within time_to_next_ping, as otherwise the heart
141                // stopped beating and we disconnect.
142                let data = self
143                    .transport
144                    .as_transport()
145                    .poll(Duration::from_millis(self.time_to_next_ping()?))?;
146
147                if data.is_empty() {
148                    continue;
149                }
150
151                let payload = Payload::try_from(data)?;
152                let mut iter = payload.into_iter();
153
154                if let Some(packet) = iter.next() {
155                    *self.remaining_packets.write()? = Some(iter);
156                    return Ok(Some(packet));
157                }
158            } else {
159                return Ok(None);
160            }
161        }
162    }
163
164    /// Calls the error callback with a given message.
165    #[inline]
166    fn call_error_callback(&self, text: String) {
167        if let Some(function) = self.on_error.as_ref() {
168            spawn_scoped!(function(text));
169        }
170    }
171
172    // Check if the underlying transport client is connected.
173    pub(crate) fn is_connected(&self) -> Result<bool> {
174        Ok(self.connected.load(Ordering::Acquire))
175    }
176
177    pub(crate) fn pinged(&self) -> Result<()> {
178        *self.last_ping.lock()? = Instant::now();
179        Ok(())
180    }
181
182    /// Returns the time in milliseconds that is left until a new ping must be received.
183    /// This is used to detect whether we have been disconnected from the server.
184    /// See https://socket.io/docs/v4/how-it-works/#disconnection-detection
185    fn time_to_next_ping(&self) -> Result<u64> {
186        match Instant::now().checked_duration_since(*self.last_ping.lock()?) {
187            Some(since_last_ping) => {
188                let since_last_ping = since_last_ping.as_millis() as u64;
189                if since_last_ping > self.max_ping_timeout {
190                    Ok(0)
191                } else {
192                    Ok(self.max_ping_timeout - since_last_ping)
193                }
194            }
195            None => Ok(0),
196        }
197    }
198
199    pub(crate) fn handle_packet(&self, packet: Packet) {
200        if let Some(on_packet) = self.on_packet.as_ref() {
201            spawn_scoped!(on_packet(packet));
202        }
203    }
204
205    pub(crate) fn handle_data(&self, data: Bytes) {
206        if let Some(on_data) = self.on_data.as_ref() {
207            spawn_scoped!(on_data(data));
208        }
209    }
210
211    pub(crate) fn handle_close(&self) {
212        if let Some(on_close) = self.on_close.as_ref() {
213            spawn_scoped!(on_close(()));
214        }
215
216        self.connected.store(false, Ordering::Release);
217    }
218}
219
220#[cfg_attr(tarpaulin, ignore)]
221impl Debug for Socket {
222    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
223        f.write_fmt(format_args!(
224            "EngineSocket(transport: {:?}, on_error: {:?}, on_open: {:?}, on_close: {:?}, on_packet: {:?}, on_data: {:?}, connected: {:?}, last_ping: {:?}, last_pong: {:?}, connection_data: {:?})",
225            self.transport,
226            self.on_error,
227            self.on_open,
228            self.on_close,
229            self.on_packet,
230            self.on_data,
231            self.connected,
232            self.last_ping,
233            self.last_pong,
234            self.connection_data,
235        ))
236    }
237}