1use std::{
4 borrow::Cow,
5 collections::HashMap,
6 fmt::{self, Debug},
7 future::Future,
8 sync::{
9 Arc, Mutex, RwLock,
10 atomic::{AtomicBool, AtomicI64, Ordering},
11 },
12 time::Duration,
13};
14
15use engineioxide::socket::{DisconnectReason as EIoDisconnectReason, Permit};
16use serde::Serialize;
17use tokio::sync::{
18 mpsc::error::TrySendError,
19 oneshot::{self, Receiver},
20};
21
22#[cfg(feature = "extensions")]
23use crate::extensions::Extensions;
24
25use crate::{
26 AckError, SendError, SocketError, SocketIo,
27 ack::{AckInnerStream, AckResult, AckStream},
28 adapter::{Adapter, LocalAdapter},
29 client::SocketData,
30 errors::Error,
31 handler::{
32 BoxedDisconnectHandler, BoxedMessageHandler, DisconnectHandler, MakeErasedHandler,
33 MessageHandler,
34 },
35 ns::Namespace,
36 operators::{BroadcastOperators, ConfOperators},
37 parser::Parser,
38};
39use socketioxide_core::{
40 Value,
41 adapter::errors::{AdapterError, BroadcastError},
42 adapter::{BroadcastOptions, RemoteSocketData, Room, RoomParam},
43 packet::{Packet, PacketData},
44 parser::Parse,
45};
46
47pub use socketioxide_core::Sid;
48
49#[derive(Debug, Copy, Clone, Eq, PartialEq)]
53pub enum DisconnectReason {
54 TransportClose,
56
57 MultipleHttpPollingError,
59
60 PacketParsingError,
62
63 TransportError,
65
66 HeartbeatTimeout,
68
69 ClientNSDisconnect,
71
72 ServerNSDisconnect,
74
75 ClosingServer,
77}
78
79impl std::fmt::Display for DisconnectReason {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 use DisconnectReason::*;
82 let str: &'static str = match self {
83 TransportClose => "client gracefully closed the connection",
84 MultipleHttpPollingError => "client sent multiple polling requests at the same time",
85 PacketParsingError => "client sent a bad request / the packet could not be parsed",
86 TransportError => "The connection was abruptly closed",
87 HeartbeatTimeout => "client did not send a PONG packet in time",
88 ClientNSDisconnect => "client has manually disconnected the socket from the namespace",
89 ServerNSDisconnect => "socket was forcefully disconnected from the namespace",
90 ClosingServer => "server is being closed",
91 };
92 f.write_str(str)
93 }
94}
95
96impl From<EIoDisconnectReason> for DisconnectReason {
97 fn from(reason: EIoDisconnectReason) -> Self {
98 use DisconnectReason::*;
99 match reason {
100 EIoDisconnectReason::TransportClose => TransportClose,
101 EIoDisconnectReason::TransportError => TransportError,
102 EIoDisconnectReason::HeartbeatTimeout => HeartbeatTimeout,
103 EIoDisconnectReason::MultipleHttpPollingError => MultipleHttpPollingError,
104 EIoDisconnectReason::PacketParsingError => PacketParsingError,
105 EIoDisconnectReason::ClosingServer => ClosingServer,
106 }
107 }
108}
109
110pub(crate) trait PermitExt<'a> {
111 fn send(self, packet: Packet, parser: Parser);
112 fn send_raw(self, value: Value);
113}
114impl<'a> PermitExt<'a> for Permit<'a> {
115 fn send(self, packet: Packet, parser: Parser) {
116 match parser.encode(packet) {
117 Value::Str(msg, None) => self.emit(msg),
118 Value::Str(msg, Some(bin_payloads)) => self.emit_many(msg, bin_payloads),
119 Value::Bytes(bin) => self.emit_binary(bin),
120 }
121 }
122
123 fn send_raw(self, value: Value) {
124 match value {
125 Value::Str(msg, None) => self.emit(msg),
126 Value::Str(msg, Some(bin_payloads)) => self.emit_many(msg, bin_payloads),
127 Value::Bytes(bin) => self.emit_binary(bin),
128 }
129 }
130}
131
132pub struct RemoteSocket<A> {
135 adapter: Arc<A>,
136 parser: Parser,
137 data: RemoteSocketData,
138}
139
140impl<A> RemoteSocket<A> {
141 pub(crate) fn new(data: RemoteSocketData, adapter: &Arc<A>, parser: Parser) -> Self {
142 Self {
143 data,
144 adapter: adapter.clone(),
145 parser,
146 }
147 }
148 #[inline]
150 pub fn into_data(self) -> RemoteSocketData {
151 self.data
152 }
153 #[inline]
155 pub fn data(&self) -> &RemoteSocketData {
156 &self.data
157 }
158}
159impl<A: Adapter> RemoteSocket<A> {
160 pub async fn emit<T: ?Sized + Serialize>(
164 &self,
165 event: impl AsRef<str>,
166 data: &T,
167 ) -> Result<(), RemoteActionError> {
168 let opts = self.get_opts();
169 let data = self.parser.encode_value(data, Some(event.as_ref()))?;
170 let packet = Packet::event(self.data.ns.clone(), data);
171 self.adapter.broadcast(packet, opts).await?;
172 Ok(())
173 }
174
175 pub async fn emit_with_ack<T: ?Sized + Serialize, V: serde::de::DeserializeOwned>(
179 &self,
180 event: impl AsRef<str>,
181 data: &T,
182 ) -> Result<AckStream<V, A>, RemoteActionError> {
183 let opts = self.get_opts();
184 let data = self.parser.encode_value(data, Some(event.as_ref()))?;
185 let packet = Packet::event(self.data.ns.clone(), data);
186 let stream = self
187 .adapter
188 .broadcast_with_ack(packet, opts, None)
189 .await
190 .map_err(Into::<AdapterError>::into)?;
191 Ok(AckStream::new(stream, self.parser))
192 }
193
194 #[inline]
198 pub async fn rooms(&self) -> Result<Vec<Room>, A::Error> {
199 self.adapter.rooms(self.get_opts()).await
200 }
201
202 #[inline]
206 pub fn join(&self, rooms: impl RoomParam) -> impl Future<Output = Result<(), A::Error>> + '_ {
207 self.adapter.add_sockets(self.get_opts(), rooms)
208 }
209
210 #[inline]
214 pub fn leave(
215 &self,
216 rooms: impl RoomParam,
217 ) -> impl Future<Output = Result<(), A::Error>> + Send + '_ {
218 self.adapter.del_sockets(self.get_opts(), rooms)
219 }
220
221 #[inline]
225 pub async fn disconnect(self) -> Result<(), RemoteActionError> {
226 self.adapter.disconnect_socket(self.get_opts()).await?;
227 Ok(())
228 }
229
230 #[inline(always)]
231 fn get_opts(&self) -> BroadcastOptions {
232 BroadcastOptions::new_remote(&self.data)
233 }
234}
235impl<A> fmt::Debug for RemoteSocket<A> {
236 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
237 f.debug_struct("RemoteSocket")
238 .field("id", &self.data.id)
239 .field("server_id", &self.data.server_id)
240 .field("ns", &self.data.ns)
241 .finish()
242 }
243}
244impl<A> Clone for RemoteSocket<A> {
245 fn clone(&self) -> Self {
246 Self {
247 adapter: self.adapter.clone(),
248 parser: self.parser,
249 data: self.data.clone(),
250 }
251 }
252}
253
254#[derive(Debug, thiserror::Error)]
256pub enum RemoteActionError {
257 #[error("cannot encode data: {0}")]
259 Serialize(#[from] crate::parser::ParserError),
260 #[error("cannot send the message to the local socket: {0}")]
262 Socket(crate::SocketError),
263 #[error("cannot propagate the request to the server: {0}")]
265 Adapter(#[from] AdapterError),
266}
267impl From<BroadcastError> for RemoteActionError {
268 fn from(value: BroadcastError) -> Self {
269 match value {
271 BroadcastError::Socket(s) if !s.is_empty() => RemoteActionError::Socket(s[0].clone()),
272 BroadcastError::Socket(_) => {
273 panic!("BroadcastError with an empty socket vec is not permitted")
274 }
275 BroadcastError::Adapter(e) => e.into(),
276 BroadcastError::Serialize(e) => e.into(),
277 }
278 }
279}
280
281pub struct Socket<A: Adapter = LocalAdapter> {
285 pub(crate) ns: Arc<Namespace<A>>,
286 message_handlers: RwLock<HashMap<Cow<'static, str>, BoxedMessageHandler<A>>>,
287 fallback_message_handler: RwLock<Option<BoxedMessageHandler<A>>>,
288 disconnect_handler: Mutex<Option<BoxedDisconnectHandler<A>>>,
289 ack_message: Mutex<HashMap<i64, oneshot::Sender<AckResult<Value>>>>,
290 ack_counter: AtomicI64,
291 connected: AtomicBool,
292 pub(crate) parser: Parser,
293 pub id: Sid,
295
296 #[cfg_attr(docsrs, doc(cfg(feature = "extensions")))]
302 #[cfg(feature = "extensions")]
303 pub extensions: Extensions,
304 esocket: Arc<engineioxide::Socket<SocketData<A>>>,
305}
306
307impl<A: Adapter> Socket<A> {
308 pub(crate) fn new(
309 sid: Sid,
310 ns: Arc<Namespace<A>>,
311 esocket: Arc<engineioxide::Socket<SocketData<A>>>,
312 parser: Parser,
313 ) -> Self {
314 Self {
315 ns,
316 message_handlers: RwLock::new(HashMap::new()),
317 fallback_message_handler: RwLock::new(None),
318 disconnect_handler: Mutex::new(None),
319 ack_message: Mutex::new(HashMap::new()),
320 ack_counter: AtomicI64::new(0),
321 connected: AtomicBool::new(false),
322 parser,
323 id: sid,
324 #[cfg(feature = "extensions")]
325 extensions: Extensions::new(),
326 esocket,
327 }
328 }
329
330 pub fn on<H, T>(&self, event: impl Into<Cow<'static, str>>, handler: H)
399 where
400 H: MessageHandler<A, T>,
401 T: Send + Sync + 'static,
402 {
403 self.message_handlers
404 .write()
405 .unwrap()
406 .insert(event.into(), MakeErasedHandler::new_message_boxed(handler));
407 }
408
409 pub fn on_fallback<H, T>(&self, handler: H)
435 where
436 H: MessageHandler<A, T>,
437 T: Send + Sync + 'static,
438 {
439 self.fallback_message_handler
440 .write()
441 .unwrap()
442 .replace(MakeErasedHandler::new_message_boxed(handler));
443 }
444
445 pub fn on_disconnect<C, T>(&self, callback: C)
477 where
478 C: DisconnectHandler<A, T> + Send + Sync + 'static,
479 T: Send + Sync + 'static,
480 {
481 let handler = MakeErasedHandler::new_disconnect_boxed(callback);
482 self.disconnect_handler.lock().unwrap().replace(handler);
483 }
484
485 #[doc = include_str!("../docs/operators/emit.md")]
486 pub fn emit<T: ?Sized + Serialize>(
487 &self,
488 event: impl AsRef<str>,
489 data: &T,
490 ) -> Result<(), SendError> {
491 if !self.connected() {
492 return Err(SendError::Socket(SocketError::Closed));
493 }
494
495 let permit = match self.reserve() {
496 Ok(permit) => permit,
497 Err(e) => {
498 #[cfg(feature = "tracing")]
499 tracing::debug!("sending error during emit message: {e:?}");
500 return Err(SendError::Socket(e));
501 }
502 };
503
504 let ns = self.ns.path.clone();
505 let data = self.parser.encode_value(data, Some(event.as_ref()))?;
506
507 permit.send(Packet::event(ns, data), self.parser);
508 Ok(())
509 }
510
511 #[doc = include_str!("../docs/operators/emit_with_ack.md")]
512 pub fn emit_with_ack<T: ?Sized + Serialize, V>(
513 &self,
514 event: impl AsRef<str>,
515 data: &T,
516 ) -> Result<AckStream<V>, SendError> {
517 if !self.connected() {
518 return Err(SendError::Socket(SocketError::Closed));
519 }
520 let permit = match self.reserve() {
521 Ok(permit) => permit,
522 Err(e) => {
523 #[cfg(feature = "tracing")]
524 tracing::debug!("sending error during emit message: {e:?}");
525 return Err(SendError::Socket(e));
526 }
527 };
528 let ns = self.ns.path.clone();
529 let data = self.parser.encode_value(data, Some(event.as_ref()))?;
530 let packet = Packet::event(ns, data);
531 let rx = self.send_with_ack_permit(packet, permit);
532 let stream = AckInnerStream::send(rx, self.get_io().config().ack_timeout, self.id);
533 Ok(AckStream::<V>::new(stream, self.parser))
534 }
535
536 pub fn join(&self, rooms: impl RoomParam) {
554 self.ns.adapter.get_local().add_all(self.id, rooms)
555 }
556
557 pub fn leave(&self, rooms: impl RoomParam) {
571 self.ns.adapter.get_local().del(self.id, rooms)
572 }
573
574 pub fn leave_all(&self) {
576 self.ns.adapter.get_local().del_all(self.id);
577 }
578
579 pub fn rooms(&self) -> Vec<Room> {
595 self.ns
596 .adapter
597 .get_local()
598 .socket_rooms(self.id)
599 .into_iter()
600 .collect()
601 }
602
603 pub fn connected(&self) -> bool {
608 self.connected.load(Ordering::SeqCst)
609 }
610
611 #[doc = include_str!("../docs/operators/to.md")]
614 pub fn to(&self, rooms: impl RoomParam) -> BroadcastOperators<A> {
615 BroadcastOperators::from_sock(self.ns.clone(), self.id, self.parser).to(rooms)
616 }
617
618 #[doc = include_str!("../docs/operators/within.md")]
619 pub fn within(&self, rooms: impl RoomParam) -> BroadcastOperators<A> {
620 BroadcastOperators::from_sock(self.ns.clone(), self.id, self.parser).within(rooms)
621 }
622
623 #[doc = include_str!("../docs/operators/except.md")]
624 pub fn except(&self, rooms: impl RoomParam) -> BroadcastOperators<A> {
625 BroadcastOperators::from_sock(self.ns.clone(), self.id, self.parser).except(rooms)
626 }
627
628 #[doc = include_str!("../docs/operators/local.md")]
629 pub fn local(&self) -> BroadcastOperators<A> {
630 BroadcastOperators::from_sock(self.ns.clone(), self.id, self.parser).local()
631 }
632
633 #[doc = include_str!("../docs/operators/timeout.md")]
634 pub fn timeout(&self, timeout: Duration) -> ConfOperators<'_, A> {
635 ConfOperators::new(self).timeout(timeout)
636 }
637
638 #[doc = include_str!("../docs/operators/broadcast.md")]
639 pub fn broadcast(&self) -> BroadcastOperators<A> {
640 BroadcastOperators::from_sock(self.ns.clone(), self.id, self.parser).broadcast()
641 }
642
643 pub(crate) fn get_io(&self) -> &SocketIo<A> {
649 self.esocket.data.io.get().unwrap()
650 }
651
652 pub fn disconnect(self: Arc<Self>) -> Result<(), SocketError> {
656 let res = self.send(Packet::disconnect(self.ns.path.clone()));
657 if let Err(SocketError::InternalChannelFull) = res {
658 return Err(SocketError::InternalChannelFull);
659 }
660 self.close(DisconnectReason::ServerNSDisconnect);
661 Ok(())
662 }
663
664 pub fn req_parts(&self) -> &http::request::Parts {
668 &self.esocket.req_parts
669 }
670
671 pub fn transport_type(&self) -> crate::TransportType {
683 self.esocket.transport_type()
684 }
685
686 pub fn protocol(&self) -> crate::ProtocolVersion {
698 self.esocket.protocol.into()
699 }
700
701 #[inline]
703 pub fn ns(&self) -> &str {
704 &self.ns.path
705 }
706
707 pub(crate) async fn close_underlying_transport(&self) {
711 if !self.esocket.is_closed() {
712 #[cfg(feature = "tracing")]
713 tracing::debug!("closing underlying transport for socket: {}", self.id);
714 self.esocket.close(EIoDisconnectReason::ClosingServer);
715 }
716 self.esocket.closed().await;
717 }
718
719 pub(crate) fn set_connected(&self, connected: bool) {
720 self.connected.store(connected, Ordering::SeqCst);
721 }
722
723 pub(crate) fn reserve(&self) -> Result<Permit<'_>, SocketError> {
724 match self.esocket.reserve() {
725 Ok(permit) => Ok(permit),
726 Err(TrySendError::Full(_)) => Err(SocketError::InternalChannelFull),
727 Err(TrySendError::Closed(_)) => Err(SocketError::Closed),
728 }
729 }
730
731 pub(crate) fn send(&self, packet: Packet) -> Result<(), SocketError> {
732 let permit = self.reserve()?;
733 permit.send(packet, self.parser);
734 Ok(())
735 }
736 pub(crate) fn send_raw(&self, value: Value) -> Result<(), SocketError> {
737 let permit = self.reserve()?;
738 permit.send_raw(value);
739 Ok(())
740 }
741
742 pub(crate) fn send_with_ack_permit(
743 &self,
744 mut packet: Packet,
745 permit: Permit<'_>,
746 ) -> Receiver<AckResult<Value>> {
747 let (tx, rx) = oneshot::channel();
748
749 let ack = self.ack_counter.fetch_add(1, Ordering::SeqCst) + 1;
750 packet.inner.set_ack_id(ack);
751 permit.send(packet, self.parser);
752 self.ack_message.lock().unwrap().insert(ack, tx);
753 rx
754 }
755
756 pub(crate) fn send_with_ack(&self, mut packet: Packet) -> Receiver<AckResult<Value>> {
757 let (tx, rx) = oneshot::channel();
758
759 let ack = self.ack_counter.fetch_add(1, Ordering::SeqCst) + 1;
760 packet.inner.set_ack_id(ack);
761 match self.send(packet) {
762 Ok(()) => {
763 self.ack_message.lock().unwrap().insert(ack, tx);
764 }
765 Err(e) => {
766 tx.send(Err(AckError::Socket(e))).ok();
767 }
768 }
769 rx
770 }
771
772 pub(crate) fn close(self: Arc<Self>, reason: DisconnectReason) {
776 self.set_connected(false);
777
778 let disconnect_handler = { self.disconnect_handler.lock().unwrap().take() };
779
780 if let Some(handler) = disconnect_handler {
781 #[cfg(feature = "tracing")]
782 tracing::trace!(?reason, ?self.id, "spawning disconnect handler");
783
784 handler.call_with_defer(self.clone(), reason, |s| s.ns.remove_socket(s.id));
785 } else {
786 self.ns.remove_socket(self.id);
787 }
788 }
789
790 pub(crate) fn recv(self: Arc<Self>, packet: PacketData) -> Result<(), Error> {
792 match packet {
793 PacketData::Event(d, ack) | PacketData::BinaryEvent(d, ack) => self.recv_event(d, ack),
794 PacketData::EventAck(d, ack) | PacketData::BinaryAck(d, ack) => self.recv_ack(d, ack),
795 PacketData::Disconnect => {
796 self.close(DisconnectReason::ClientNSDisconnect);
797 Ok(())
798 }
799 _ => unreachable!(),
800 }
801 }
802
803 fn recv_event(self: Arc<Self>, data: Value, ack: Option<i64>) -> Result<(), Error> {
804 let event = self.parser.read_event(&data).map_err(|_e| {
805 #[cfg(feature = "tracing")]
806 tracing::debug!(?_e, "failed to read event");
807 Error::InvalidEventName
808 })?;
809 #[cfg(feature = "tracing")]
810 tracing::debug!(?event, "reading");
811 if let Some(handler) = self.message_handlers.read().unwrap().get(event) {
812 handler.call(self.clone(), data, ack);
813 } else if let Some(fallback) = self.fallback_message_handler.read().unwrap().as_ref() {
814 fallback.call(self.clone(), data, ack);
815 }
816 Ok(())
817 }
818
819 fn recv_ack(self: Arc<Self>, data: Value, ack: i64) -> Result<(), Error> {
820 if let Some(tx) = self.ack_message.lock().unwrap().remove(&ack) {
821 tx.send(Ok(data)).ok();
822 } else {
823 #[cfg(feature = "tracing")]
824 tracing::debug!(sid = ?self.id, "ack not found: {ack}");
825 }
826 Ok(())
827 }
828}
829
830impl<A: Adapter> Debug for Socket<A> {
831 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
832 f.debug_struct("Socket")
833 .field("ns", &self.ns())
834 .field("ack_message", &self.ack_message)
835 .field("ack_counter", &self.ack_counter)
836 .field("sid", &self.id)
837 .finish()
838 }
839}
840impl<A: Adapter> PartialEq for Socket<A> {
841 fn eq(&self, other: &Self) -> bool {
842 self.id == other.id
843 }
844}
845
846#[doc(hidden)]
847#[cfg(feature = "__test_harness")]
848impl Socket<LocalAdapter> {
849 pub fn new_dummy(sid: Sid, ns: Arc<Namespace<LocalAdapter>>) -> Socket<LocalAdapter> {
851 use crate::client::Client;
852 use crate::io::SocketIoConfig;
853
854 let close_fn = Box::new(move |_, _| ());
855 let config = SocketIoConfig::default();
856 let io = SocketIo::from(Arc::new(Client::new(
857 config,
858 (),
859 #[cfg(feature = "state")]
860 std::default::Default::default(),
861 )));
862 let s = Socket::new(
863 sid,
864 ns,
865 engineioxide::Socket::new_dummy(sid, close_fn),
866 Parser::default(),
867 );
868 s.esocket.data.io.set(io).unwrap();
869 s.set_connected(true);
870 s
871 }
872}
873
874#[cfg(test)]
875mod test {
876 use super::*;
877
878 #[tokio::test]
879 async fn send_with_ack_error() {
880 let sid = Sid::new();
881 let ns = Namespace::<LocalAdapter>::new_dummy([sid]);
882 let socket: Arc<Socket> = Socket::new_dummy(sid, ns).into();
883 let parser = Parser::default();
884 for _ in 0..1024 {
886 socket
887 .send(Packet::event(
888 "test",
889 parser.encode_value(&(), Some("test")).unwrap(),
890 ))
891 .unwrap();
892 }
893
894 let ack = socket.emit_with_ack::<_, ()>("test", &());
895 assert!(matches!(
896 ack,
897 Err(SendError::Socket(SocketError::InternalChannelFull))
898 ));
899 }
900}