socketioxide/
ns.rs

1use std::{
2    collections::HashMap,
3    sync::{Arc, RwLock, Weak},
4    time::Duration,
5};
6
7use crate::{
8    ProtocolVersion, SocketError, SocketIoConfig,
9    ack::AckInnerStream,
10    adapter::Adapter,
11    client::SocketData,
12    errors::{ConnectFail, Error},
13    handler::{BoxedConnectHandler, ConnectHandler, MakeErasedHandler},
14    parser::Parser,
15    socket::{DisconnectReason, Socket},
16};
17use socketioxide_core::{
18    Sid, Str, Uid, Value,
19    adapter::{BroadcastIter, CoreLocalAdapter, RemoteSocketData, SocketEmitter},
20    packet::{ConnectPacket, Packet, PacketData},
21    parser::Parse,
22};
23
24/// A [`Namespace`] constructor used for dynamic namespaces
25/// A namespace constructor only hold a common handler that will be cloned
26/// to the instantiated namespaces.
27pub struct NamespaceCtr<A: Adapter> {
28    handler: BoxedConnectHandler<A>,
29}
30pub struct Namespace<A: Adapter> {
31    pub path: Str,
32    pub(crate) adapter: Arc<A>,
33    parser: Parser,
34    handler: BoxedConnectHandler<A>,
35    sockets: RwLock<HashMap<Sid, Arc<Socket<A>>>>,
36}
37
38/// ===== impl NamespaceCtr =====
39impl<A: Adapter> NamespaceCtr<A> {
40    pub fn new<C, T>(handler: C) -> Self
41    where
42        C: ConnectHandler<A, T> + Send + Sync + 'static,
43        T: Send + Sync + 'static,
44    {
45        Self {
46            handler: MakeErasedHandler::new_ns_boxed(handler),
47        }
48    }
49    pub fn get_new_ns(
50        &self,
51        path: Str,
52        adapter_state: &A::State,
53        config: &SocketIoConfig,
54    ) -> Arc<Namespace<A>> {
55        let handler = self.handler.boxed_clone();
56        Namespace::new_boxed(path, handler, adapter_state, config)
57    }
58}
59
60impl<A: Adapter> Namespace<A> {
61    pub(crate) fn new<C, T>(
62        path: Str,
63        handler: C,
64        adapter_state: &A::State,
65        config: &SocketIoConfig,
66    ) -> Arc<Self>
67    where
68        C: ConnectHandler<A, T> + Send + Sync + 'static,
69        T: Send + Sync + 'static,
70    {
71        let handler = MakeErasedHandler::new_ns_boxed(handler);
72        Self::new_boxed(path, handler, adapter_state, config)
73    }
74
75    fn new_boxed(
76        path: Str,
77        handler: BoxedConnectHandler<A>,
78        adapter_state: &A::State,
79        config: &SocketIoConfig,
80    ) -> Arc<Self> {
81        let parser = config.parser;
82        let ack_timeout = config.ack_timeout;
83        let uid = config.server_id;
84        Arc::new_cyclic(|ns| Self {
85            path: path.clone(),
86            handler,
87            parser,
88            sockets: HashMap::new().into(),
89            adapter: Arc::new(A::new(
90                adapter_state,
91                CoreLocalAdapter::new(Emitter::new(ns.clone(), parser, path, ack_timeout, uid)),
92            )),
93        })
94    }
95
96    /// Connects a socket to a namespace.
97    ///
98    /// Middlewares are first called to check if the connection is allowed.
99    /// * If the handler returns an error, a connect_error packet is sent to the client.
100    /// * If the handler returns Ok, a connect packet is sent to the client and the handler `is` called.
101    pub(crate) async fn connect(
102        self: Arc<Self>,
103        sid: Sid,
104        esocket: Arc<engineioxide::Socket<SocketData<A>>>,
105        auth: Option<Value>,
106    ) -> Result<(), ConnectFail> {
107        let socket: Arc<Socket<A>> =
108            Socket::new(sid, self.clone(), esocket.clone(), self.parser).into();
109
110        if let Err(e) = self.handler.call_middleware(socket.clone(), &auth).await {
111            #[cfg(feature = "tracing")]
112            tracing::trace!(ns = self.path.as_str(), ?socket.id, "emitting connect_error packet");
113
114            let data = e.to_string();
115            if let Err(_e) = socket.send(Packet::connect_error(self.path.clone(), data)) {
116                #[cfg(feature = "tracing")]
117                tracing::debug!("error sending connect_error packet: {:?}, closing conn", _e);
118                esocket.close(engineioxide::DisconnectReason::PacketParsingError);
119            }
120            return Err(ConnectFail);
121        }
122
123        self.sockets.write().unwrap().insert(sid, socket.clone());
124        #[cfg(feature = "tracing")]
125        tracing::trace!(?socket.id, ?self.path, "socket added to namespace");
126
127        let protocol = esocket.protocol.into();
128        let payload = ConnectPacket { sid: socket.id };
129        let payload = match protocol {
130            ProtocolVersion::V5 => Some(self.parser.encode_default(&payload).unwrap()),
131            ProtocolVersion::V4 => None,
132        };
133        if let Err(_e) = socket.send(Packet::connect(self.path.clone(), payload)) {
134            #[cfg(feature = "tracing")]
135            tracing::debug!("error sending connect packet: {:?}, closing conn", _e);
136            esocket.close(engineioxide::DisconnectReason::PacketParsingError);
137            return Err(ConnectFail);
138        }
139
140        socket.set_connected(true);
141        self.handler.call(socket, auth);
142
143        Ok(())
144    }
145
146    /// Removes a socket from a namespace
147    pub fn remove_socket(&self, sid: Sid) {
148        #[cfg(feature = "tracing")]
149        tracing::trace!(?sid, ?self.path, "removing socket from namespace");
150
151        self.sockets.write().unwrap().remove(&sid);
152        self.adapter.get_local().del_all(sid);
153    }
154
155    pub fn has(&self, sid: Sid) -> bool {
156        self.sockets.read().unwrap().values().any(|s| s.id == sid)
157    }
158
159    pub fn recv(&self, sid: Sid, packet: PacketData) -> Result<(), Error> {
160        match packet {
161            PacketData::Connect(_) => unreachable!("connect packets should be handled before"),
162            PacketData::ConnectError(_) => Err(Error::InvalidPacketType),
163            packet => self.get_socket(sid)?.recv(packet),
164        }
165    }
166
167    pub fn get_socket(&self, sid: Sid) -> Result<Arc<Socket<A>>, Error> {
168        self.sockets
169            .read()
170            .unwrap()
171            .get(&sid)
172            .cloned()
173            .ok_or(Error::SocketGone(sid))
174    }
175
176    pub fn get_sockets(&self) -> Vec<Arc<Socket<A>>> {
177        self.sockets.read().unwrap().values().cloned().collect()
178    }
179
180    /// Closes the entire namespace :
181    /// * Closes the adapter
182    /// * Closes all the sockets and
183    ///   their underlying connections in case of [`DisconnectReason::ClosingServer`]
184    /// * Removes all the sockets from the namespace
185    ///
186    /// This function is using .await points only when called with [`DisconnectReason::ClosingServer`]
187    pub async fn close(&self, reason: DisconnectReason) {
188        use futures_util::future;
189        let sockets = self.sockets.read().unwrap().clone();
190
191        #[cfg(feature = "tracing")]
192        tracing::debug!(?self.path, "closing {} sockets in namespace", sockets.len());
193
194        if reason == DisconnectReason::ClosingServer {
195            // When closing the underlying transport, this will indirectly close the socket
196            // Therefore there is no need to manually call `s.close()`.
197            future::join_all(sockets.values().map(|s| s.close_underlying_transport())).await;
198        } else {
199            for s in sockets.into_values() {
200                let _sid = s.id;
201                s.close(reason);
202            }
203        }
204        #[cfg(feature = "tracing")]
205        tracing::debug!(?self.path, "all sockets in namespace closed");
206
207        let _err = self.adapter.close().await;
208        #[cfg(feature = "tracing")]
209        if let Err(err) = _err {
210            tracing::debug!(?err, ?self.path, "could not close adapter");
211        }
212    }
213}
214
215/// A type erased emitter to discard the adapter type parameter `A`.
216/// Otherwise it creates a cyclic dependency between the namespace, the emitter and the adapter.
217trait InnerEmitter: Send + Sync + 'static {
218    /// Get the remote socket data from the socket ids.
219    fn get_remote_sockets(&self, sids: BroadcastIter<'_>, uid: Uid) -> Vec<RemoteSocketData>;
220    /// Get all the socket ids in the namespace.
221    fn get_all_sids(&self, filter: &dyn Fn(&Sid) -> bool) -> Vec<Sid>;
222    /// Send data to the list of socket ids.
223    fn send_many(&self, sids: BroadcastIter<'_>, data: Value) -> Result<(), Vec<SocketError>>;
224    /// Send data to the list of socket ids and get a stream of acks.
225    fn send_many_with_ack(
226        &self,
227        sids: BroadcastIter<'_>,
228        packet: Packet,
229        timeout: Duration,
230    ) -> (AckInnerStream, u32);
231    /// Disconnect all the sockets in the list.
232    fn disconnect_many(&self, sids: Vec<Sid>) -> Result<(), Vec<SocketError>>;
233}
234
235impl<A: Adapter> InnerEmitter for Namespace<A> {
236    fn get_remote_sockets(&self, sids: BroadcastIter<'_>, uid: Uid) -> Vec<RemoteSocketData> {
237        let sockets = self.sockets.read().unwrap();
238        sids.filter_map(|sid| sockets.get(&sid))
239            .map(|socket| RemoteSocketData {
240                id: socket.id,
241                ns: self.path.clone(),
242                server_id: uid,
243            })
244            .collect()
245    }
246    fn get_all_sids(&self, filter: &dyn Fn(&Sid) -> bool) -> Vec<Sid> {
247        self.sockets
248            .read()
249            .unwrap()
250            .keys()
251            .filter(|id| filter(id))
252            .copied()
253            .collect()
254    }
255
256    fn send_many(&self, sids: BroadcastIter<'_>, data: Value) -> Result<(), Vec<SocketError>> {
257        let sockets = self.sockets.read().unwrap();
258        let errs: Vec<SocketError> = sids
259            .filter_map(|sid| sockets.get(&sid))
260            .filter_map(|socket| socket.send_raw(data.clone()).err())
261            .collect();
262        if errs.is_empty() { Ok(()) } else { Err(errs) }
263    }
264
265    fn send_many_with_ack(
266        &self,
267        sids: BroadcastIter<'_>,
268        packet: Packet,
269        timeout: Duration,
270    ) -> (AckInnerStream, u32) {
271        let sockets_map = self.sockets.read().unwrap();
272        let sockets = sids.filter_map(|sid| sockets_map.get(&sid));
273        AckInnerStream::broadcast(packet, sockets, timeout)
274    }
275
276    fn disconnect_many(&self, sids: Vec<Sid>) -> Result<(), Vec<SocketError>> {
277        if sids.is_empty() {
278            return Ok(());
279        }
280        // Here we can't take a ref because this would cause a deadlock.
281        // Ideally the disconnect / closing process should be refactored to avoid this.
282        let sockets = {
283            let sock_map = self.sockets.read().unwrap();
284            sids.into_iter()
285                .filter_map(|sid| sock_map.get(&sid))
286                .cloned()
287                .collect::<Vec<_>>()
288        };
289
290        let errs = sockets
291            .into_iter()
292            .filter_map(|socket| socket.disconnect().err())
293            .collect::<Vec<_>>();
294        if errs.is_empty() { Ok(()) } else { Err(errs) }
295    }
296}
297
298/// Internal interface implementor to apply global operations on a namespace.
299#[doc(hidden)]
300pub struct Emitter {
301    /// This `Weak<dyn>` allows to break the cyclic dependency between the namespace and the emitter.
302    ns: Weak<dyn InnerEmitter>,
303    parser: Parser,
304    path: Str,
305    ack_timeout: Duration,
306    uid: Uid,
307}
308
309impl Emitter {
310    fn new<A: Adapter>(
311        ns: Weak<Namespace<A>>,
312        parser: Parser,
313        path: Str,
314        ack_timeout: Duration,
315        uid: Uid,
316    ) -> Self {
317        Self {
318            ns,
319            parser,
320            path,
321            ack_timeout,
322            uid,
323        }
324    }
325}
326
327impl SocketEmitter for Emitter {
328    type AckError = crate::AckError;
329    type AckStream = AckInnerStream;
330
331    fn get_all_sids(&self, filter: impl Fn(&Sid) -> bool) -> Vec<Sid> {
332        self.ns
333            .upgrade()
334            .map(|ns| ns.get_all_sids(&filter))
335            .unwrap_or_default()
336    }
337    fn get_remote_sockets(&self, sids: BroadcastIter<'_>) -> Vec<RemoteSocketData> {
338        self.ns
339            .upgrade()
340            .map(|ns| ns.get_remote_sockets(sids, self.uid))
341            .unwrap_or_default()
342    }
343
344    fn send_many(&self, sids: BroadcastIter<'_>, data: Value) -> Result<(), Vec<SocketError>> {
345        match self.ns.upgrade() {
346            Some(ns) => ns.send_many(sids, data),
347            None => Ok(()),
348        }
349    }
350
351    fn send_many_with_ack(
352        &self,
353        sids: BroadcastIter<'_>,
354        packet: Packet,
355        timeout: Option<Duration>,
356    ) -> (Self::AckStream, u32) {
357        self.ns
358            .upgrade()
359            .map(|ns| ns.send_many_with_ack(sids, packet, timeout.unwrap_or(self.ack_timeout)))
360            .unwrap_or((AckInnerStream::empty(), 0))
361    }
362
363    fn disconnect_many(&self, sids: Vec<Sid>) -> Result<(), Vec<SocketError>> {
364        match self.ns.upgrade() {
365            Some(ns) => ns.disconnect_many(sids),
366            None => Ok(()),
367        }
368    }
369    fn parser(&self) -> impl Parse {
370        self.parser
371    }
372    fn server_id(&self) -> Uid {
373        self.uid
374    }
375    fn path(&self) -> &Str {
376        &self.path
377    }
378}
379
380#[doc(hidden)]
381#[cfg(feature = "__test_harness")]
382impl Namespace<crate::adapter::LocalAdapter> {
383    pub fn new_dummy<const S: usize>(sockets: [Sid; S]) -> Arc<Self> {
384        let ns = Namespace::new("/".into(), || {}, &(), &SocketIoConfig::default());
385        for sid in sockets {
386            ns.sockets
387                .write()
388                .unwrap()
389                .insert(sid, Socket::new_dummy(sid, ns.clone()).into());
390        }
391        ns
392    }
393
394    pub fn clean_dummy_sockets(&self) {
395        self.sockets.write().unwrap().clear();
396    }
397}
398
399impl<A: Adapter> std::fmt::Debug for Namespace<A> {
400    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
401        f.debug_struct("Namespace")
402            .field("path", &self.path)
403            .field("sockets", &self.sockets)
404            .finish()
405    }
406}
407
408#[cfg(feature = "tracing")]
409impl<A: Adapter> Drop for Namespace<A> {
410    fn drop(&mut self) {
411        #[cfg(feature = "tracing")]
412        tracing::debug!("dropping namespace {}", self.path);
413    }
414}