1use std::{
2 collections::HashMap,
3 sync::{Arc, RwLock, Weak},
4 time::Duration,
5};
6
7use crate::{
8 ProtocolVersion, SocketError, SocketIoConfig,
9 ack::AckInnerStream,
10 adapter::Adapter,
11 client::SocketData,
12 errors::{ConnectFail, Error},
13 handler::{BoxedConnectHandler, ConnectHandler, MakeErasedHandler},
14 parser::Parser,
15 socket::{DisconnectReason, Socket},
16};
17use socketioxide_core::{
18 Sid, Str, Uid, Value,
19 adapter::{BroadcastIter, CoreLocalAdapter, RemoteSocketData, SocketEmitter},
20 packet::{ConnectPacket, Packet, PacketData},
21 parser::Parse,
22};
23
24pub struct NamespaceCtr<A: Adapter> {
28 handler: BoxedConnectHandler<A>,
29}
30pub struct Namespace<A: Adapter> {
31 pub path: Str,
32 pub(crate) adapter: Arc<A>,
33 parser: Parser,
34 handler: BoxedConnectHandler<A>,
35 sockets: RwLock<HashMap<Sid, Arc<Socket<A>>>>,
36}
37
38impl<A: Adapter> NamespaceCtr<A> {
40 pub fn new<C, T>(handler: C) -> Self
41 where
42 C: ConnectHandler<A, T> + Send + Sync + 'static,
43 T: Send + Sync + 'static,
44 {
45 Self {
46 handler: MakeErasedHandler::new_ns_boxed(handler),
47 }
48 }
49 pub fn get_new_ns(
50 &self,
51 path: Str,
52 adapter_state: &A::State,
53 config: &SocketIoConfig,
54 ) -> Arc<Namespace<A>> {
55 let handler = self.handler.boxed_clone();
56 Namespace::new_boxed(path, handler, adapter_state, config)
57 }
58}
59
60impl<A: Adapter> Namespace<A> {
61 pub(crate) fn new<C, T>(
62 path: Str,
63 handler: C,
64 adapter_state: &A::State,
65 config: &SocketIoConfig,
66 ) -> Arc<Self>
67 where
68 C: ConnectHandler<A, T> + Send + Sync + 'static,
69 T: Send + Sync + 'static,
70 {
71 let handler = MakeErasedHandler::new_ns_boxed(handler);
72 Self::new_boxed(path, handler, adapter_state, config)
73 }
74
75 fn new_boxed(
76 path: Str,
77 handler: BoxedConnectHandler<A>,
78 adapter_state: &A::State,
79 config: &SocketIoConfig,
80 ) -> Arc<Self> {
81 let parser = config.parser;
82 let ack_timeout = config.ack_timeout;
83 let uid = config.server_id;
84 Arc::new_cyclic(|ns| Self {
85 path: path.clone(),
86 handler,
87 parser,
88 sockets: HashMap::new().into(),
89 adapter: Arc::new(A::new(
90 adapter_state,
91 CoreLocalAdapter::new(Emitter::new(ns.clone(), parser, path, ack_timeout, uid)),
92 )),
93 })
94 }
95
96 pub(crate) async fn connect(
102 self: Arc<Self>,
103 sid: Sid,
104 esocket: Arc<engineioxide::Socket<SocketData<A>>>,
105 auth: Option<Value>,
106 ) -> Result<(), ConnectFail> {
107 let socket: Arc<Socket<A>> =
108 Socket::new(sid, self.clone(), esocket.clone(), self.parser).into();
109
110 if let Err(e) = self.handler.call_middleware(socket.clone(), &auth).await {
111 #[cfg(feature = "tracing")]
112 tracing::trace!(ns = self.path.as_str(), ?socket.id, "emitting connect_error packet");
113
114 let data = e.to_string();
115 if let Err(_e) = socket.send(Packet::connect_error(self.path.clone(), data)) {
116 #[cfg(feature = "tracing")]
117 tracing::debug!("error sending connect_error packet: {:?}, closing conn", _e);
118 esocket.close(engineioxide::DisconnectReason::PacketParsingError);
119 }
120 return Err(ConnectFail);
121 }
122
123 self.sockets.write().unwrap().insert(sid, socket.clone());
124 #[cfg(feature = "tracing")]
125 tracing::trace!(?socket.id, ?self.path, "socket added to namespace");
126
127 let protocol = esocket.protocol.into();
128 let payload = ConnectPacket { sid: socket.id };
129 let payload = match protocol {
130 ProtocolVersion::V5 => Some(self.parser.encode_default(&payload).unwrap()),
131 ProtocolVersion::V4 => None,
132 };
133 if let Err(_e) = socket.send(Packet::connect(self.path.clone(), payload)) {
134 #[cfg(feature = "tracing")]
135 tracing::debug!("error sending connect packet: {:?}, closing conn", _e);
136 esocket.close(engineioxide::DisconnectReason::PacketParsingError);
137 return Err(ConnectFail);
138 }
139
140 socket.set_connected(true);
141 self.handler.call(socket, auth);
142
143 Ok(())
144 }
145
146 pub fn remove_socket(&self, sid: Sid) {
148 #[cfg(feature = "tracing")]
149 tracing::trace!(?sid, ?self.path, "removing socket from namespace");
150
151 self.sockets.write().unwrap().remove(&sid);
152 self.adapter.get_local().del_all(sid);
153 }
154
155 pub fn has(&self, sid: Sid) -> bool {
156 self.sockets.read().unwrap().values().any(|s| s.id == sid)
157 }
158
159 pub fn recv(&self, sid: Sid, packet: PacketData) -> Result<(), Error> {
160 match packet {
161 PacketData::Connect(_) => unreachable!("connect packets should be handled before"),
162 PacketData::ConnectError(_) => Err(Error::InvalidPacketType),
163 packet => self.get_socket(sid)?.recv(packet),
164 }
165 }
166
167 pub fn get_socket(&self, sid: Sid) -> Result<Arc<Socket<A>>, Error> {
168 self.sockets
169 .read()
170 .unwrap()
171 .get(&sid)
172 .cloned()
173 .ok_or(Error::SocketGone(sid))
174 }
175
176 pub fn get_sockets(&self) -> Vec<Arc<Socket<A>>> {
177 self.sockets.read().unwrap().values().cloned().collect()
178 }
179
180 pub async fn close(&self, reason: DisconnectReason) {
188 use futures_util::future;
189 let sockets = self.sockets.read().unwrap().clone();
190
191 #[cfg(feature = "tracing")]
192 tracing::debug!(?self.path, "closing {} sockets in namespace", sockets.len());
193
194 if reason == DisconnectReason::ClosingServer {
195 future::join_all(sockets.values().map(|s| s.close_underlying_transport())).await;
198 } else {
199 for s in sockets.into_values() {
200 let _sid = s.id;
201 s.close(reason);
202 }
203 }
204 #[cfg(feature = "tracing")]
205 tracing::debug!(?self.path, "all sockets in namespace closed");
206
207 let _err = self.adapter.close().await;
208 #[cfg(feature = "tracing")]
209 if let Err(err) = _err {
210 tracing::debug!(?err, ?self.path, "could not close adapter");
211 }
212 }
213}
214
215trait InnerEmitter: Send + Sync + 'static {
218 fn get_remote_sockets(&self, sids: BroadcastIter<'_>, uid: Uid) -> Vec<RemoteSocketData>;
220 fn get_all_sids(&self, filter: &dyn Fn(&Sid) -> bool) -> Vec<Sid>;
222 fn send_many(&self, sids: BroadcastIter<'_>, data: Value) -> Result<(), Vec<SocketError>>;
224 fn send_many_with_ack(
226 &self,
227 sids: BroadcastIter<'_>,
228 packet: Packet,
229 timeout: Duration,
230 ) -> (AckInnerStream, u32);
231 fn disconnect_many(&self, sids: Vec<Sid>) -> Result<(), Vec<SocketError>>;
233}
234
235impl<A: Adapter> InnerEmitter for Namespace<A> {
236 fn get_remote_sockets(&self, sids: BroadcastIter<'_>, uid: Uid) -> Vec<RemoteSocketData> {
237 let sockets = self.sockets.read().unwrap();
238 sids.filter_map(|sid| sockets.get(&sid))
239 .map(|socket| RemoteSocketData {
240 id: socket.id,
241 ns: self.path.clone(),
242 server_id: uid,
243 })
244 .collect()
245 }
246 fn get_all_sids(&self, filter: &dyn Fn(&Sid) -> bool) -> Vec<Sid> {
247 self.sockets
248 .read()
249 .unwrap()
250 .keys()
251 .filter(|id| filter(id))
252 .copied()
253 .collect()
254 }
255
256 fn send_many(&self, sids: BroadcastIter<'_>, data: Value) -> Result<(), Vec<SocketError>> {
257 let sockets = self.sockets.read().unwrap();
258 let errs: Vec<SocketError> = sids
259 .filter_map(|sid| sockets.get(&sid))
260 .filter_map(|socket| socket.send_raw(data.clone()).err())
261 .collect();
262 if errs.is_empty() { Ok(()) } else { Err(errs) }
263 }
264
265 fn send_many_with_ack(
266 &self,
267 sids: BroadcastIter<'_>,
268 packet: Packet,
269 timeout: Duration,
270 ) -> (AckInnerStream, u32) {
271 let sockets_map = self.sockets.read().unwrap();
272 let sockets = sids.filter_map(|sid| sockets_map.get(&sid));
273 AckInnerStream::broadcast(packet, sockets, timeout)
274 }
275
276 fn disconnect_many(&self, sids: Vec<Sid>) -> Result<(), Vec<SocketError>> {
277 if sids.is_empty() {
278 return Ok(());
279 }
280 let sockets = {
283 let sock_map = self.sockets.read().unwrap();
284 sids.into_iter()
285 .filter_map(|sid| sock_map.get(&sid))
286 .cloned()
287 .collect::<Vec<_>>()
288 };
289
290 let errs = sockets
291 .into_iter()
292 .filter_map(|socket| socket.disconnect().err())
293 .collect::<Vec<_>>();
294 if errs.is_empty() { Ok(()) } else { Err(errs) }
295 }
296}
297
298#[doc(hidden)]
300pub struct Emitter {
301 ns: Weak<dyn InnerEmitter>,
303 parser: Parser,
304 path: Str,
305 ack_timeout: Duration,
306 uid: Uid,
307}
308
309impl Emitter {
310 fn new<A: Adapter>(
311 ns: Weak<Namespace<A>>,
312 parser: Parser,
313 path: Str,
314 ack_timeout: Duration,
315 uid: Uid,
316 ) -> Self {
317 Self {
318 ns,
319 parser,
320 path,
321 ack_timeout,
322 uid,
323 }
324 }
325}
326
327impl SocketEmitter for Emitter {
328 type AckError = crate::AckError;
329 type AckStream = AckInnerStream;
330
331 fn get_all_sids(&self, filter: impl Fn(&Sid) -> bool) -> Vec<Sid> {
332 self.ns
333 .upgrade()
334 .map(|ns| ns.get_all_sids(&filter))
335 .unwrap_or_default()
336 }
337 fn get_remote_sockets(&self, sids: BroadcastIter<'_>) -> Vec<RemoteSocketData> {
338 self.ns
339 .upgrade()
340 .map(|ns| ns.get_remote_sockets(sids, self.uid))
341 .unwrap_or_default()
342 }
343
344 fn send_many(&self, sids: BroadcastIter<'_>, data: Value) -> Result<(), Vec<SocketError>> {
345 match self.ns.upgrade() {
346 Some(ns) => ns.send_many(sids, data),
347 None => Ok(()),
348 }
349 }
350
351 fn send_many_with_ack(
352 &self,
353 sids: BroadcastIter<'_>,
354 packet: Packet,
355 timeout: Option<Duration>,
356 ) -> (Self::AckStream, u32) {
357 self.ns
358 .upgrade()
359 .map(|ns| ns.send_many_with_ack(sids, packet, timeout.unwrap_or(self.ack_timeout)))
360 .unwrap_or((AckInnerStream::empty(), 0))
361 }
362
363 fn disconnect_many(&self, sids: Vec<Sid>) -> Result<(), Vec<SocketError>> {
364 match self.ns.upgrade() {
365 Some(ns) => ns.disconnect_many(sids),
366 None => Ok(()),
367 }
368 }
369 fn parser(&self) -> impl Parse {
370 self.parser
371 }
372 fn server_id(&self) -> Uid {
373 self.uid
374 }
375 fn path(&self) -> &Str {
376 &self.path
377 }
378}
379
380#[doc(hidden)]
381#[cfg(feature = "__test_harness")]
382impl Namespace<crate::adapter::LocalAdapter> {
383 pub fn new_dummy<const S: usize>(sockets: [Sid; S]) -> Arc<Self> {
384 let ns = Namespace::new("/".into(), || {}, &(), &SocketIoConfig::default());
385 for sid in sockets {
386 ns.sockets
387 .write()
388 .unwrap()
389 .insert(sid, Socket::new_dummy(sid, ns.clone()).into());
390 }
391 ns
392 }
393
394 pub fn clean_dummy_sockets(&self) {
395 self.sockets.write().unwrap().clear();
396 }
397}
398
399impl<A: Adapter> std::fmt::Debug for Namespace<A> {
400 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
401 f.debug_struct("Namespace")
402 .field("path", &self.path)
403 .field("sockets", &self.sockets)
404 .finish()
405 }
406}
407
408#[cfg(feature = "tracing")]
409impl<A: Adapter> Drop for Namespace<A> {
410 fn drop(&mut self) {
411 #[cfg(feature = "tracing")]
412 tracing::debug!("dropping namespace {}", self.path);
413 }
414}