Skip to main content

socketioxide/
socket.rs

1//! A [`Socket`] represents a client connected to a namespace.
2//! The socket struct itself should not be used directly, but through a [`SocketRef`](crate::extract::SocketRef).
3use std::{
4    borrow::Cow,
5    collections::HashMap,
6    fmt::{self, Debug},
7    future::Future,
8    sync::{
9        Arc, Mutex, RwLock,
10        atomic::{AtomicBool, AtomicI64, Ordering},
11    },
12    time::Duration,
13};
14
15use engineioxide::socket::{DisconnectReason as EIoDisconnectReason, Permit};
16use serde::Serialize;
17use tokio::sync::{
18    mpsc::error::TrySendError,
19    oneshot::{self, Receiver},
20};
21
22#[cfg(feature = "extensions")]
23use crate::extensions::Extensions;
24
25use crate::{
26    AckError, SendError, SocketError, SocketIo,
27    ack::{AckInnerStream, AckResult, AckStream},
28    adapter::{Adapter, LocalAdapter},
29    client::SocketData,
30    errors::Error,
31    handler::{
32        BoxedDisconnectHandler, BoxedMessageHandler, DisconnectHandler, MakeErasedHandler,
33        MessageHandler,
34    },
35    ns::Namespace,
36    operators::{BroadcastOperators, ConfOperators},
37    parser::Parser,
38};
39use socketioxide_core::{
40    Value,
41    adapter::errors::{AdapterError, BroadcastError},
42    adapter::{BroadcastOptions, RemoteSocketData, Room, RoomParam},
43    packet::{Packet, PacketData},
44    parser::Parse,
45};
46
47pub use socketioxide_core::Sid;
48
49/// All the possible reasons for a [`Socket`] to be disconnected from a namespace.
50///
51/// It can be used as an extractor in the [`on_disconnect`](crate::handler::disconnect) handler.
52#[derive(Debug, Copy, Clone, Eq, PartialEq)]
53pub enum DisconnectReason {
54    /// The client gracefully closed the connection
55    TransportClose,
56
57    /// The client sent multiple polling requests at the same time (it is forbidden according to the engine.io protocol)
58    MultipleHttpPollingError,
59
60    /// The client sent a bad request / the packet could not be parsed correctly
61    PacketParsingError,
62
63    /// The connection was closed (example: the user has lost connection, or the network was changed from WiFi to 4G)
64    TransportError,
65
66    /// The client did not send a PONG packet in the `ping timeout` delay
67    HeartbeatTimeout,
68
69    /// The client has manually disconnected the socket using [`socket.disconnect()`](https://socket.io/fr/docs/v4/client-api/#socketdisconnect)
70    ClientNSDisconnect,
71
72    /// The socket was forcefully disconnected from the namespace with [`Socket::disconnect`] or with [`SocketIo::delete_ns`](crate::io::SocketIo::delete_ns)
73    ServerNSDisconnect,
74
75    /// The server is being closed
76    ClosingServer,
77}
78
79impl std::fmt::Display for DisconnectReason {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        use DisconnectReason::*;
82        let str: &'static str = match self {
83            TransportClose => "client gracefully closed the connection",
84            MultipleHttpPollingError => "client sent multiple polling requests at the same time",
85            PacketParsingError => "client sent a bad request / the packet could not be parsed",
86            TransportError => "The connection was abruptly closed",
87            HeartbeatTimeout => "client did not send a PONG packet in time",
88            ClientNSDisconnect => "client has manually disconnected the socket from the namespace",
89            ServerNSDisconnect => "socket was forcefully disconnected from the namespace",
90            ClosingServer => "server is being closed",
91        };
92        f.write_str(str)
93    }
94}
95
96impl From<EIoDisconnectReason> for DisconnectReason {
97    fn from(reason: EIoDisconnectReason) -> Self {
98        use DisconnectReason::*;
99        match reason {
100            EIoDisconnectReason::TransportClose => TransportClose,
101            EIoDisconnectReason::TransportError => TransportError,
102            EIoDisconnectReason::HeartbeatTimeout => HeartbeatTimeout,
103            EIoDisconnectReason::MultipleHttpPollingError => MultipleHttpPollingError,
104            EIoDisconnectReason::PacketParsingError => PacketParsingError,
105            EIoDisconnectReason::ClosingServer => ClosingServer,
106        }
107    }
108}
109
110pub(crate) trait PermitExt<'a> {
111    fn send(self, packet: Packet, parser: Parser);
112    fn send_raw(self, value: Value);
113}
114impl<'a> PermitExt<'a> for Permit<'a> {
115    fn send(self, packet: Packet, parser: Parser) {
116        match parser.encode(packet) {
117            Value::Str(msg, None) => self.emit(msg),
118            Value::Str(msg, Some(bin_payloads)) => self.emit_many(msg, bin_payloads),
119            Value::Bytes(bin) => self.emit_binary(bin),
120        }
121    }
122
123    fn send_raw(self, value: Value) {
124        match value {
125            Value::Str(msg, None) => self.emit(msg),
126            Value::Str(msg, Some(bin_payloads)) => self.emit_many(msg, bin_payloads),
127            Value::Bytes(bin) => self.emit_binary(bin),
128        }
129    }
130}
131
132/// A RemoteSocket is a [`Socket`] that is remotely connected on another server.
133/// It implements a subset of the [`Socket`] API.
134pub struct RemoteSocket<A> {
135    adapter: Arc<A>,
136    parser: Parser,
137    data: RemoteSocketData,
138}
139
140impl<A> RemoteSocket<A> {
141    pub(crate) fn new(data: RemoteSocketData, adapter: &Arc<A>, parser: Parser) -> Self {
142        Self {
143            data,
144            adapter: adapter.clone(),
145            parser,
146        }
147    }
148    /// Consume the [`RemoteSocket`] and return its underlying data
149    #[inline]
150    pub fn into_data(self) -> RemoteSocketData {
151        self.data
152    }
153    /// Get a ref to the underlying data of the socket
154    #[inline]
155    pub fn data(&self) -> &RemoteSocketData {
156        &self.data
157    }
158}
159impl<A: Adapter> RemoteSocket<A> {
160    /// # Emit a message to a client that is remotely connected on another server.
161    ///
162    /// See [`Socket::emit`] for more info.
163    pub async fn emit<T: ?Sized + Serialize>(
164        &self,
165        event: impl AsRef<str>,
166        data: &T,
167    ) -> Result<(), RemoteActionError> {
168        let opts = self.get_opts();
169        let data = self.parser.encode_value(data, Some(event.as_ref()))?;
170        let packet = Packet::event(self.data.ns.clone(), data);
171        self.adapter.broadcast(packet, opts).await?;
172        Ok(())
173    }
174
175    /// # Emit a message to a client that is remotely connected on another server and wait for an acknowledgement.
176    ///
177    /// See [`Socket::emit_with_ack`] for more info.
178    pub async fn emit_with_ack<T: ?Sized + Serialize, V: serde::de::DeserializeOwned>(
179        &self,
180        event: impl AsRef<str>,
181        data: &T,
182    ) -> Result<AckStream<V, A>, RemoteActionError> {
183        let opts = self.get_opts();
184        let data = self.parser.encode_value(data, Some(event.as_ref()))?;
185        let packet = Packet::event(self.data.ns.clone(), data);
186        let stream = self
187            .adapter
188            .broadcast_with_ack(packet, opts, None)
189            .await
190            .map_err(Into::<AdapterError>::into)?;
191        Ok(AckStream::new(stream, self.parser))
192    }
193
194    /// # Get all room names this remote socket is connected to.
195    ///
196    /// See [`Socket::rooms`] for more info.
197    #[inline]
198    pub async fn rooms(&self) -> Result<Vec<Room>, A::Error> {
199        self.adapter.rooms(self.get_opts()).await
200    }
201
202    /// # Add the remote socket to the specified room(s).
203    ///
204    /// See [`Socket::join`] for more info.
205    #[inline]
206    pub fn join(&self, rooms: impl RoomParam) -> impl Future<Output = Result<(), A::Error>> + '_ {
207        self.adapter.add_sockets(self.get_opts(), rooms)
208    }
209
210    /// # Remove the remote socket from the specified room(s).
211    ///
212    /// See [`Socket::leave`] for more info.
213    #[inline]
214    pub fn leave(
215        &self,
216        rooms: impl RoomParam,
217    ) -> impl Future<Output = Result<(), A::Error>> + Send + '_ {
218        self.adapter.del_sockets(self.get_opts(), rooms)
219    }
220
221    /// # Disconnect the remote socket from the current namespace,
222    ///
223    /// See [`Socket::disconnect`] for more info.
224    #[inline]
225    pub async fn disconnect(self) -> Result<(), RemoteActionError> {
226        self.adapter.disconnect_socket(self.get_opts()).await?;
227        Ok(())
228    }
229
230    #[inline(always)]
231    fn get_opts(&self) -> BroadcastOptions {
232        BroadcastOptions::new_remote(&self.data)
233    }
234}
235impl<A> fmt::Debug for RemoteSocket<A> {
236    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
237        f.debug_struct("RemoteSocket")
238            .field("id", &self.data.id)
239            .field("server_id", &self.data.server_id)
240            .field("ns", &self.data.ns)
241            .finish()
242    }
243}
244impl<A> Clone for RemoteSocket<A> {
245    fn clone(&self) -> Self {
246        Self {
247            adapter: self.adapter.clone(),
248            parser: self.parser,
249            data: self.data.clone(),
250        }
251    }
252}
253
254/// A error that can occur when emitting a message to a remote socket.
255#[derive(Debug, thiserror::Error)]
256pub enum RemoteActionError {
257    /// The message data could not be encoded.
258    #[error("cannot encode data: {0}")]
259    Serialize(#[from] crate::parser::ParserError),
260    /// The remote socket is, in fact, a local socket and we should not emit to it.
261    #[error("cannot send the message to the local socket: {0}")]
262    Socket(crate::SocketError),
263    /// The message could not be sent to the remote server.
264    #[error("cannot propagate the request to the server: {0}")]
265    Adapter(#[from] AdapterError),
266}
267impl From<BroadcastError> for RemoteActionError {
268    fn from(value: BroadcastError) -> Self {
269        // This conversion assumes that we broadcast to a single (remote or not) socket.
270        match value {
271            BroadcastError::Socket(s) if !s.is_empty() => RemoteActionError::Socket(s[0].clone()),
272            BroadcastError::Socket(_) => {
273                panic!("BroadcastError with an empty socket vec is not permitted")
274            }
275            BroadcastError::Adapter(e) => e.into(),
276            BroadcastError::Serialize(e) => e.into(),
277        }
278    }
279}
280
281/// A Socket represents a client connected to a namespace.
282/// It is used to send and receive messages from the client, join and leave rooms, etc.
283/// The socket struct itself should not be used directly, but through a [`SocketRef`](crate::extract::SocketRef).
284pub struct Socket<A: Adapter = LocalAdapter> {
285    pub(crate) ns: Arc<Namespace<A>>,
286    message_handlers: RwLock<HashMap<Cow<'static, str>, BoxedMessageHandler<A>>>,
287    fallback_message_handler: RwLock<Option<BoxedMessageHandler<A>>>,
288    disconnect_handler: Mutex<Option<BoxedDisconnectHandler<A>>>,
289    ack_message: Mutex<HashMap<i64, oneshot::Sender<AckResult<Value>>>>,
290    ack_counter: AtomicI64,
291    connected: AtomicBool,
292    pub(crate) parser: Parser,
293    /// The socket id
294    pub id: Sid,
295
296    /// A type map of protocol extensions.
297    /// It can be used to share data through the lifetime of the socket.
298    ///
299    /// **Note**: This is not the same data as the `extensions` field on the [`http::Request::extensions()`](http::Request) struct.
300    /// If you want to extract extensions from the http request, you should use the [`HttpExtension`](crate::extract::HttpExtension) extractor.
301    #[cfg_attr(docsrs, doc(cfg(feature = "extensions")))]
302    #[cfg(feature = "extensions")]
303    pub extensions: Extensions,
304    esocket: Arc<engineioxide::Socket<SocketData<A>>>,
305}
306
307impl<A: Adapter> Socket<A> {
308    pub(crate) fn new(
309        sid: Sid,
310        ns: Arc<Namespace<A>>,
311        esocket: Arc<engineioxide::Socket<SocketData<A>>>,
312        parser: Parser,
313    ) -> Self {
314        Self {
315            ns,
316            message_handlers: RwLock::new(HashMap::new()),
317            fallback_message_handler: RwLock::new(None),
318            disconnect_handler: Mutex::new(None),
319            ack_message: Mutex::new(HashMap::new()),
320            ack_counter: AtomicI64::new(0),
321            connected: AtomicBool::new(false),
322            parser,
323            id: sid,
324            #[cfg(feature = "extensions")]
325            extensions: Extensions::new(),
326            esocket,
327        }
328    }
329
330    /// # Registers a [`MessageHandler`] for the given event.
331    ///
332    /// * See the [`message`](crate::handler::message) module doc for more details on message handler.
333    /// * See the [`extract`](crate::extract) module doc for more details on available extractors.
334    ///
335    /// _It is recommended for code clarity to define your handler as top level function rather than closures._
336    ///
337    /// # Simple example with an async closure and an async fn:
338    /// ```
339    /// # use socketioxide::{SocketIo, extract::*};
340    /// # use serde::{Serialize, Deserialize};
341    /// #[derive(Debug, Serialize, Deserialize)]
342    /// struct MyData {
343    ///     name: String,
344    ///     age: u8,
345    /// }
346    /// async fn handler(socket: SocketRef, Data(data): Data::<MyData>) {
347    ///     println!("Received a test message {:?}", data);
348    ///     socket.emit("test-test", &MyData { name: "Test".to_string(), age: 8 }).ok(); // Emit a message to the client
349    /// }
350    ///
351    /// let (_, io) = SocketIo::new_svc();
352    /// io.ns("/", async |socket: SocketRef| {
353    ///     // Register a handler for the "test" event and extract the data as a `MyData` struct
354    ///     // With the Data extractor, the handler is called only if the data can be deserialized as a `MyData` struct
355    ///     // If you want to manage errors yourself you can use the TryData extractor
356    ///     socket.on("test", async |socket: SocketRef, Data::<MyData>(data)| {
357    ///         println!("Received a test message {:?}", data);
358    ///         socket.emit("test-test", &MyData { name: "Test".to_string(), age: 8 }).ok(); // Emit a message to the client
359    ///     });
360    ///     // Do the same thing but with an async function
361    ///     socket.on("test_2", handler);
362    /// });
363    ///
364    /// ```
365    ///
366    /// # Example with a closure and an fn with an acknowledgement + binary data:
367    /// ```
368    /// # use socketioxide::{SocketIo, extract::*};
369    /// # use serde_json::Value;
370    /// # use serde::{Serialize, Deserialize};
371    /// #[derive(Debug, Serialize, Deserialize)]
372    /// struct MyData {
373    ///     name: String,
374    ///     age: u8,
375    /// }
376    /// async fn handler(socket: SocketRef, Data(data): Data::<MyData>, ack: AckSender) {
377    ///     println!("Received a test message {:?}", data);
378    ///     tokio::time::sleep(std::time::Duration::from_secs(1)).await;
379    ///     ack.send(&data).ok(); // The data received is sent back to the client through the ack
380    ///     socket.emit("test-test", &MyData { name: "Test".to_string(), age: 8 }).ok(); // Emit a message to the client
381    /// }
382    ///
383    /// let (_, io) = SocketIo::new_svc();
384    /// io.ns("/", async |socket: SocketRef| {
385    ///     // Register an async handler for the "test" event and extract the data as a `MyData` struct
386    ///     // Extract the binary payload as a `Vec<Bytes>` with the Bin extractor.
387    ///     // It should be the last extractor because it consumes the request
388    ///     socket.on("test", async |socket: SocketRef, Data::<MyData>(data), ack: AckSender| {
389    ///         println!("Received a test message {:?}", data);
390    ///         tokio::time::sleep(std::time::Duration::from_secs(1)).await;
391    ///         ack.send(&data).ok(); // The data received is sent back to the client through the ack
392    ///         socket.emit("test-test", &MyData { name: "Test".to_string(), age: 8 }).ok(); // Emit a message to the client
393    ///     });
394    ///     // Do the same thing but with an async function
395    ///     socket.on("test_2", handler);
396    /// });
397    /// ```
398    pub fn on<H, T>(&self, event: impl Into<Cow<'static, str>>, handler: H)
399    where
400        H: MessageHandler<A, T>,
401        T: Send + Sync + 'static,
402    {
403        self.message_handlers
404            .write()
405            .unwrap()
406            .insert(event.into(), MakeErasedHandler::new_message_boxed(handler));
407    }
408
409    /// # Registers a fallback [`MessageHandler`] when no other handler is found. You can see this as a 404 handler.
410    /// You can register only one fallback handler per socket. If you register multiple handlers, only the last one will be used.
411    ///
412    /// * See the [`message`](crate::handler::message) module doc for more details on message handler.
413    /// * See the [`extract`](crate::extract) module doc for more details on available extractors.
414    ///
415    /// _It is recommended for code clarity to define your handler as top level function rather than closures._
416    ///
417    /// # Example:
418    /// ```
419    /// # use socketioxide::{SocketIo, extract::*};
420    /// # use serde::{Serialize, Deserialize};
421    /// # use serde_json::Value;
422    /// async fn fallback_handler(socket: SocketRef, Event(event): Event, Data(data): Data::<Value>) {
423    ///     println!("Received an {event} event with message {:?}", data);
424    /// }
425    ///
426    /// let (_, io) = SocketIo::new_svc();
427    /// io.ns("/", async |socket: SocketRef| {
428    ///     // Register a fallback handler.
429    ///     // In our example it will be always called as there is no other handler.
430    ///     socket.on_fallback(fallback_handler);
431    /// });
432    ///
433    /// ```
434    pub fn on_fallback<H, T>(&self, handler: H)
435    where
436        H: MessageHandler<A, T>,
437        T: Send + Sync + 'static,
438    {
439        self.fallback_message_handler
440            .write()
441            .unwrap()
442            .replace(MakeErasedHandler::new_message_boxed(handler));
443    }
444
445    /// # Register a disconnect handler.
446    /// You can register only one disconnect handler per socket. If you register multiple handlers, only the last one will be used.
447    ///
448    /// This implementation is slightly different to the socket.io spec.
449    /// The difference being that [`rooms`](Self::rooms) are still available in this handler
450    /// and only cleaned up AFTER the execution of this handler.
451    /// Therefore you must not indefinitely stall/hang this handler, for example by entering an endless loop.
452    ///
453    /// _It is recommended for code clarity to define your handler as top level function rather than closures._
454    ///
455    /// * See the [`disconnect`](crate::handler::disconnect) module doc for more details on disconnect handler.
456    /// * See the [`extract`](crate::extract) module doc for more details on available extractors.
457    ///
458    /// The callback will be called when the socket is disconnected from the server or the client or when the underlying connection crashes.
459    /// A [`DisconnectReason`] is passed to the callback to indicate the reason for the disconnection.
460    ///
461    /// # Example
462    /// ```
463    /// # use socketioxide::{SocketIo, socket::DisconnectReason, extract::*};
464    /// # use serde_json::Value;
465    /// # use std::sync::Arc;
466    /// let (_, io) = SocketIo::new_svc();
467    /// io.ns("/", async |socket: SocketRef| {
468    ///     socket.on("test", async |socket: SocketRef| {
469    ///         // Close the current socket
470    ///         socket.disconnect().ok();
471    ///     });
472    ///     socket.on_disconnect(async |socket: SocketRef, reason: DisconnectReason| {
473    ///         println!("Socket {} on ns {} disconnected, reason: {:?}", socket.id, socket.ns(), reason);
474    ///     });
475    /// });
476    pub fn on_disconnect<C, T>(&self, callback: C)
477    where
478        C: DisconnectHandler<A, T> + Send + Sync + 'static,
479        T: Send + Sync + 'static,
480    {
481        let handler = MakeErasedHandler::new_disconnect_boxed(callback);
482        self.disconnect_handler.lock().unwrap().replace(handler);
483    }
484
485    #[doc = include_str!("../docs/operators/emit.md")]
486    pub fn emit<T: ?Sized + Serialize>(
487        &self,
488        event: impl AsRef<str>,
489        data: &T,
490    ) -> Result<(), SendError> {
491        if !self.connected() {
492            return Err(SendError::Socket(SocketError::Closed));
493        }
494
495        let permit = match self.reserve() {
496            Ok(permit) => permit,
497            Err(e) => {
498                #[cfg(feature = "tracing")]
499                tracing::debug!("sending error during emit message: {e:?}");
500                return Err(SendError::Socket(e));
501            }
502        };
503
504        let ns = self.ns.path.clone();
505        let data = self.parser.encode_value(data, Some(event.as_ref()))?;
506
507        permit.send(Packet::event(ns, data), self.parser);
508        Ok(())
509    }
510
511    #[doc = include_str!("../docs/operators/emit_with_ack.md")]
512    pub fn emit_with_ack<T: ?Sized + Serialize, V>(
513        &self,
514        event: impl AsRef<str>,
515        data: &T,
516    ) -> Result<AckStream<V>, SendError> {
517        if !self.connected() {
518            return Err(SendError::Socket(SocketError::Closed));
519        }
520        let permit = match self.reserve() {
521            Ok(permit) => permit,
522            Err(e) => {
523                #[cfg(feature = "tracing")]
524                tracing::debug!("sending error during emit message: {e:?}");
525                return Err(SendError::Socket(e));
526            }
527        };
528        let ns = self.ns.path.clone();
529        let data = self.parser.encode_value(data, Some(event.as_ref()))?;
530        let packet = Packet::event(ns, data);
531        let rx = self.send_with_ack_permit(packet, permit);
532        let stream = AckInnerStream::send(rx, self.get_io().config().ack_timeout, self.id);
533        Ok(AckStream::<V>::new(stream, self.parser))
534    }
535
536    // Room actions
537
538    /// # Add the current socket to the specified room(s).
539    ///
540    /// # Example
541    /// ```rust
542    /// # use socketioxide::{SocketIo, extract::*};
543    /// async fn handler(socket: SocketRef) {
544    ///     // Add all sockets that are in room1 and room3 to room4 and room5
545    ///     socket.join(["room4", "room5"]);
546    ///     // We should retrieve all the local sockets that are in room3 and room5
547    ///     let sockets = socket.within("room4").within("room5").sockets();
548    /// }
549    ///
550    /// let (_, io) = SocketIo::new_svc();
551    /// io.ns("/", async |s: SocketRef| s.on("test", handler));
552    /// ```
553    pub fn join(&self, rooms: impl RoomParam) {
554        self.ns.adapter.get_local().add_all(self.id, rooms)
555    }
556
557    /// # Remove the current socket from the specified room(s).
558    ///
559    /// # Example
560    /// ```rust
561    /// # use socketioxide::{SocketIo, extract::*};
562    /// async fn handler(socket: SocketRef) {
563    ///     // Remove all sockets that are in room1 and room3 from room4 and room5
564    ///     socket.within("room1").within("room3").leave(["room4", "room5"]);
565    /// }
566    ///
567    /// let (_, io) = SocketIo::new_svc();
568    /// io.ns("/", async |s: SocketRef| s.on("test", handler));
569    /// ```
570    pub fn leave(&self, rooms: impl RoomParam) {
571        self.ns.adapter.get_local().del(self.id, rooms)
572    }
573
574    /// # Remove the current socket from all its rooms.
575    pub fn leave_all(&self) {
576        self.ns.adapter.get_local().del_all(self.id);
577    }
578
579    /// # Get all room names this socket is connected to.
580    ///
581    /// # Example
582    /// ```rust
583    /// # use socketioxide::{SocketIo, extract::SocketRef};
584    /// async fn handler(socket: SocketRef) {
585    ///     println!("Socket connected to the / namespace with id: {}", socket.id);
586    ///     socket.join(["room1", "room2"]);
587    ///     let rooms = socket.rooms();
588    ///     println!("All rooms in the / namespace: {:?}", rooms);
589    /// }
590    ///
591    /// let (_, io) = SocketIo::new_svc();
592    /// io.ns("/", handler);
593    /// ```
594    pub fn rooms(&self) -> Vec<Room> {
595        self.ns
596            .adapter
597            .get_local()
598            .socket_rooms(self.id)
599            .into_iter()
600            .collect()
601    }
602
603    /// # Return true if the socket is connected to the namespace.
604    ///
605    /// A socket is considered connected when it has been successfully handshaked with the server
606    /// and that all [connect middlewares](crate::handler::connect#middlewares) have been executed.
607    pub fn connected(&self) -> bool {
608        self.connected.load(Ordering::SeqCst)
609    }
610
611    // Socket operators
612
613    #[doc = include_str!("../docs/operators/to.md")]
614    pub fn to(&self, rooms: impl RoomParam) -> BroadcastOperators<A> {
615        BroadcastOperators::from_sock(self.ns.clone(), self.id, self.parser).to(rooms)
616    }
617
618    #[doc = include_str!("../docs/operators/within.md")]
619    pub fn within(&self, rooms: impl RoomParam) -> BroadcastOperators<A> {
620        BroadcastOperators::from_sock(self.ns.clone(), self.id, self.parser).within(rooms)
621    }
622
623    #[doc = include_str!("../docs/operators/except.md")]
624    pub fn except(&self, rooms: impl RoomParam) -> BroadcastOperators<A> {
625        BroadcastOperators::from_sock(self.ns.clone(), self.id, self.parser).except(rooms)
626    }
627
628    #[doc = include_str!("../docs/operators/local.md")]
629    pub fn local(&self) -> BroadcastOperators<A> {
630        BroadcastOperators::from_sock(self.ns.clone(), self.id, self.parser).local()
631    }
632
633    #[doc = include_str!("../docs/operators/timeout.md")]
634    pub fn timeout(&self, timeout: Duration) -> ConfOperators<'_, A> {
635        ConfOperators::new(self).timeout(timeout)
636    }
637
638    #[doc = include_str!("../docs/operators/broadcast.md")]
639    pub fn broadcast(&self) -> BroadcastOperators<A> {
640        BroadcastOperators::from_sock(self.ns.clone(), self.id, self.parser).broadcast()
641    }
642
643    /// # Get the [`SocketIo`] context related to this socket
644    ///
645    /// # Panics
646    /// Because [`SocketData::io`] should be immediately set at the creation of the socket.
647    /// this should never panic.
648    pub(crate) fn get_io(&self) -> &SocketIo<A> {
649        self.esocket.data.io.get().unwrap()
650    }
651
652    /// # Disconnect the socket from the current namespace,
653    ///
654    /// It will also call the disconnect handler if it is set with a [`DisconnectReason::ServerNSDisconnect`].
655    pub fn disconnect(self: Arc<Self>) -> Result<(), SocketError> {
656        let res = self.send(Packet::disconnect(self.ns.path.clone()));
657        if let Err(SocketError::InternalChannelFull) = res {
658            return Err(SocketError::InternalChannelFull);
659        }
660        self.close(DisconnectReason::ServerNSDisconnect);
661        Ok(())
662    }
663
664    /// # Get the request info made by the client to connect.
665    ///
666    /// It might be used to retrieve the [`http::Extensions`]
667    pub fn req_parts(&self) -> &http::request::Parts {
668        &self.esocket.req_parts
669    }
670
671    /// # Get the [`TransportType`](crate::TransportType) used by the client to connect with this [`Socket`].
672    ///
673    /// It can also be accessed as an extractor
674    /// # Example
675    /// ```
676    /// # use socketioxide::{SocketIo, TransportType, extract::*};
677    ///
678    /// let (_, io) = SocketIo::new_svc();
679    /// io.ns("/", async |socket: SocketRef, transport: TransportType| {
680    ///     assert_eq!(socket.transport_type(), transport);
681    /// });
682    pub fn transport_type(&self) -> crate::TransportType {
683        self.esocket.transport_type()
684    }
685
686    /// Get the socket.io [`ProtocolVersion`](crate::ProtocolVersion) used by the client to connect with this [`Socket`].
687    ///
688    /// It can also be accessed as an extractor:
689    /// # Example
690    /// ```
691    /// # use socketioxide::{SocketIo, ProtocolVersion, extract::*};
692    ///
693    /// let (_, io) = SocketIo::new_svc();
694    /// io.ns("/", async |socket: SocketRef, v: ProtocolVersion| {
695    ///     assert_eq!(socket.protocol(), v);
696    /// });
697    pub fn protocol(&self) -> crate::ProtocolVersion {
698        self.esocket.protocol.into()
699    }
700
701    /// # Get the socket namespace path.
702    #[inline]
703    pub fn ns(&self) -> &str {
704        &self.ns.path
705    }
706
707    /// # Close the engine.io connection if it is not already closed.
708    ///
709    /// Return a future that resolves when the underlying transport is closed.
710    pub(crate) async fn close_underlying_transport(&self) {
711        if !self.esocket.is_closed() {
712            #[cfg(feature = "tracing")]
713            tracing::debug!("closing underlying transport for socket: {}", self.id);
714            self.esocket.close(EIoDisconnectReason::ClosingServer);
715        }
716        self.esocket.closed().await;
717    }
718
719    pub(crate) fn set_connected(&self, connected: bool) {
720        self.connected.store(connected, Ordering::SeqCst);
721    }
722
723    pub(crate) fn reserve(&self) -> Result<Permit<'_>, SocketError> {
724        match self.esocket.reserve() {
725            Ok(permit) => Ok(permit),
726            Err(TrySendError::Full(_)) => Err(SocketError::InternalChannelFull),
727            Err(TrySendError::Closed(_)) => Err(SocketError::Closed),
728        }
729    }
730
731    pub(crate) fn send(&self, packet: Packet) -> Result<(), SocketError> {
732        let permit = self.reserve()?;
733        permit.send(packet, self.parser);
734        Ok(())
735    }
736    pub(crate) fn send_raw(&self, value: Value) -> Result<(), SocketError> {
737        let permit = self.reserve()?;
738        permit.send_raw(value);
739        Ok(())
740    }
741
742    pub(crate) fn send_with_ack_permit(
743        &self,
744        mut packet: Packet,
745        permit: Permit<'_>,
746    ) -> Receiver<AckResult<Value>> {
747        let (tx, rx) = oneshot::channel();
748
749        let ack = self.ack_counter.fetch_add(1, Ordering::SeqCst) + 1;
750        packet.inner.set_ack_id(ack);
751        permit.send(packet, self.parser);
752        self.ack_message.lock().unwrap().insert(ack, tx);
753        rx
754    }
755
756    pub(crate) fn send_with_ack(&self, mut packet: Packet) -> Receiver<AckResult<Value>> {
757        let (tx, rx) = oneshot::channel();
758
759        let ack = self.ack_counter.fetch_add(1, Ordering::SeqCst) + 1;
760        packet.inner.set_ack_id(ack);
761        match self.send(packet) {
762            Ok(()) => {
763                self.ack_message.lock().unwrap().insert(ack, tx);
764            }
765            Err(e) => {
766                tx.send(Err(AckError::Socket(e))).ok();
767            }
768        }
769        rx
770    }
771
772    /// Called when the socket is gracefully disconnected from the server or the client
773    ///
774    /// It maybe also close when the underlying transport is closed or failed.
775    pub(crate) fn close(self: Arc<Self>, reason: DisconnectReason) {
776        self.set_connected(false);
777
778        let disconnect_handler = { self.disconnect_handler.lock().unwrap().take() };
779
780        if let Some(handler) = disconnect_handler {
781            #[cfg(feature = "tracing")]
782            tracing::trace!(?reason, ?self.id, "spawning disconnect handler");
783
784            handler.call_with_defer(self.clone(), reason, |s| s.ns.remove_socket(s.id));
785        } else {
786            self.ns.remove_socket(self.id);
787        }
788    }
789
790    /// Receive data from client
791    pub(crate) fn recv(self: Arc<Self>, packet: PacketData) -> Result<(), Error> {
792        match packet {
793            PacketData::Event(d, ack) | PacketData::BinaryEvent(d, ack) => self.recv_event(d, ack),
794            PacketData::EventAck(d, ack) | PacketData::BinaryAck(d, ack) => self.recv_ack(d, ack),
795            PacketData::Disconnect => {
796                self.close(DisconnectReason::ClientNSDisconnect);
797                Ok(())
798            }
799            _ => unreachable!(),
800        }
801    }
802
803    fn recv_event(self: Arc<Self>, data: Value, ack: Option<i64>) -> Result<(), Error> {
804        let event = self.parser.read_event(&data).map_err(|_e| {
805            #[cfg(feature = "tracing")]
806            tracing::debug!(?_e, "failed to read event");
807            Error::InvalidEventName
808        })?;
809        #[cfg(feature = "tracing")]
810        tracing::debug!(?event, "reading");
811        if let Some(handler) = self.message_handlers.read().unwrap().get(event) {
812            handler.call(self.clone(), data, ack);
813        } else if let Some(fallback) = self.fallback_message_handler.read().unwrap().as_ref() {
814            fallback.call(self.clone(), data, ack);
815        }
816        Ok(())
817    }
818
819    fn recv_ack(self: Arc<Self>, data: Value, ack: i64) -> Result<(), Error> {
820        if let Some(tx) = self.ack_message.lock().unwrap().remove(&ack) {
821            tx.send(Ok(data)).ok();
822        } else {
823            #[cfg(feature = "tracing")]
824            tracing::debug!(sid = ?self.id, "ack not found: {ack}");
825        }
826        Ok(())
827    }
828}
829
830impl<A: Adapter> Debug for Socket<A> {
831    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
832        f.debug_struct("Socket")
833            .field("ns", &self.ns())
834            .field("ack_message", &self.ack_message)
835            .field("ack_counter", &self.ack_counter)
836            .field("sid", &self.id)
837            .finish()
838    }
839}
840impl<A: Adapter> PartialEq for Socket<A> {
841    fn eq(&self, other: &Self) -> bool {
842        self.id == other.id
843    }
844}
845
846#[doc(hidden)]
847#[cfg(feature = "__test_harness")]
848impl Socket<LocalAdapter> {
849    /// Creates a dummy socket for testing purposes
850    pub fn new_dummy(sid: Sid, ns: Arc<Namespace<LocalAdapter>>) -> Socket<LocalAdapter> {
851        use crate::client::Client;
852        use crate::io::SocketIoConfig;
853
854        let close_fn = Box::new(move |_, _| ());
855        let config = SocketIoConfig::default();
856        let io = SocketIo::from(Arc::new(Client::new(
857            config,
858            (),
859            #[cfg(feature = "state")]
860            std::default::Default::default(),
861        )));
862        let s = Socket::new(
863            sid,
864            ns,
865            engineioxide::Socket::new_dummy(sid, close_fn),
866            Parser::default(),
867        );
868        s.esocket.data.io.set(io).unwrap();
869        s.set_connected(true);
870        s
871    }
872}
873
874#[cfg(test)]
875mod test {
876    use super::*;
877
878    #[tokio::test]
879    async fn send_with_ack_error() {
880        let sid = Sid::new();
881        let ns = Namespace::<LocalAdapter>::new_dummy([sid]);
882        let socket: Arc<Socket> = Socket::new_dummy(sid, ns).into();
883        let parser = Parser::default();
884        // Saturate the channel
885        for _ in 0..1024 {
886            socket
887                .send(Packet::event(
888                    "test",
889                    parser.encode_value(&(), Some("test")).unwrap(),
890                ))
891                .unwrap();
892        }
893
894        let ack = socket.emit_with_ack::<_, ()>("test", &());
895        assert!(matches!(
896            ack,
897            Err(SendError::Socket(SocketError::InternalChannelFull))
898        ));
899    }
900}