#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
use std::{
borrow::Cow,
collections::HashMap,
fmt,
future::{self, Future},
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
time::Duration,
};
use drivers::{ChanItem, Driver, MessageStream};
use futures_core::Stream;
use futures_util::StreamExt;
use serde::{Serialize, de::DeserializeOwned};
use socketioxide_core::adapter::remote_packet::{
RequestIn, RequestOut, RequestTypeIn, RequestTypeOut, Response, ResponseType, ResponseTypeId,
};
use socketioxide_core::{
Sid, Uid,
adapter::errors::{AdapterError, BroadcastError},
adapter::{
BroadcastOptions, CoreAdapter, CoreLocalAdapter, DefinedAdapter, RemoteSocketData, Room,
RoomParam, SocketEmitter, Spawnable,
},
packet::Packet,
};
use stream::{AckStream, DropStream};
use tokio::{sync::mpsc, time};
pub mod drivers;
mod stream;
#[derive(thiserror::Error)]
pub enum Error<R: Driver> {
#[error("driver error: {0}")]
Driver(R::Error),
#[error("packet encoding error: {0}")]
Decode(#[from] rmp_serde::decode::Error),
#[error("packet decoding error: {0}")]
Encode(#[from] rmp_serde::encode::Error),
}
impl<R: Driver> Error<R> {
fn from_driver(err: R::Error) -> Self {
Self::Driver(err)
}
}
impl<R: Driver> fmt::Debug for Error<R> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Driver(err) => write!(f, "Driver error: {err:?}"),
Self::Decode(err) => write!(f, "Decode error: {err:?}"),
Self::Encode(err) => write!(f, "Encode error: {err:?}"),
}
}
}
impl<R: Driver> From<Error<R>> for AdapterError {
fn from(err: Error<R>) -> Self {
AdapterError::from(Box::new(err) as Box<dyn std::error::Error + Send>)
}
}
#[derive(Debug, Clone)]
pub struct RedisAdapterConfig {
pub request_timeout: Duration,
pub prefix: Cow<'static, str>,
pub ack_response_buffer: usize,
pub stream_buffer: usize,
}
impl RedisAdapterConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
self.request_timeout = timeout;
self
}
pub fn with_prefix(mut self, prefix: impl Into<Cow<'static, str>>) -> Self {
self.prefix = prefix.into();
self
}
pub fn with_ack_response_buffer(mut self, buffer: usize) -> Self {
assert!(buffer > 0, "buffer size must be greater than 0");
self.ack_response_buffer = buffer;
self
}
pub fn with_stream_buffer(mut self, buffer: usize) -> Self {
assert!(buffer > 0, "buffer size must be greater than 0");
self.stream_buffer = buffer;
self
}
}
impl Default for RedisAdapterConfig {
fn default() -> Self {
Self {
request_timeout: Duration::from_secs(5),
prefix: Cow::Borrowed("socket.io"),
ack_response_buffer: 255,
stream_buffer: 1024,
}
}
}
#[derive(Debug)]
pub struct RedisAdapterCtr<R> {
driver: R,
config: RedisAdapterConfig,
}
#[cfg(feature = "redis")]
impl RedisAdapterCtr<drivers::redis::RedisDriver> {
#[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
pub async fn new_with_redis(client: &redis::Client) -> redis::RedisResult<Self> {
Self::new_with_redis_config(client, RedisAdapterConfig::default()).await
}
#[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
pub async fn new_with_redis_config(
client: &redis::Client,
config: RedisAdapterConfig,
) -> redis::RedisResult<Self> {
let driver = drivers::redis::RedisDriver::new(client).await?;
Ok(Self::new_with_driver(driver, config))
}
}
#[cfg(feature = "redis-cluster")]
impl RedisAdapterCtr<drivers::redis::ClusterDriver> {
#[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
pub async fn new_with_cluster(
client: &redis::cluster::ClusterClient,
) -> redis::RedisResult<Self> {
Self::new_with_cluster_config(client, RedisAdapterConfig::default()).await
}
#[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
pub async fn new_with_cluster_config(
client: &redis::cluster::ClusterClient,
config: RedisAdapterConfig,
) -> redis::RedisResult<Self> {
let driver = drivers::redis::ClusterDriver::new(client).await?;
Ok(Self::new_with_driver(driver, config))
}
}
#[cfg(feature = "fred")]
impl RedisAdapterCtr<drivers::fred::FredDriver> {
#[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
pub async fn new_with_fred(
client: fred::clients::SubscriberClient,
) -> fred::prelude::FredResult<Self> {
Self::new_with_fred_config(client, RedisAdapterConfig::default()).await
}
#[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
pub async fn new_with_fred_config(
client: fred::clients::SubscriberClient,
config: RedisAdapterConfig,
) -> fred::prelude::FredResult<Self> {
let driver = drivers::fred::FredDriver::new(client).await?;
Ok(Self::new_with_driver(driver, config))
}
}
impl<R: Driver> RedisAdapterCtr<R> {
pub fn new_with_driver(driver: R, config: RedisAdapterConfig) -> RedisAdapterCtr<R> {
RedisAdapterCtr { driver, config }
}
}
pub(crate) type ResponseHandlers = HashMap<Sid, mpsc::Sender<Vec<u8>>>;
#[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
#[cfg(feature = "fred")]
pub type FredAdapter<E> = CustomRedisAdapter<E, drivers::fred::FredDriver>;
#[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
#[cfg(feature = "redis")]
pub type RedisAdapter<E> = CustomRedisAdapter<E, drivers::redis::RedisDriver>;
#[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
#[cfg(feature = "redis-cluster")]
pub type ClusterAdapter<E> = CustomRedisAdapter<E, drivers::redis::ClusterDriver>;
pub struct CustomRedisAdapter<E, R> {
driver: R,
config: RedisAdapterConfig,
uid: Uid,
local: CoreLocalAdapter<E>,
req_chan: String,
responses: Arc<Mutex<ResponseHandlers>>,
}
impl<E, R> DefinedAdapter for CustomRedisAdapter<E, R> {}
impl<E: SocketEmitter, R: Driver> CoreAdapter<E> for CustomRedisAdapter<E, R> {
type Error = Error<R>;
type State = RedisAdapterCtr<R>;
type AckStream = AckStream<E::AckStream>;
type InitRes = InitRes<R>;
fn new(state: &Self::State, local: CoreLocalAdapter<E>) -> Self {
let req_chan = format!("{}-request#{}#", state.config.prefix, local.path());
let uid = local.server_id();
Self {
local,
req_chan,
uid,
driver: state.driver.clone(),
config: state.config.clone(),
responses: Arc::new(Mutex::new(HashMap::new())),
}
}
fn init(self: Arc<Self>, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes {
let fut = async move {
check_ns(self.local.path())?;
let global_stream = self.subscribe(self.req_chan.clone()).await?;
let specific_stream = self.subscribe(self.get_req_chan(Some(self.uid))).await?;
let response_chan = format!(
"{}-response#{}#{}#",
&self.config.prefix,
self.local.path(),
self.uid
);
let response_stream = self.subscribe(response_chan.clone()).await?;
let stream = futures_util::stream::select(global_stream, specific_stream);
let stream = futures_util::stream::select(stream, response_stream);
tokio::spawn(self.pipe_stream(stream, response_chan));
on_success();
Ok(())
};
InitRes(Box::pin(fut))
}
async fn close(&self) -> Result<(), Self::Error> {
let response_chan = format!(
"{}-response#{}#{}#",
&self.config.prefix,
self.local.path(),
self.uid
);
tokio::try_join!(
self.driver.unsubscribe(self.req_chan.clone()),
self.driver.unsubscribe(self.get_req_chan(Some(self.uid))),
self.driver.unsubscribe(response_chan)
)
.map_err(Error::from_driver)?;
Ok(())
}
async fn server_count(&self) -> Result<u16, Self::Error> {
let count = self
.driver
.num_serv(&self.req_chan)
.await
.map_err(Error::from_driver)?;
Ok(count)
}
async fn broadcast(
&self,
packet: Packet,
opts: BroadcastOptions,
) -> Result<(), BroadcastError> {
if !opts.is_local(self.uid) {
let req = RequestOut::new(self.uid, RequestTypeOut::Broadcast(&packet), &opts);
self.send_req(req, opts.server_id)
.await
.map_err(AdapterError::from)?;
}
self.local.broadcast(packet, opts)?;
Ok(())
}
async fn broadcast_with_ack(
&self,
packet: Packet,
opts: BroadcastOptions,
timeout: Option<Duration>,
) -> Result<Self::AckStream, Self::Error> {
if opts.is_local(self.uid) {
tracing::debug!(?opts, "broadcast with ack is local");
let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout);
let stream = AckStream::new_local(local);
return Ok(stream);
}
let req = RequestOut::new(self.uid, RequestTypeOut::BroadcastWithAck(&packet), &opts);
let req_id = req.id;
let remote_serv_cnt = self.server_count().await?.saturating_sub(1);
let (tx, rx) = mpsc::channel(self.config.ack_response_buffer + remote_serv_cnt as usize);
self.responses.lock().unwrap().insert(req_id, tx);
let remote = MessageStream::new(rx);
self.send_req(req, opts.server_id).await?;
let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout);
let timeout = self
.config
.request_timeout
.saturating_add(timeout.unwrap_or(self.local.ack_timeout()));
Ok(AckStream::new(
local,
remote,
timeout,
remote_serv_cnt,
req_id,
self.responses.clone(),
))
}
async fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), BroadcastError> {
if !opts.is_local(self.uid) {
let req = RequestOut::new(self.uid, RequestTypeOut::DisconnectSockets, &opts);
self.send_req(req, opts.server_id)
.await
.map_err(AdapterError::from)?;
}
self.local
.disconnect_socket(opts)
.map_err(BroadcastError::Socket)?;
Ok(())
}
async fn rooms(&self, opts: BroadcastOptions) -> Result<Vec<Room>, Self::Error> {
if opts.is_local(self.uid) {
return Ok(self.local.rooms(opts).into_iter().collect());
}
let req = RequestOut::new(self.uid, RequestTypeOut::AllRooms, &opts);
let req_id = req.id;
let stream = self
.get_res::<()>(req_id, ResponseTypeId::AllRooms, opts.server_id)
.await?;
self.send_req(req, opts.server_id).await?;
let local = self.local.rooms(opts);
let rooms = stream
.filter_map(|item| future::ready(item.into_rooms()))
.fold(local, async |mut acc, item| {
acc.extend(item);
acc
})
.await;
Ok(Vec::from_iter(rooms))
}
async fn add_sockets(
&self,
opts: BroadcastOptions,
rooms: impl RoomParam,
) -> Result<(), Self::Error> {
let rooms: Vec<Room> = rooms.into_room_iter().collect();
if !opts.is_local(self.uid) {
let req = RequestOut::new(self.uid, RequestTypeOut::AddSockets(&rooms), &opts);
self.send_req(req, opts.server_id).await?;
}
self.local.add_sockets(opts, rooms);
Ok(())
}
async fn del_sockets(
&self,
opts: BroadcastOptions,
rooms: impl RoomParam,
) -> Result<(), Self::Error> {
let rooms: Vec<Room> = rooms.into_room_iter().collect();
if !opts.is_local(self.uid) {
let req = RequestOut::new(self.uid, RequestTypeOut::DelSockets(&rooms), &opts);
self.send_req(req, opts.server_id).await?;
}
self.local.del_sockets(opts, rooms);
Ok(())
}
async fn fetch_sockets(
&self,
opts: BroadcastOptions,
) -> Result<Vec<RemoteSocketData>, Self::Error> {
if opts.is_local(self.uid) {
return Ok(self.local.fetch_sockets(opts));
}
let req = RequestOut::new(self.uid, RequestTypeOut::FetchSockets, &opts);
let req_id = req.id;
let remote = self
.get_res::<RemoteSocketData>(req_id, ResponseTypeId::FetchSockets, opts.server_id)
.await?;
self.send_req(req, opts.server_id).await?;
let local = self.local.fetch_sockets(opts);
let sockets = remote
.filter_map(|item| future::ready(item.into_fetch_sockets()))
.fold(local, async |mut acc, item| {
acc.extend(item);
acc
})
.await;
Ok(sockets)
}
fn get_local(&self) -> &CoreLocalAdapter<E> {
&self.local
}
}
#[derive(thiserror::Error)]
pub enum InitError<D: Driver> {
#[error("driver error: {0}")]
Driver(D::Error),
#[error("malformed namespace path, it must not contain '#'")]
MalformedNamespace,
}
impl<D: Driver> fmt::Debug for InitError<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Driver(err) => fmt::Debug::fmt(err, f),
Self::MalformedNamespace => write!(f, "Malformed namespace path"),
}
}
}
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct InitRes<D: Driver>(futures_core::future::BoxFuture<'static, Result<(), InitError<D>>>);
impl<D: Driver> Future for InitRes<D> {
type Output = Result<(), InitError<D>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.as_mut().poll(cx)
}
}
impl<D: Driver> Spawnable for InitRes<D> {
fn spawn(self) {
tokio::spawn(async move {
if let Err(e) = self.0.await {
tracing::error!("error initializing adapter: {e}");
}
});
}
}
impl<E: SocketEmitter, R: Driver> CustomRedisAdapter<E, R> {
fn get_res_chan(&self, uid: Uid) -> String {
let path = self.local.path();
let prefix = &self.config.prefix;
format!("{prefix}-response#{path}#{uid}#")
}
fn get_req_chan(&self, node_id: Option<Uid>) -> String {
match node_id {
Some(uid) => format!("{}{}#", self.req_chan, uid),
None => self.req_chan.clone(),
}
}
async fn pipe_stream(
self: Arc<Self>,
mut stream: impl Stream<Item = ChanItem> + Unpin,
response_chan: String,
) {
while let Some((chan, item)) = stream.next().await {
if chan.starts_with(&self.req_chan) {
if let Err(e) = self.recv_req(item) {
let ns = self.local.path();
let uid = self.uid;
tracing::warn!(?uid, ?ns, "request handler error: {e}");
}
} else if chan == response_chan {
let req_id = read_req_id(&item);
tracing::trace!(?req_id, ?chan, ?response_chan, "extracted sid");
let handlers = self.responses.lock().unwrap();
if let Some(tx) = req_id.and_then(|id| handlers.get(&id)) {
if let Err(e) = tx.try_send(item) {
tracing::warn!("error sending response to handler: {e}");
}
} else {
tracing::warn!(?req_id, "could not find req handler");
}
} else {
tracing::warn!("unexpected message/channel: {chan}");
}
}
}
fn recv_req(self: &Arc<Self>, item: Vec<u8>) -> Result<(), Error<R>> {
let req: RequestIn = rmp_serde::from_slice(&item)?;
if req.node_id == self.uid {
return Ok(());
}
tracing::trace!(?req, "handling request");
let Some(opts) = req.opts else {
tracing::warn!(?req, "request is missing options");
return Ok(());
};
match req.r#type {
RequestTypeIn::Broadcast(p) => self.recv_broadcast(opts, p),
RequestTypeIn::BroadcastWithAck(p) => {
self.clone()
.recv_broadcast_with_ack(req.node_id, req.id, p, opts)
}
RequestTypeIn::DisconnectSockets => self.recv_disconnect_sockets(opts),
RequestTypeIn::AllRooms => self.recv_rooms(req.node_id, req.id, opts),
RequestTypeIn::AddSockets(rooms) => self.recv_add_sockets(opts, rooms),
RequestTypeIn::DelSockets(rooms) => self.recv_del_sockets(opts, rooms),
RequestTypeIn::FetchSockets => self.recv_fetch_sockets(req.node_id, req.id, opts),
_ => (),
};
Ok(())
}
fn recv_broadcast(&self, opts: BroadcastOptions, packet: Packet) {
if let Err(e) = self.local.broadcast(packet, opts) {
let ns = self.local.path();
tracing::warn!(?self.uid, ?ns, "remote request broadcast handler: {:?}", e);
}
}
fn recv_disconnect_sockets(&self, opts: BroadcastOptions) {
if let Err(e) = self.local.disconnect_socket(opts) {
let ns = self.local.path();
tracing::warn!(
?self.uid,
?ns,
"remote request disconnect sockets handler: {:?}",
e
);
}
}
fn recv_broadcast_with_ack(
self: Arc<Self>,
origin: Uid,
req_id: Sid,
packet: Packet,
opts: BroadcastOptions,
) {
let (stream, count) = self.local.broadcast_with_ack(packet, opts, None);
tokio::spawn(async move {
let on_err = |err| {
let ns = self.local.path();
tracing::warn!(
?origin,
?ns,
"remote request broadcast with ack handler errors: {:?}",
err
);
};
let res = Response {
r#type: ResponseType::<()>::BroadcastAckCount(count),
node_id: self.uid,
};
if let Err(err) = self.send_res(origin, req_id, res).await {
on_err(err);
return;
}
futures_util::pin_mut!(stream);
while let Some(ack) = stream.next().await {
let res = Response {
r#type: ResponseType::BroadcastAck(ack),
node_id: self.uid,
};
if let Err(err) = self.send_res(origin, req_id, res).await {
on_err(err);
return;
}
}
});
}
fn recv_rooms(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) {
let rooms = self.local.rooms(opts);
let res = Response {
r#type: ResponseType::<()>::AllRooms(rooms),
node_id: self.uid,
};
let fut = self.send_res(origin, req_id, res);
let ns = self.local.path().clone();
let uid = self.uid;
tokio::spawn(async move {
if let Err(err) = fut.await {
tracing::warn!(?uid, ?ns, "remote request rooms handler: {:?}", err);
}
});
}
fn recv_add_sockets(&self, opts: BroadcastOptions, rooms: Vec<Room>) {
self.local.add_sockets(opts, rooms);
}
fn recv_del_sockets(&self, opts: BroadcastOptions, rooms: Vec<Room>) {
self.local.del_sockets(opts, rooms);
}
fn recv_fetch_sockets(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) {
let sockets = self.local.fetch_sockets(opts);
let res = Response {
node_id: self.uid,
r#type: ResponseType::FetchSockets(sockets),
};
let fut = self.send_res(origin, req_id, res);
let ns = self.local.path().clone();
let uid = self.uid;
tokio::spawn(async move {
if let Err(err) = fut.await {
tracing::warn!(?uid, ?ns, "remote request fetch sockets handler: {:?}", err);
}
});
}
async fn send_req(&self, req: RequestOut<'_>, target_uid: Option<Uid>) -> Result<(), Error<R>> {
tracing::trace!(?req, "sending request");
let req = rmp_serde::to_vec(&req)?;
let chan = self.get_req_chan(target_uid);
self.driver
.publish(chan, req)
.await
.map_err(Error::from_driver)?;
Ok(())
}
fn send_res<D: Serialize + fmt::Debug>(
&self,
req_node_id: Uid,
req_id: Sid,
res: Response<D>,
) -> impl Future<Output = Result<(), Error<R>>> + Send + 'static {
let chan = self.get_res_chan(req_node_id);
tracing::trace!(?res, "sending response to {}", &chan);
let res = rmp_serde::to_vec(&(req_id, res));
let driver = self.driver.clone();
async move {
driver
.publish(chan, res?)
.await
.map_err(Error::from_driver)?;
Ok(())
}
}
async fn get_res<D: DeserializeOwned + fmt::Debug>(
&self,
req_id: Sid,
response_type: ResponseTypeId,
target_uid: Option<Uid>,
) -> Result<impl Stream<Item = Response<D>>, Error<R>> {
let remote_serv_cnt = if target_uid.is_none() {
self.server_count().await?.saturating_sub(1) as usize
} else {
1
};
let (tx, rx) = mpsc::channel(std::cmp::max(remote_serv_cnt, 1));
self.responses.lock().unwrap().insert(req_id, tx);
let stream = MessageStream::new(rx)
.filter_map(|item| {
let data = match rmp_serde::from_slice::<(Sid, Response<D>)>(&item) {
Ok((_, data)) => Some(data),
Err(e) => {
tracing::warn!("error decoding response: {e}");
None
}
};
future::ready(data)
})
.filter(move |item| future::ready(ResponseTypeId::from(&item.r#type) == response_type))
.take(remote_serv_cnt)
.take_until(time::sleep(self.config.request_timeout));
let stream = DropStream::new(stream, self.responses.clone(), req_id);
Ok(stream)
}
#[inline]
async fn subscribe(&self, pat: String) -> Result<MessageStream<ChanItem>, InitError<R>> {
tracing::trace!(?pat, "subscribing to");
self.driver
.subscribe(pat, self.config.stream_buffer)
.await
.map_err(InitError::Driver)
}
}
fn check_ns<D: Driver>(path: &str) -> Result<(), InitError<D>> {
if path.is_empty() || path.contains('#') {
Err(InitError::MalformedNamespace)
} else {
Ok(())
}
}
pub fn read_req_id(data: &[u8]) -> Option<Sid> {
use std::str::FromStr;
let mut rd = data;
let len = rmp::decode::read_array_len(&mut rd).ok()?;
if len < 1 {
return None;
}
let mut buff = [0u8; Sid::ZERO.as_str().len()];
let str = rmp::decode::read_str(&mut rd, &mut buff).ok()?;
Sid::from_str(str).ok()
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::stream::{self, FusedStream, StreamExt};
use socketioxide_core::{Str, Value, adapter::AckStreamItem};
use std::convert::Infallible;
#[derive(Clone)]
struct StubDriver;
impl Driver for StubDriver {
type Error = Infallible;
async fn publish(&self, _: String, _: Vec<u8>) -> Result<(), Self::Error> {
Ok(())
}
async fn subscribe(
&self,
_: String,
_: usize,
) -> Result<MessageStream<ChanItem>, Self::Error> {
Ok(MessageStream::new_empty())
}
async fn unsubscribe(&self, _: String) -> Result<(), Self::Error> {
Ok(())
}
async fn num_serv(&self, _: &str) -> Result<u16, Self::Error> {
Ok(0)
}
}
fn new_stub_ack_stream(
remote: MessageStream<Vec<u8>>,
timeout: Duration,
) -> AckStream<stream::Empty<AckStreamItem<()>>> {
AckStream::new(
stream::empty::<AckStreamItem<()>>(),
remote,
timeout,
2,
Sid::new(),
Arc::new(Mutex::new(HashMap::new())),
)
}
#[tokio::test]
async fn ack_stream() {
let (tx, rx) = tokio::sync::mpsc::channel(255);
let remote = MessageStream::new(rx);
let stream = new_stub_ack_stream(remote, Duration::from_secs(10));
let node_id = Uid::new();
let req_id = Sid::new();
let ack_cnt_res = Response::<()> {
node_id,
r#type: ResponseType::BroadcastAckCount(2),
};
tx.try_send(rmp_serde::to_vec(&(req_id, &ack_cnt_res)).unwrap())
.unwrap();
tx.try_send(rmp_serde::to_vec(&(req_id, &ack_cnt_res)).unwrap())
.unwrap();
let ack_res = Response::<String> {
node_id,
r#type: ResponseType::BroadcastAck((Sid::new(), Ok(Value::Str(Str::from(""), None)))),
};
for _ in 0..4 {
tx.try_send(rmp_serde::to_vec(&(req_id, &ack_res)).unwrap())
.unwrap();
}
futures_util::pin_mut!(stream);
for _ in 0..4 {
assert!(stream.next().await.is_some());
}
assert!(stream.is_terminated());
}
#[tokio::test]
async fn ack_stream_timeout() {
let (tx, rx) = tokio::sync::mpsc::channel(255);
let remote = MessageStream::new(rx);
let stream = new_stub_ack_stream(remote, Duration::from_millis(50));
let node_id = Uid::new();
let req_id = Sid::new();
let ack_cnt_res = Response::<()> {
node_id,
r#type: ResponseType::BroadcastAckCount(2),
};
tx.try_send(rmp_serde::to_vec(&(req_id, ack_cnt_res)).unwrap())
.unwrap();
futures_util::pin_mut!(stream);
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(stream.next().await.is_none());
assert!(stream.is_terminated());
}
#[tokio::test]
async fn ack_stream_drop() {
let (tx, rx) = tokio::sync::mpsc::channel(255);
let remote = MessageStream::new(rx);
let handlers = Arc::new(Mutex::new(HashMap::new()));
let id = Sid::new();
handlers.lock().unwrap().insert(id, tx);
let stream = AckStream::new(
stream::empty::<AckStreamItem<()>>(),
remote,
Duration::from_secs(10),
2,
id,
handlers.clone(),
);
drop(stream);
assert!(handlers.lock().unwrap().is_empty(),);
}
#[test]
fn check_ns_error() {
assert!(matches!(
check_ns::<StubDriver>("#"),
Err(InitError::MalformedNamespace)
));
assert!(matches!(
check_ns::<StubDriver>(""),
Err(InitError::MalformedNamespace)
));
}
}