1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![warn(
3 clippy::all,
4 clippy::todo,
5 clippy::empty_enum,
6 clippy::mem_forget,
7 clippy::unused_self,
8 clippy::filter_map_next,
9 clippy::needless_continue,
10 clippy::needless_borrow,
11 clippy::match_wildcard_for_single_variants,
12 clippy::if_let_mutex,
13 clippy::await_holding_lock,
14 clippy::imprecise_flops,
15 clippy::suboptimal_flops,
16 clippy::lossy_float_literal,
17 clippy::rest_pat_in_fully_bound_structs,
18 clippy::fn_params_excessive_bools,
19 clippy::exit,
20 clippy::inefficient_to_string,
21 clippy::linkedlist,
22 clippy::macro_use_imports,
23 clippy::option_option,
24 clippy::verbose_file_reads,
25 clippy::unnested_or_patterns,
26 rust_2018_idioms,
27 rust_2024_compatibility,
28 future_incompatible,
29 nonstandard_style,
30 missing_docs
31)]
32
33use std::{
171 borrow::Cow,
172 collections::HashMap,
173 fmt,
174 future::{self, Future},
175 pin::Pin,
176 sync::{Arc, Mutex},
177 task::{Context, Poll},
178 time::Duration,
179};
180
181use drivers::{ChanItem, Driver, MessageStream};
182use futures_core::Stream;
183use futures_util::StreamExt;
184use serde::{Serialize, de::DeserializeOwned};
185use socketioxide_core::adapter::remote_packet::{
186 RequestIn, RequestOut, RequestTypeIn, RequestTypeOut, Response, ResponseType, ResponseTypeId,
187};
188use socketioxide_core::{
189 Sid, Uid,
190 adapter::errors::{AdapterError, BroadcastError},
191 adapter::{
192 BroadcastOptions, CoreAdapter, CoreLocalAdapter, DefinedAdapter, RemoteSocketData, Room,
193 RoomParam, SocketEmitter, Spawnable,
194 },
195 packet::Packet,
196};
197use stream::{AckStream, DropStream};
198use tokio::{sync::mpsc, time};
199
200pub mod drivers;
203
204mod stream;
205
206#[derive(thiserror::Error)]
208pub enum Error<R: Driver> {
209 #[error("driver error: {0}")]
211 Driver(R::Error),
212 #[error("packet encoding error: {0}")]
214 Decode(#[from] rmp_serde::decode::Error),
215 #[error("packet decoding error: {0}")]
217 Encode(#[from] rmp_serde::encode::Error),
218}
219
220impl<R: Driver> Error<R> {
221 fn from_driver(err: R::Error) -> Self {
222 Self::Driver(err)
223 }
224}
225impl<R: Driver> fmt::Debug for Error<R> {
226 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
227 match self {
228 Self::Driver(err) => write!(f, "Driver error: {err:?}"),
229 Self::Decode(err) => write!(f, "Decode error: {err:?}"),
230 Self::Encode(err) => write!(f, "Encode error: {err:?}"),
231 }
232 }
233}
234
235impl<R: Driver> From<Error<R>> for AdapterError {
236 fn from(err: Error<R>) -> Self {
237 AdapterError::from(Box::new(err) as Box<dyn std::error::Error + Send>)
238 }
239}
240
241#[derive(Debug, Clone)]
243pub struct RedisAdapterConfig {
244 pub request_timeout: Duration,
247
248 pub prefix: Cow<'static, str>,
250
251 pub ack_response_buffer: usize,
256
257 pub stream_buffer: usize,
261}
262impl RedisAdapterConfig {
263 pub fn new() -> Self {
265 Self::default()
266 }
267 pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
269 self.request_timeout = timeout;
270 self
271 }
272
273 pub fn with_prefix(mut self, prefix: impl Into<Cow<'static, str>>) -> Self {
275 self.prefix = prefix.into();
276 self
277 }
278
279 pub fn with_ack_response_buffer(mut self, buffer: usize) -> Self {
284 assert!(buffer > 0, "buffer size must be greater than 0");
285 self.ack_response_buffer = buffer;
286 self
287 }
288
289 pub fn with_stream_buffer(mut self, buffer: usize) -> Self {
293 assert!(buffer > 0, "buffer size must be greater than 0");
294 self.stream_buffer = buffer;
295 self
296 }
297}
298
299impl Default for RedisAdapterConfig {
300 fn default() -> Self {
301 Self {
302 request_timeout: Duration::from_secs(5),
303 prefix: Cow::Borrowed("socket.io"),
304 ack_response_buffer: 255,
305 stream_buffer: 1024,
306 }
307 }
308}
309
310#[derive(Debug)]
313pub struct RedisAdapterCtr<R> {
314 driver: R,
315 config: RedisAdapterConfig,
316}
317
318#[cfg(feature = "redis")]
319impl RedisAdapterCtr<drivers::redis::RedisDriver> {
320 #[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
322 pub async fn new_with_redis(client: &redis::Client) -> redis::RedisResult<Self> {
323 Self::new_with_redis_config(client, RedisAdapterConfig::default()).await
324 }
325 #[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
327 pub async fn new_with_redis_config(
328 client: &redis::Client,
329 config: RedisAdapterConfig,
330 ) -> redis::RedisResult<Self> {
331 let driver = drivers::redis::RedisDriver::new(client).await?;
332 Ok(Self::new_with_driver(driver, config))
333 }
334}
335#[cfg(feature = "redis-cluster")]
336impl RedisAdapterCtr<drivers::redis::ClusterDriver> {
337 #[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
339 pub async fn new_with_cluster(
340 client: &redis::cluster::ClusterClient,
341 ) -> redis::RedisResult<Self> {
342 Self::new_with_cluster_config(client, RedisAdapterConfig::default()).await
343 }
344
345 #[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
347 pub async fn new_with_cluster_config(
348 client: &redis::cluster::ClusterClient,
349 config: RedisAdapterConfig,
350 ) -> redis::RedisResult<Self> {
351 let driver = drivers::redis::ClusterDriver::new(client).await?;
352 Ok(Self::new_with_driver(driver, config))
353 }
354}
355#[cfg(feature = "fred")]
356impl RedisAdapterCtr<drivers::fred::FredDriver> {
357 #[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
359 pub async fn new_with_fred(
360 client: fred::clients::SubscriberClient,
361 ) -> fred::prelude::FredResult<Self> {
362 Self::new_with_fred_config(client, RedisAdapterConfig::default()).await
363 }
364 #[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
366 pub async fn new_with_fred_config(
367 client: fred::clients::SubscriberClient,
368 config: RedisAdapterConfig,
369 ) -> fred::prelude::FredResult<Self> {
370 let driver = drivers::fred::FredDriver::new(client).await?;
371 Ok(Self::new_with_driver(driver, config))
372 }
373}
374impl<R: Driver> RedisAdapterCtr<R> {
375 pub fn new_with_driver(driver: R, config: RedisAdapterConfig) -> RedisAdapterCtr<R> {
380 RedisAdapterCtr { driver, config }
381 }
382}
383
384pub(crate) type ResponseHandlers = HashMap<Sid, mpsc::Sender<Vec<u8>>>;
385
386#[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
388#[cfg(feature = "fred")]
389pub type FredAdapter<E> = CustomRedisAdapter<E, drivers::fred::FredDriver>;
390
391#[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
393#[cfg(feature = "redis")]
394pub type RedisAdapter<E> = CustomRedisAdapter<E, drivers::redis::RedisDriver>;
395
396#[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
398#[cfg(feature = "redis-cluster")]
399pub type ClusterAdapter<E> = CustomRedisAdapter<E, drivers::redis::ClusterDriver>;
400
401pub struct CustomRedisAdapter<E, R> {
406 driver: R,
409 config: RedisAdapterConfig,
411 uid: Uid,
413 local: CoreLocalAdapter<E>,
415 req_chan: String,
418 responses: Arc<Mutex<ResponseHandlers>>,
420}
421
422impl<E, R> DefinedAdapter for CustomRedisAdapter<E, R> {}
423impl<E: SocketEmitter, R: Driver> CoreAdapter<E> for CustomRedisAdapter<E, R> {
424 type Error = Error<R>;
425 type State = RedisAdapterCtr<R>;
426 type AckStream = AckStream<E::AckStream>;
427 type InitRes = InitRes<R>;
428
429 fn new(state: &Self::State, local: CoreLocalAdapter<E>) -> Self {
430 let req_chan = format!("{}-request#{}#", state.config.prefix, local.path());
431 let uid = local.server_id();
432 Self {
433 local,
434 req_chan,
435 uid,
436 driver: state.driver.clone(),
437 config: state.config.clone(),
438 responses: Arc::new(Mutex::new(HashMap::new())),
439 }
440 }
441
442 fn init(self: Arc<Self>, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes {
443 let fut = async move {
444 check_ns(self.local.path())?;
445 let global_stream = self.subscribe(self.req_chan.clone()).await?;
446 let specific_stream = self.subscribe(self.get_req_chan(Some(self.uid))).await?;
447 let response_chan = format!(
448 "{}-response#{}#{}#",
449 &self.config.prefix,
450 self.local.path(),
451 self.uid
452 );
453
454 let response_stream = self.subscribe(response_chan.clone()).await?;
455 let stream = futures_util::stream::select(global_stream, specific_stream);
456 let stream = futures_util::stream::select(stream, response_stream);
457 tokio::spawn(self.pipe_stream(stream, response_chan));
458 on_success();
459 Ok(())
460 };
461 InitRes(Box::pin(fut))
462 }
463
464 async fn close(&self) -> Result<(), Self::Error> {
465 let response_chan = format!(
466 "{}-response#{}#{}#",
467 &self.config.prefix,
468 self.local.path(),
469 self.uid
470 );
471 tokio::try_join!(
472 self.driver.unsubscribe(self.req_chan.clone()),
473 self.driver.unsubscribe(self.get_req_chan(Some(self.uid))),
474 self.driver.unsubscribe(response_chan)
475 )
476 .map_err(Error::from_driver)?;
477
478 Ok(())
479 }
480
481 async fn server_count(&self) -> Result<u16, Self::Error> {
483 let count = self
484 .driver
485 .num_serv(&self.req_chan)
486 .await
487 .map_err(Error::from_driver)?;
488
489 Ok(count)
490 }
491
492 async fn broadcast(
494 &self,
495 packet: Packet,
496 opts: BroadcastOptions,
497 ) -> Result<(), BroadcastError> {
498 if !opts.is_local(self.uid) {
499 let req = RequestOut::new(self.uid, RequestTypeOut::Broadcast(&packet), &opts);
500 self.send_req(req, opts.server_id)
501 .await
502 .map_err(AdapterError::from)?;
503 }
504
505 self.local.broadcast(packet, opts)?;
506 Ok(())
507 }
508
509 async fn broadcast_with_ack(
537 &self,
538 packet: Packet,
539 opts: BroadcastOptions,
540 timeout: Option<Duration>,
541 ) -> Result<Self::AckStream, Self::Error> {
542 if opts.is_local(self.uid) {
543 tracing::debug!(?opts, "broadcast with ack is local");
544 let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout);
545 let stream = AckStream::new_local(local);
546 return Ok(stream);
547 }
548 let req = RequestOut::new(self.uid, RequestTypeOut::BroadcastWithAck(&packet), &opts);
549 let req_id = req.id;
550
551 let remote_serv_cnt = self.server_count().await?.saturating_sub(1);
552
553 let (tx, rx) = mpsc::channel(self.config.ack_response_buffer + remote_serv_cnt as usize);
554 self.responses.lock().unwrap().insert(req_id, tx);
555 let remote = MessageStream::new(rx);
556
557 self.send_req(req, opts.server_id).await?;
558 let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout);
559
560 Ok(AckStream::new(
561 local,
562 remote,
563 self.config.request_timeout,
564 remote_serv_cnt,
565 req_id,
566 self.responses.clone(),
567 ))
568 }
569
570 async fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), BroadcastError> {
571 if !opts.is_local(self.uid) {
572 let req = RequestOut::new(self.uid, RequestTypeOut::DisconnectSockets, &opts);
573 self.send_req(req, opts.server_id)
574 .await
575 .map_err(AdapterError::from)?;
576 }
577 self.local
578 .disconnect_socket(opts)
579 .map_err(BroadcastError::Socket)?;
580
581 Ok(())
582 }
583
584 async fn rooms(&self, opts: BroadcastOptions) -> Result<Vec<Room>, Self::Error> {
585 if opts.is_local(self.uid) {
586 return Ok(self.local.rooms(opts).into_iter().collect());
587 }
588 let req = RequestOut::new(self.uid, RequestTypeOut::AllRooms, &opts);
589 let req_id = req.id;
590
591 let stream = self
594 .get_res::<()>(req_id, ResponseTypeId::AllRooms, opts.server_id)
595 .await?;
596 self.send_req(req, opts.server_id).await?;
597 let local = self.local.rooms(opts);
598 let rooms = stream
599 .filter_map(|item| future::ready(item.into_rooms()))
600 .fold(local, async |mut acc, item| {
601 acc.extend(item);
602 acc
603 })
604 .await;
605 Ok(Vec::from_iter(rooms))
606 }
607
608 async fn add_sockets(
609 &self,
610 opts: BroadcastOptions,
611 rooms: impl RoomParam,
612 ) -> Result<(), Self::Error> {
613 let rooms: Vec<Room> = rooms.into_room_iter().collect();
614 if !opts.is_local(self.uid) {
615 let req = RequestOut::new(self.uid, RequestTypeOut::AddSockets(&rooms), &opts);
616 self.send_req(req, opts.server_id).await?;
617 }
618 self.local.add_sockets(opts, rooms);
619 Ok(())
620 }
621
622 async fn del_sockets(
623 &self,
624 opts: BroadcastOptions,
625 rooms: impl RoomParam,
626 ) -> Result<(), Self::Error> {
627 let rooms: Vec<Room> = rooms.into_room_iter().collect();
628 if !opts.is_local(self.uid) {
629 let req = RequestOut::new(self.uid, RequestTypeOut::DelSockets(&rooms), &opts);
630 self.send_req(req, opts.server_id).await?;
631 }
632 self.local.del_sockets(opts, rooms);
633 Ok(())
634 }
635
636 async fn fetch_sockets(
637 &self,
638 opts: BroadcastOptions,
639 ) -> Result<Vec<RemoteSocketData>, Self::Error> {
640 if opts.is_local(self.uid) {
641 return Ok(self.local.fetch_sockets(opts));
642 }
643 let req = RequestOut::new(self.uid, RequestTypeOut::FetchSockets, &opts);
644 let req_id = req.id;
645 let remote = self
648 .get_res::<RemoteSocketData>(req_id, ResponseTypeId::FetchSockets, opts.server_id)
649 .await?;
650
651 self.send_req(req, opts.server_id).await?;
652 let local = self.local.fetch_sockets(opts);
653 let sockets = remote
654 .filter_map(|item| future::ready(item.into_fetch_sockets()))
655 .fold(local, async |mut acc, item| {
656 acc.extend(item);
657 acc
658 })
659 .await;
660 Ok(sockets)
661 }
662
663 fn get_local(&self) -> &CoreLocalAdapter<E> {
664 &self.local
665 }
666}
667
668#[derive(thiserror::Error)]
670pub enum InitError<D: Driver> {
671 #[error("driver error: {0}")]
673 Driver(D::Error),
674 #[error("malformed namespace path, it must not contain '#'")]
676 MalformedNamespace,
677}
678impl<D: Driver> fmt::Debug for InitError<D> {
679 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
680 match self {
681 Self::Driver(err) => fmt::Debug::fmt(err, f),
682 Self::MalformedNamespace => write!(f, "Malformed namespace path"),
683 }
684 }
685}
686#[must_use = "futures do nothing unless you `.await` or poll them"]
688pub struct InitRes<D: Driver>(futures_core::future::BoxFuture<'static, Result<(), InitError<D>>>);
689
690impl<D: Driver> Future for InitRes<D> {
691 type Output = Result<(), InitError<D>>;
692
693 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
694 self.0.as_mut().poll(cx)
695 }
696}
697impl<D: Driver> Spawnable for InitRes<D> {
698 fn spawn(self) {
699 tokio::spawn(async move {
700 if let Err(e) = self.0.await {
701 tracing::error!("error initializing adapter: {e}");
702 }
703 });
704 }
705}
706
707impl<E: SocketEmitter, R: Driver> CustomRedisAdapter<E, R> {
708 fn get_res_chan(&self, uid: Uid) -> String {
713 let path = self.local.path();
714 let prefix = &self.config.prefix;
715 format!("{prefix}-response#{path}#{uid}#")
716 }
717 fn get_req_chan(&self, node_id: Option<Uid>) -> String {
722 match node_id {
723 Some(uid) => format!("{}{}#", self.req_chan, uid),
724 None => self.req_chan.clone(),
725 }
726 }
727
728 async fn pipe_stream(
729 self: Arc<Self>,
730 mut stream: impl Stream<Item = ChanItem> + Unpin,
731 response_chan: String,
732 ) {
733 while let Some((chan, item)) = stream.next().await {
734 if chan.starts_with(&self.req_chan) {
735 if let Err(e) = self.recv_req(item) {
736 let ns = self.local.path();
737 let uid = self.uid;
738 tracing::warn!(?uid, ?ns, "request handler error: {e}");
739 }
740 } else if chan == response_chan {
741 let req_id = read_req_id(&item);
742 tracing::trace!(?req_id, ?chan, ?response_chan, "extracted sid");
743 let handlers = self.responses.lock().unwrap();
744 if let Some(tx) = req_id.and_then(|id| handlers.get(&id)) {
745 if let Err(e) = tx.try_send(item) {
746 tracing::warn!("error sending response to handler: {e}");
747 }
748 } else {
749 tracing::warn!(?req_id, "could not find req handler");
750 }
751 } else {
752 tracing::warn!("unexpected message/channel: {chan}");
753 }
754 }
755 }
756
757 fn recv_req(self: &Arc<Self>, item: Vec<u8>) -> Result<(), Error<R>> {
759 let req: RequestIn = rmp_serde::from_slice(&item)?;
760 if req.node_id == self.uid {
761 return Ok(());
762 }
763
764 tracing::trace!(?req, "handling request");
765 let Some(opts) = req.opts else {
766 tracing::warn!(?req, "request is missing options");
767 return Ok(());
768 };
769
770 match req.r#type {
771 RequestTypeIn::Broadcast(p) => self.recv_broadcast(opts, p),
772 RequestTypeIn::BroadcastWithAck(p) => {
773 self.clone()
774 .recv_broadcast_with_ack(req.node_id, req.id, p, opts)
775 }
776 RequestTypeIn::DisconnectSockets => self.recv_disconnect_sockets(opts),
777 RequestTypeIn::AllRooms => self.recv_rooms(req.node_id, req.id, opts),
778 RequestTypeIn::AddSockets(rooms) => self.recv_add_sockets(opts, rooms),
779 RequestTypeIn::DelSockets(rooms) => self.recv_del_sockets(opts, rooms),
780 RequestTypeIn::FetchSockets => self.recv_fetch_sockets(req.node_id, req.id, opts),
781 _ => (),
782 };
783 Ok(())
784 }
785
786 fn recv_broadcast(&self, opts: BroadcastOptions, packet: Packet) {
787 if let Err(e) = self.local.broadcast(packet, opts) {
788 let ns = self.local.path();
789 tracing::warn!(?self.uid, ?ns, "remote request broadcast handler: {:?}", e);
790 }
791 }
792
793 fn recv_disconnect_sockets(&self, opts: BroadcastOptions) {
794 if let Err(e) = self.local.disconnect_socket(opts) {
795 let ns = self.local.path();
796 tracing::warn!(
797 ?self.uid,
798 ?ns,
799 "remote request disconnect sockets handler: {:?}",
800 e
801 );
802 }
803 }
804
805 fn recv_broadcast_with_ack(
806 self: Arc<Self>,
807 origin: Uid,
808 req_id: Sid,
809 packet: Packet,
810 opts: BroadcastOptions,
811 ) {
812 let (stream, count) = self.local.broadcast_with_ack(packet, opts, None);
813 tokio::spawn(async move {
814 let on_err = |err| {
815 let ns = self.local.path();
816 tracing::warn!(
817 ?origin,
818 ?ns,
819 "remote request broadcast with ack handler errors: {:?}",
820 err
821 );
822 };
823 let res = Response {
826 r#type: ResponseType::<()>::BroadcastAckCount(count),
827 node_id: self.uid,
828 };
829 if let Err(err) = self.send_res(origin, req_id, res).await {
830 on_err(err);
831 return;
832 }
833
834 futures_util::pin_mut!(stream);
836 while let Some(ack) = stream.next().await {
837 let res = Response {
838 r#type: ResponseType::BroadcastAck(ack),
839 node_id: self.uid,
840 };
841 if let Err(err) = self.send_res(origin, req_id, res).await {
842 on_err(err);
843 return;
844 }
845 }
846 });
847 }
848
849 fn recv_rooms(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) {
850 let rooms = self.local.rooms(opts);
851 let res = Response {
852 r#type: ResponseType::<()>::AllRooms(rooms),
853 node_id: self.uid,
854 };
855 let fut = self.send_res(origin, req_id, res);
856 let ns = self.local.path().clone();
857 let uid = self.uid;
858 tokio::spawn(async move {
859 if let Err(err) = fut.await {
860 tracing::warn!(?uid, ?ns, "remote request rooms handler: {:?}", err);
861 }
862 });
863 }
864
865 fn recv_add_sockets(&self, opts: BroadcastOptions, rooms: Vec<Room>) {
866 self.local.add_sockets(opts, rooms);
867 }
868
869 fn recv_del_sockets(&self, opts: BroadcastOptions, rooms: Vec<Room>) {
870 self.local.del_sockets(opts, rooms);
871 }
872 fn recv_fetch_sockets(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) {
873 let sockets = self.local.fetch_sockets(opts);
874 let res = Response {
875 node_id: self.uid,
876 r#type: ResponseType::FetchSockets(sockets),
877 };
878 let fut = self.send_res(origin, req_id, res);
879 let ns = self.local.path().clone();
880 let uid = self.uid;
881 tokio::spawn(async move {
882 if let Err(err) = fut.await {
883 tracing::warn!(?uid, ?ns, "remote request fetch sockets handler: {:?}", err);
884 }
885 });
886 }
887
888 async fn send_req(&self, req: RequestOut<'_>, target_uid: Option<Uid>) -> Result<(), Error<R>> {
889 tracing::trace!(?req, "sending request");
890 let req = rmp_serde::to_vec(&req)?;
891 let chan = self.get_req_chan(target_uid);
892 self.driver
893 .publish(chan, req)
894 .await
895 .map_err(Error::from_driver)?;
896
897 Ok(())
898 }
899
900 fn send_res<D: Serialize + fmt::Debug>(
901 &self,
902 req_node_id: Uid,
903 req_id: Sid,
904 res: Response<D>,
905 ) -> impl Future<Output = Result<(), Error<R>>> + Send + 'static {
906 let chan = self.get_res_chan(req_node_id);
907 tracing::trace!(?res, "sending response to {}", &chan);
908 let res = rmp_serde::to_vec(&(req_id, res));
912 let driver = self.driver.clone();
913 async move {
914 driver
915 .publish(chan, res?)
916 .await
917 .map_err(Error::from_driver)?;
918 Ok(())
919 }
920 }
921
922 async fn get_res<D: DeserializeOwned + fmt::Debug>(
924 &self,
925 req_id: Sid,
926 response_type: ResponseTypeId,
927 target_uid: Option<Uid>,
928 ) -> Result<impl Stream<Item = Response<D>>, Error<R>> {
929 let remote_serv_cnt = if target_uid.is_none() {
931 self.server_count().await?.saturating_sub(1) as usize
932 } else {
933 1
934 };
935 let (tx, rx) = mpsc::channel(std::cmp::max(remote_serv_cnt, 1));
936 self.responses.lock().unwrap().insert(req_id, tx);
937 let stream = MessageStream::new(rx)
938 .filter_map(|item| {
939 let data = match rmp_serde::from_slice::<(Sid, Response<D>)>(&item) {
940 Ok((_, data)) => Some(data),
941 Err(e) => {
942 tracing::warn!("error decoding response: {e}");
943 None
944 }
945 };
946 future::ready(data)
947 })
948 .filter(move |item| future::ready(ResponseTypeId::from(&item.r#type) == response_type))
949 .take(remote_serv_cnt)
950 .take_until(time::sleep(self.config.request_timeout));
951 let stream = DropStream::new(stream, self.responses.clone(), req_id);
952 Ok(stream)
953 }
954
955 #[inline]
957 async fn subscribe(&self, pat: String) -> Result<MessageStream<ChanItem>, InitError<R>> {
958 tracing::trace!(?pat, "subscribing to");
959 self.driver
960 .subscribe(pat, self.config.stream_buffer)
961 .await
962 .map_err(InitError::Driver)
963 }
964}
965
966fn check_ns<D: Driver>(path: &str) -> Result<(), InitError<D>> {
969 if path.is_empty() || path.contains('#') {
970 Err(InitError::MalformedNamespace)
971 } else {
972 Ok(())
973 }
974}
975
976pub fn read_req_id(data: &[u8]) -> Option<Sid> {
978 use std::str::FromStr;
979 let mut rd = data;
980 let len = rmp::decode::read_array_len(&mut rd).ok()?;
981 if len < 1 {
982 return None;
983 }
984
985 let mut buff = [0u8; Sid::ZERO.as_str().len()];
986 let str = rmp::decode::read_str(&mut rd, &mut buff).ok()?;
987 Sid::from_str(str).ok()
988}
989
990#[cfg(test)]
991mod tests {
992 use super::*;
993 use futures_util::stream::{self, FusedStream, StreamExt};
994 use socketioxide_core::{Str, Value, adapter::AckStreamItem};
995 use std::convert::Infallible;
996
997 #[derive(Clone)]
998 struct StubDriver;
999 impl Driver for StubDriver {
1000 type Error = Infallible;
1001
1002 async fn publish(&self, _: String, _: Vec<u8>) -> Result<(), Self::Error> {
1003 Ok(())
1004 }
1005
1006 async fn subscribe(
1007 &self,
1008 _: String,
1009 _: usize,
1010 ) -> Result<MessageStream<ChanItem>, Self::Error> {
1011 Ok(MessageStream::new_empty())
1012 }
1013
1014 async fn unsubscribe(&self, _: String) -> Result<(), Self::Error> {
1015 Ok(())
1016 }
1017
1018 async fn num_serv(&self, _: &str) -> Result<u16, Self::Error> {
1019 Ok(0)
1020 }
1021 }
1022 fn new_stub_ack_stream(
1023 remote: MessageStream<Vec<u8>>,
1024 timeout: Duration,
1025 ) -> AckStream<stream::Empty<AckStreamItem<()>>> {
1026 AckStream::new(
1027 stream::empty::<AckStreamItem<()>>(),
1028 remote,
1029 timeout,
1030 2,
1031 Sid::new(),
1032 Arc::new(Mutex::new(HashMap::new())),
1033 )
1034 }
1035
1036 #[tokio::test]
1038 async fn ack_stream() {
1039 let (tx, rx) = tokio::sync::mpsc::channel(255);
1040 let remote = MessageStream::new(rx);
1041 let stream = new_stub_ack_stream(remote, Duration::from_secs(10));
1042 let node_id = Uid::new();
1043 let req_id = Sid::new();
1044
1045 let ack_cnt_res = Response::<()> {
1047 node_id,
1048 r#type: ResponseType::BroadcastAckCount(2),
1049 };
1050 tx.try_send(rmp_serde::to_vec(&(req_id, &ack_cnt_res)).unwrap())
1051 .unwrap();
1052 tx.try_send(rmp_serde::to_vec(&(req_id, &ack_cnt_res)).unwrap())
1053 .unwrap();
1054
1055 let ack_res = Response::<String> {
1056 node_id,
1057 r#type: ResponseType::BroadcastAck((Sid::new(), Ok(Value::Str(Str::from(""), None)))),
1058 };
1059 for _ in 0..4 {
1060 tx.try_send(rmp_serde::to_vec(&(req_id, &ack_res)).unwrap())
1061 .unwrap();
1062 }
1063 futures_util::pin_mut!(stream);
1064 for _ in 0..4 {
1065 assert!(stream.next().await.is_some());
1066 }
1067 assert!(stream.is_terminated());
1068 }
1069
1070 #[tokio::test]
1071 async fn ack_stream_timeout() {
1072 let (tx, rx) = tokio::sync::mpsc::channel(255);
1073 let remote = MessageStream::new(rx);
1074 let stream = new_stub_ack_stream(remote, Duration::from_millis(50));
1075 let node_id = Uid::new();
1076 let req_id = Sid::new();
1077 let ack_cnt_res = Response::<()> {
1079 node_id,
1080 r#type: ResponseType::BroadcastAckCount(2),
1081 };
1082 tx.try_send(rmp_serde::to_vec(&(req_id, ack_cnt_res)).unwrap())
1083 .unwrap();
1084
1085 futures_util::pin_mut!(stream);
1086 tokio::time::sleep(Duration::from_millis(50)).await;
1087 assert!(stream.next().await.is_none());
1088 assert!(stream.is_terminated());
1089 }
1090
1091 #[tokio::test]
1092 async fn ack_stream_drop() {
1093 let (tx, rx) = tokio::sync::mpsc::channel(255);
1094 let remote = MessageStream::new(rx);
1095 let handlers = Arc::new(Mutex::new(HashMap::new()));
1096 let id = Sid::new();
1097 handlers.lock().unwrap().insert(id, tx);
1098 let stream = AckStream::new(
1099 stream::empty::<AckStreamItem<()>>(),
1100 remote,
1101 Duration::from_secs(10),
1102 2,
1103 id,
1104 handlers.clone(),
1105 );
1106 drop(stream);
1107 assert!(handlers.lock().unwrap().is_empty(),);
1108 }
1109
1110 #[test]
1111 fn check_ns_error() {
1112 assert!(matches!(
1113 check_ns::<StubDriver>("#"),
1114 Err(InitError::MalformedNamespace)
1115 ));
1116 assert!(matches!(
1117 check_ns::<StubDriver>(""),
1118 Err(InitError::MalformedNamespace)
1119 ));
1120 }
1121}