1use crate::callback::OptionalCallback;
2use crate::transport::TransportType;
3
4use crate::error::{Error, Result};
5use crate::packet::{HandshakePacket, Packet, PacketId, Payload};
6use bytes::Bytes;
7use std::convert::TryFrom;
8use std::sync::RwLock;
9use std::time::Duration;
10use std::{fmt::Debug, sync::atomic::Ordering};
11use std::{
12 sync::{atomic::AtomicBool, Arc, Mutex},
13 time::Instant,
14};
15
16pub const DEFAULT_MAX_POLL_TIMEOUT: Duration = Duration::from_secs(45);
20
21#[derive(Clone)]
24pub struct Socket {
25 transport: Arc<TransportType>,
26 on_close: OptionalCallback<()>,
27 on_data: OptionalCallback<Bytes>,
28 on_error: OptionalCallback<String>,
29 on_open: OptionalCallback<()>,
30 on_packet: OptionalCallback<Packet>,
31 connected: Arc<AtomicBool>,
32 last_ping: Arc<Mutex<Instant>>,
33 last_pong: Arc<Mutex<Instant>>,
34 connection_data: Arc<HandshakePacket>,
35 remaining_packets: Arc<RwLock<Option<crate::packet::IntoIter>>>,
37 max_ping_timeout: u64,
38}
39
40impl Socket {
41 pub(crate) fn new(
42 transport: TransportType,
43 handshake: HandshakePacket,
44 on_close: OptionalCallback<()>,
45 on_data: OptionalCallback<Bytes>,
46 on_error: OptionalCallback<String>,
47 on_open: OptionalCallback<()>,
48 on_packet: OptionalCallback<Packet>,
49 ) -> Self {
50 let max_ping_timeout = handshake.ping_interval + handshake.ping_timeout;
51
52 Socket {
53 on_close,
54 on_data,
55 on_error,
56 on_open,
57 on_packet,
58 transport: Arc::new(transport),
59 connected: Arc::new(AtomicBool::default()),
60 last_ping: Arc::new(Mutex::new(Instant::now())),
61 last_pong: Arc::new(Mutex::new(Instant::now())),
62 connection_data: Arc::new(handshake),
63 remaining_packets: Arc::new(RwLock::new(None)),
64 max_ping_timeout,
65 }
66 }
67
68 pub fn connect(&self) -> Result<()> {
71 self.connected.store(true, Ordering::Release);
73
74 if let Some(on_open) = self.on_open.as_ref() {
75 spawn_scoped!(on_open(()));
76 }
77
78 *self.last_ping.lock()? = Instant::now();
80
81 self.emit(Packet::new(PacketId::Pong, Bytes::new()))?;
83
84 Ok(())
85 }
86
87 pub fn disconnect(&self) -> Result<()> {
88 if let Some(on_close) = self.on_close.as_ref() {
89 spawn_scoped!(on_close(()));
90 }
91
92 let _ = self.emit(Packet::new(PacketId::Close, Bytes::new()));
94
95 self.connected.store(false, Ordering::Release);
96
97 Ok(())
98 }
99
100 pub fn emit(&self, packet: Packet) -> Result<()> {
102 if !self.connected.load(Ordering::Acquire) {
103 let error = Error::IllegalActionBeforeOpen();
104 self.call_error_callback(format!("{}", error));
105 return Err(error);
106 }
107
108 let is_binary = packet.packet_id == PacketId::MessageBinary;
109
110 let data: Bytes = if is_binary {
113 packet.data
114 } else {
115 packet.into()
116 };
117
118 if let Err(error) = self.transport.as_transport().emit(data, is_binary) {
119 self.call_error_callback(error.to_string());
120 return Err(error);
121 }
122
123 Ok(())
124 }
125
126 pub(crate) fn poll(&self) -> Result<Option<Packet>> {
128 loop {
129 if self.connected.load(Ordering::Acquire) {
130 if self.remaining_packets.read()?.is_some() {
131 let mut iter = self.remaining_packets.write()?;
133 let iter = iter.as_mut().unwrap();
134 if let Some(packet) = iter.next() {
135 return Ok(Some(packet));
136 }
137 }
138
139 let data = self
143 .transport
144 .as_transport()
145 .poll(Duration::from_millis(self.time_to_next_ping()?))?;
146
147 if data.is_empty() {
148 continue;
149 }
150
151 let payload = Payload::try_from(data)?;
152 let mut iter = payload.into_iter();
153
154 if let Some(packet) = iter.next() {
155 *self.remaining_packets.write()? = Some(iter);
156 return Ok(Some(packet));
157 }
158 } else {
159 return Ok(None);
160 }
161 }
162 }
163
164 #[inline]
166 fn call_error_callback(&self, text: String) {
167 if let Some(function) = self.on_error.as_ref() {
168 spawn_scoped!(function(text));
169 }
170 }
171
172 pub(crate) fn is_connected(&self) -> Result<bool> {
174 Ok(self.connected.load(Ordering::Acquire))
175 }
176
177 pub(crate) fn pinged(&self) -> Result<()> {
178 *self.last_ping.lock()? = Instant::now();
179 Ok(())
180 }
181
182 fn time_to_next_ping(&self) -> Result<u64> {
186 match Instant::now().checked_duration_since(*self.last_ping.lock()?) {
187 Some(since_last_ping) => {
188 let since_last_ping = since_last_ping.as_millis() as u64;
189 if since_last_ping > self.max_ping_timeout {
190 Ok(0)
191 } else {
192 Ok(self.max_ping_timeout - since_last_ping)
193 }
194 }
195 None => Ok(0),
196 }
197 }
198
199 pub(crate) fn handle_packet(&self, packet: Packet) {
200 if let Some(on_packet) = self.on_packet.as_ref() {
201 spawn_scoped!(on_packet(packet));
202 }
203 }
204
205 pub(crate) fn handle_data(&self, data: Bytes) {
206 if let Some(on_data) = self.on_data.as_ref() {
207 spawn_scoped!(on_data(data));
208 }
209 }
210
211 pub(crate) fn handle_close(&self) {
212 if let Some(on_close) = self.on_close.as_ref() {
213 spawn_scoped!(on_close(()));
214 }
215
216 self.connected.store(false, Ordering::Release);
217 }
218}
219
220#[cfg_attr(tarpaulin, ignore)]
221impl Debug for Socket {
222 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
223 f.write_fmt(format_args!(
224 "EngineSocket(transport: {:?}, on_error: {:?}, on_open: {:?}, on_close: {:?}, on_packet: {:?}, on_data: {:?}, connected: {:?}, last_ping: {:?}, last_pong: {:?}, connection_data: {:?})",
225 self.transport,
226 self.on_error,
227 self.on_open,
228 self.on_close,
229 self.on_packet,
230 self.on_data,
231 self.connected,
232 self.last_ping,
233 self.last_pong,
234 self.connection_data,
235 ))
236 }
237}