use std::{
ops::{Deref, DerefMut},
sync::Arc,
time::Duration,
};
use crate::{
socket::Socket as InnerSocket, AckId, ClientBuilder, Error, Event, Packet, Payload, Result,
};
use backoff::{backoff::Backoff, ExponentialBackoff, ExponentialBackoffBuilder};
use futures_util::future::BoxFuture;
use tokio::sync::RwLock;
use tracing::{trace, warn};
#[derive(Clone)]
pub struct Client {
builder: ClientBuilder,
socket: Arc<RwLock<InnerSocket<Socket>>>,
backoff: ExponentialBackoff,
connected: Arc<RwLock<bool>>,
}
#[derive(Clone)]
pub struct Socket {
pub(crate) socket: InnerSocket<Self>,
}
impl From<InnerSocket<Socket>> for Socket {
fn from(socket: InnerSocket<Socket>) -> Self {
Self { socket }
}
}
impl Client {
#[inline]
pub async fn emit<E, D>(&self, event: E, data: D) -> Result<()>
where
E: Into<Event>,
D: Into<Payload>,
{
let socket = self.socket.read().await;
socket.emit(event, data).await
}
#[inline]
pub async fn emit_with_ack<F, E, D>(
&self,
event: E,
data: D,
timeout: Duration,
callback: F,
) -> Result<()>
where
F: for<'a> std::ops::FnMut(
Option<Payload>,
Socket,
Option<AckId>,
) -> BoxFuture<'static, ()>
+ 'static
+ Send
+ Sync,
E: Into<Event>,
D: Into<Payload>,
{
let socket = self.socket.read().await;
socket.emit_with_ack(event, data, timeout, callback).await
}
pub async fn ack(&self, id: usize, data: Payload) -> Result<()> {
let socket = self.socket.read().await;
socket.ack(id, data).await
}
pub async fn disconnect(&self) -> Result<()> {
trace!("client disconnect");
let mut connected = self.connected.write().await;
if !*connected {
return Ok(());
}
*connected = false;
self.disconnect_socket().await
}
async fn disconnect_socket(&self) -> Result<()> {
let socket = self.socket.read().await;
socket.disconnect().await
}
pub(crate) async fn new(builder: ClientBuilder) -> Result<Self> {
let b = builder.clone();
let socket = b.connect_socket().await?;
let connected = Arc::new(RwLock::new(true));
let backoff = ExponentialBackoffBuilder::new()
.with_initial_interval(Duration::from_millis(builder.reconnect_delay_min))
.with_max_interval(Duration::from_millis(builder.reconnect_delay_max))
.build();
let s = Self {
builder,
socket: Arc::new(RwLock::new(socket)),
backoff,
connected,
};
Ok(s)
}
async fn reconnect(&mut self) {
let mut reconnect_attempts = 0;
if self.builder.reconnect {
loop {
if let Some(max_reconnect_attempts) = self.builder.max_reconnect_attempts {
if reconnect_attempts > max_reconnect_attempts {
break;
}
}
reconnect_attempts += 1;
if let Some(backoff) = self.backoff.next_backoff() {
trace!("reconnect backoff {:?}", backoff);
tokio::time::sleep(backoff).await;
}
trace!("client reconnect {}", reconnect_attempts);
if self.do_reconnect().await.is_ok() {
break;
}
}
}
}
async fn do_reconnect(&self) -> Result<()> {
let new_socket = self.builder.clone().connect_socket().await?;
let mut socket = self.socket.write().await;
*socket = new_socket;
Ok(())
}
pub(crate) fn poll_callback(&self) {
let mut self_clone = self.clone();
tokio::spawn(async move {
trace!("start poll_callback ");
#[allow(clippy::for_loops_over_fallibles)]
loop {
let packet = self_clone.poll_packet().await;
trace!("poll_callback packet {:?}", packet);
if let Some(Err(Error::IncompleteResponseFromEngineIo(_))) = packet {
let _ = self_clone.disconnect_socket().await;
self_clone.reconnect().await;
}
if !*self_clone.connected.read().await {
break;
}
}
warn!("poll_callback exist");
});
}
pub(crate) async fn poll_packet(&self) -> Option<Result<Packet>> {
let socket = self.socket.read().await;
socket.poll_packet().await
}
}
impl Deref for Socket {
type Target = InnerSocket<Self>;
fn deref(&self) -> &Self::Target {
&self.socket
}
}
impl DerefMut for Socket {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.socket
}
}
#[cfg(test)]
mod test {
use std::time::Duration;
use super::*;
use crate::{
test::socket_io_server, AckId, Client, ClientBuilder, Event, Packet, PacketType, Payload,
Result, ServerBuilder, ServerSocket,
};
use bytes::Bytes;
use futures_util::FutureExt;
use serde_json::json;
use tokio::{sync::mpsc::unbounded_channel, time::sleep};
use tracing::info;
#[tokio::test(flavor = "multi_thread", worker_threads = 3)]
async fn test_client() -> Result<()> {
setup_server();
socket_io_integration().await?;
socket_io_builder_integration().await?;
socket_io_builder_integration_iterator().await?;
Ok(())
}
async fn socket_io_integration() -> Result<()> {
let url = socket_io_server();
let socket = ClientBuilder::new(url)
.on("test", |msg, _, _| {
async {
match msg {
Some(Payload::Json(data)) => info!("Received string: {:?}", data),
Some(Payload::Binary(bin)) => info!("Received binary data: {:#?}", bin),
Some(Payload::Multi(multi)) => info!("Received multi {:?}", multi),
_ => {}
}
}
.boxed()
})
.connect()
.await?;
let payload = json!({"token": 123_i32});
let result = socket.emit("test", Payload::Json(payload.clone())).await;
assert!(result.is_ok());
let ack = socket
.emit_with_ack(
"test",
Payload::Json(payload),
Duration::from_secs(1),
|message: Option<Payload>, socket: Socket, _| {
async move {
let result = socket
.emit("test", Payload::Json(json!({"got ack": true})))
.await;
assert!(result.is_ok());
info!("Yehaa! My ack got acked?");
if let Some(Payload::Json(data)) = message {
info!("Received string Ack");
info!("Ack data: {:?}", data);
}
}
.boxed()
},
)
.await;
assert!(ack.is_ok());
sleep(Duration::from_secs(2)).await;
assert!(socket.disconnect().await.is_ok());
Ok(())
}
async fn socket_io_builder_integration() -> Result<()> {
let url = socket_io_server();
let socket_builder = ClientBuilder::new(url);
let socket = socket_builder
.namespace("/admin")
.opening_header("accept-encoding", "application/json")
.on("test", |str, _, _| {
async move { info!("Received: {:#?}", str) }.boxed()
})
.on("message", |payload, _, _| {
async move { info!("{:#?}", payload) }.boxed()
})
.connect()
.await?;
assert!(socket.emit("message", json!("Hello World")).await.is_ok());
assert!(socket
.emit("binary", Bytes::from_static(&[46, 88]))
.await
.is_ok());
assert!(socket
.emit_with_ack(
"binary",
json!("pls ack"),
Duration::from_secs(1),
|payload, _, _| async move {
info!("Yehaa the ack got acked");
info!("With data: {:#?}", payload);
}
.boxed()
)
.await
.is_ok());
sleep(Duration::from_secs(2)).await;
Ok(())
}
async fn socket_io_builder_integration_iterator() -> Result<()> {
let url = socket_io_server();
let socket_builder = ClientBuilder::new(url);
let socket = socket_builder
.namespace("/admin")
.opening_header("accept-encoding", "application/json")
.on("test", |str, _, _| {
async move { info!("Received: {:#?}", str) }.boxed()
})
.on("message", |payload, _, _| {
async move { info!("Received binary {:#?}", payload) }.boxed()
})
.connect_client()
.await?;
test_socketio_socket(socket, "/admin".to_owned()).await
}
async fn test_socketio_socket(socket: Client, nsp: String) -> Result<()> {
let _: Option<Packet> = Some(socket.poll_packet().await.unwrap()?);
let packet: Option<Packet> = Some(socket.poll_packet().await.unwrap()?);
assert!(packet.is_some());
let packet = packet.unwrap();
assert_eq!(
packet,
Packet::new(
PacketType::Event,
nsp.clone(),
Some(json!(["test", "Hello from the test event!"])),
None,
0,
None
)
);
let packet: Option<Packet> = Some(socket.poll_packet().await.unwrap()?);
assert!(packet.is_some());
let packet = packet.unwrap();
assert_eq!(
packet,
Packet::new(
PacketType::BinaryEvent,
nsp.clone(),
Some(json!(["test", {"_placeholder": true, "num": 0}])),
None,
1,
Some(vec![Bytes::from_static(&[1, 2, 3])]),
)
);
let packet: Option<Packet> = Some(socket.poll_packet().await.unwrap()?);
assert!(packet.is_some());
let packet = packet.unwrap();
match packet.data {
Some(serde_json::Value::Array(array)) => assert_eq!(array.len(), 5),
_ => panic!("invlaid emit multi payload"),
}
let socket_clone = socket.clone();
tokio::spawn(async move {
loop {
let _ = socket_clone.poll_packet().await;
}
});
let (tx, mut rx) = unbounded_channel();
let tx = Arc::new(tx);
let cb = move |message: Option<Payload>, _, _| {
let tx = tx.clone();
async move {
match message {
Some(Payload::Multi(vec)) => {
let _ = tx.send(vec.len() == 2);
}
_ => {
let _ = tx.send(false);
}
};
}
.boxed()
};
assert!(socket
.emit_with_ack(
"client_ack",
Payload::Multi(vec![json!(1).into(), json!(2).into()]),
Duration::from_secs(10),
cb
)
.await
.is_ok());
match rx.recv().await {
Some(true) => {}
_ => panic!("ACK callback invlaid"),
};
let (tx, mut rx) = unbounded_channel();
let cb = move |message: Option<Payload>, _, _| {
let tx = tx.clone();
async move {
match message {
Some(Payload::Multi(vec)) => {
let _ = tx.send(vec.len() == 2);
}
_ => {
let _ = tx.send(false);
}
};
}
.boxed()
};
assert!(socket
.emit_with_ack(
"client_ack",
Payload::Multi(vec![Bytes::from_static(b"1").into(), json!(2).into()]),
Duration::from_secs(10),
cb
)
.await
.is_ok());
match rx.recv().await {
Some(true) => {}
_ => panic!("BINARY_ACK callback invlaid"),
};
Ok(())
}
fn setup_server() {
let echo_callback =
move |_payload: Option<Payload>, socket: ServerSocket, _need_ack: Option<AckId>| {
async move {
let _ = socket.emit("echo", json!("")).await;
}
.boxed()
};
let client_ack =
move |payload: Option<Payload>, socket: ServerSocket, need_ack: Option<AckId>| {
async move {
if let Some(ack_id) = need_ack {
socket
.ack(ack_id, payload.unwrap_or_else(|| json!("ackback").into()))
.await
.expect("success");
}
}
.boxed()
};
let server_recv_ack =
move |_payload: Option<Payload>, socket: ServerSocket, _need_ack: Option<AckId>| {
async move {
socket
.emit("server_recv_ack", json!(""))
.await
.expect("success");
}
.boxed()
};
let trigger_ack = move |message: Option<Payload>, socket: ServerSocket, _| {
async move {
let payload = message.unwrap_or_else(|| json!({"ack_back": true}).into());
socket
.emit_with_ack(
"server_ask_ack",
payload,
Duration::from_millis(400),
server_recv_ack,
)
.await
.expect("success");
}
.boxed()
};
let connect_cb = move |_payload: Option<Payload>, socket: ServerSocket, _| {
async move {
socket
.emit("test", json!("Hello from the test event!"))
.await
.expect("success");
socket
.emit("test", Payload::Binary(Bytes::from_static(&[1, 2, 3])))
.await
.expect("success");
socket
.emit(
"test",
Payload::Multi(vec![
json!(1).into(),
json!("2").into(),
Bytes::from_static(&[3]).into(),
Bytes::from_static(b"4").into(),
]),
)
.await
.expect("success");
}
.boxed()
};
let url = socket_io_server();
let server = ServerBuilder::new(url.port().unwrap())
.on("/admin", "echo", echo_callback)
.on("/admin", "client_ack", client_ack)
.on("/admin", "server_ack", trigger_ack)
.on("/admin", Event::Connect, connect_cb)
.build();
tokio::spawn(async move { server.serve().await });
}
}