1use std::{
2 fmt::Debug,
3 pin::Pin,
4 sync::{
5 atomic::{AtomicBool, Ordering},
6 Arc,
7 },
8 task::{ready, Poll},
9};
10
11use async_stream::try_stream;
12use bytes::Bytes;
13use futures_util::{FutureExt, Stream, StreamExt};
14use tokio::{
15 sync::{mpsc::Sender, Mutex},
16 time::Instant,
17};
18use tracing::trace;
19
20use crate::{
21 error::Result,
22 packet::{HandshakePacket, Payload},
23 transports::{Data, TransportType},
24 Error, Packet, PacketType, Sid, StreamGenerator,
25};
26
27#[derive(Clone)]
28pub struct Socket {
29 transport: Arc<Mutex<TransportType>>,
30 event_tx: Option<Arc<Sender<Event>>>,
31 connected: Arc<AtomicBool>,
32 last_ping: Arc<Mutex<Instant>>,
33 last_pong: Arc<Mutex<Instant>>,
34 connection_data: Arc<HandshakePacket>,
35 generator: Arc<Mutex<StreamGenerator<Packet, Error>>>,
36 server_end: bool,
37 should_pong: bool,
38}
39
40#[derive(Debug)]
41pub enum Event {
42 OnOpen(Sid),
43 OnClose(Sid),
44 OnData(Sid, Bytes),
45 OnPacket(Sid, Packet),
46 OnError(Sid, String),
47}
48
49impl Socket {
50 #[allow(clippy::too_many_arguments)]
51 pub(crate) fn new(
52 transport: TransportType,
53 handshake: HandshakePacket,
54 event_tx: Option<Arc<Sender<Event>>>,
55 should_pong: bool,
56 server_end: bool,
57 ) -> Self {
58 Socket {
59 transport: Arc::new(Mutex::new(transport.clone())),
60 connected: Arc::new(AtomicBool::default()),
61 last_ping: Arc::new(Mutex::new(Instant::now())),
62 last_pong: Arc::new(Mutex::new(Instant::now())),
63 connection_data: Arc::new(handshake),
64 generator: Arc::new(Mutex::new(StreamGenerator::new(Self::stream(transport)))),
65 event_tx,
66 server_end,
67 should_pong,
68 }
69 }
70
71 pub async fn connect(&self) -> Result<()> {
74 self.connected.store(true, Ordering::Release);
76 if let Some(ref event_tx) = self.event_tx {
77 event_tx.send(Event::OnOpen(self.sid())).await?;
78 }
79
80 *self.last_ping.lock().await = Instant::now();
82
83 if !self.server_end {
84 self.emit(Packet::new(PacketType::Pong, Bytes::new()))
86 .await?;
87 }
88
89 Ok(())
90 }
91
92 #[cfg(feature = "server")]
93 pub(crate) async fn last_pong(&self) -> Instant {
94 *(self.last_pong.lock().await)
95 }
96
97 async fn handle_incoming_packet(&self, packet: Packet) {
98 trace!("handle_incoming_packet {:?}", packet);
99 self.ponged().await;
101 self.handle_packet(packet.clone()).await;
103 match packet.ptype {
104 PacketType::MessageBinary | PacketType::Message => {
105 self.handle_data(packet.data).await;
106 }
107 PacketType::Close => {
108 self.handle_close().await;
109 }
110 PacketType::Upgrade => {
111 }
113 PacketType::Ping => {
114 self.pinged().await;
115 if self.should_pong {
117 let _ = self.emit(Packet::new(PacketType::Pong, packet.data)).await;
118 }
119 }
120 PacketType::Pong | PacketType::Open | PacketType::Noop => (),
121 }
122 }
123
124 fn sid(&self) -> Sid {
125 Arc::clone(&self.connection_data.sid)
126 }
127
128 pub async fn disconnect(&self) -> Result<()> {
129 if !self.is_connected() {
130 return Ok(());
131 }
132
133 if let Some(ref event_tx) = self.event_tx {
134 event_tx.send(Event::OnClose(self.sid())).await?;
135 }
136
137 self.emit(Packet::new(PacketType::Close, Bytes::new()))
138 .await?;
139
140 self.connected.store(false, Ordering::Release);
141
142 Ok(())
143 }
144
145 pub async fn emit_multi(&self, packets: Vec<Packet>) -> Result<()> {
147 if !self.connected.load(Ordering::Acquire) {
148 let error = Error::IllegalActionBeforeOpen();
149 self.on_error(format!("{}", error)).await;
150 return Err(error);
151 }
152
153 trace!("socket emit {:?}", packets);
154 let lock = self.transport.lock().await;
155 for packet in packets {
156 let data = match packet.ptype {
159 PacketType::MessageBinary => Data::Binary(packet.data),
160 _ => Data::Text(packet.into()),
161 };
162
163 let fut = lock.as_transport().emit(data);
164
165 if let Err(error) = fut.await {
166 self.on_error(error.to_string()).await;
167 return Err(error);
168 }
169 }
170
171 Ok(())
172 }
173
174 pub async fn emit(&self, packet: Packet) -> Result<()> {
176 if !self.connected.load(Ordering::Acquire) {
177 let error = Error::IllegalActionBeforeOpen();
178 self.on_error(format!("{}", error)).await;
179 return Err(error);
180 }
181
182 let data = match packet.ptype {
185 PacketType::MessageBinary => Data::Binary(packet.data),
186 _ => Data::Text(packet.into()),
187 };
188
189 let lock = self.transport.lock().await;
190 trace!("socket emit {:?} through {:?}", data, lock);
191 let fut = lock.as_transport().emit(data);
192
193 if let Err(error) = fut.await {
194 self.on_error(error.to_string()).await;
195 return Err(error);
196 }
197
198 Ok(())
199 }
200
201 #[inline]
203 async fn on_error(&self, text: String) {
204 trace!("socket on_error {}", text);
205 if let Some(ref event_tx) = self.event_tx {
206 let _ = event_tx.send(Event::OnError(self.sid(), text)).await;
207 }
208 }
209
210 pub fn is_connected(&self) -> bool {
212 self.connected.load(Ordering::Acquire)
213 }
214
215 pub(crate) async fn pinged(&self) {
216 *self.last_ping.lock().await = Instant::now();
217 }
218
219 pub(crate) async fn ponged(&self) {
220 *self.last_pong.lock().await = Instant::now();
221 }
222
223 pub(crate) async fn handle_packet(&self, packet: Packet) {
224 if let Some(ref event_tx) = self.event_tx {
225 let _ = event_tx.send(Event::OnPacket(self.sid(), packet)).await;
226 }
227 }
228
229 pub(crate) async fn handle_data(&self, data: Bytes) {
230 if let Some(ref event_tx) = self.event_tx {
231 let _ = event_tx.send(Event::OnData(self.sid(), data)).await;
232 }
233 }
234
235 pub(crate) async fn handle_close(&self) {
236 if !self.is_connected() {
237 return;
238 }
239 if let Some(ref event_tx) = self.event_tx {
240 let _ = event_tx.send(Event::OnClose(self.sid())).await;
241 }
242
243 self.connected.store(false, Ordering::Release);
244 }
245
246 pub(crate) async fn upgrade(&self, transport: TransportType) {
247 trace!("socket upgrade from {:?}", transport);
248 let mut lock = self.transport.lock().await;
249 *lock = transport.clone();
250
251 let mut lock = self.generator.lock().await;
252 *lock = StreamGenerator::new(Self::stream(transport));
253 }
254
255 fn parse_payload(bytes: Bytes) -> impl Stream<Item = Result<Packet>> {
257 try_stream! {
258 let payload = Payload::try_from(bytes);
259
260 for elem in payload?.into_iter() {
261 trace!("parse_payload yield {:?}", elem);
262 yield elem;
263 }
264 }
265 }
266
267 fn stream(
270 mut transport: TransportType,
271 ) -> Pin<Box<impl Stream<Item = Result<Packet>> + 'static + Send>> {
272 Box::pin(try_stream! {
275 for await payload in transport.as_pin_box() {
276 for await packet in Self::parse_payload(payload?) {
277 yield packet?;
278 }
279 }
280 })
281 }
282}
283
284impl Stream for Socket {
285 type Item = Result<Packet>;
286
287 fn poll_next(
288 self: Pin<&mut Self>,
289 cx: &mut std::task::Context<'_>,
290 ) -> std::task::Poll<Option<Self::Item>> {
291 let mut lock = ready!(Box::pin(self.generator.lock()).poll_unpin(cx));
292 let item = lock.poll_next_unpin(cx);
293 if let Poll::Ready(Some(Ok(packet))) = &item {
294 ready!(Box::pin(self.handle_incoming_packet(packet.clone())).poll_unpin(cx));
295 }
296 item
297 }
298}
299
300#[cfg_attr(tarpaulin, ignore)]
302impl Debug for Socket {
303 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
304 f.debug_struct("Socket")
305 .field("transport", &self.transport)
306 .field("connected", &self.connected)
307 .field("last_ping", &self.last_ping)
308 .field("last_pong", &self.last_pong)
309 .field("connection_data", &self.connection_data)
310 .field("server_end", &self.server_end)
311 .finish()
312 }
313}