use crate::ack::AckType;
use crate::error::{ClientBuilderError, ClientError, PayloadError, SocketError};
use crate::marker::{AckId, AckMarker, BinaryMarker};
use bytestring::ByteString;
use sioc_engine::engine::Engine;
use sioc_engine::transport::TransportStrategy;
use sioc_engine::websocket::WebSocketConnector;
use sioc_socket::error::ManagerError;
use sioc_socket::manager::{Manager, ManagerAction, ManagerSender, manager_sink};
use sioc_socket::packet::{Directive, Signal};
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use url::Url;
pub trait Emit<A, B>
where
A: AckMarker,
B: BinaryMarker,
{
type Output;
fn prepare(self) -> Result<(Directive, Self::Output), PayloadError>;
}
pub trait Acknowledge<A, B>
where
A: AckType,
B: BinaryMarker,
{
fn into_directive(self, id: u64) -> Result<Directive, PayloadError>;
}
pub struct ClientBuilder<C = ()> {
url: Url,
path: String,
http_client: Option<reqwest::Client>,
websocket_connector: C,
transport_strategy: TransportStrategy,
}
impl ClientBuilder<()> {
pub fn new(url: impl Into<Url>) -> Self {
Self {
url: url.into(),
path: "socket.io/".to_string(),
http_client: None,
websocket_connector: (),
transport_strategy: TransportStrategy::default(),
}
}
}
impl<C> ClientBuilder<C>
where
C: WebSocketConnector,
{
pub fn path(mut self, path: impl Into<String>) -> Self {
self.path = path.into();
self
}
pub fn http_client(mut self, client: reqwest::Client) -> Self {
self.http_client = Some(client);
self
}
pub fn websocket_connector<C2>(self, connector: C2) -> ClientBuilder<C2>
where
C2: WebSocketConnector,
{
ClientBuilder {
url: self.url,
path: self.path,
http_client: self.http_client,
websocket_connector: connector,
transport_strategy: self.transport_strategy,
}
}
pub fn transport(mut self, strategy: TransportStrategy) -> Self {
self.transport_strategy = strategy;
self
}
#[must_use = "dropping the Client stops the background tasks"]
pub fn open(self) -> Result<Client, ClientBuilderError> {
let http_client = self.http_client.unwrap_or_default();
let websocket_connector = self.websocket_connector;
let url = self.url.join(&self.path)?;
let (manager_tx, manager_rx) = mpsc::channel::<ManagerAction>(32);
let engine = Engine::connect(
url,
http_client,
websocket_connector,
self.transport_strategy,
manager_sink(manager_tx.clone()),
);
let manager = Manager::new(manager_rx);
let manager_handle = tokio::spawn(manager.socket_io(engine));
Ok(Client {
tx: ManagerSender::new(manager_tx),
handle: manager_handle,
})
}
}
#[derive(Debug)]
pub struct Client {
tx: ManagerSender,
handle: JoinHandle<Result<(), ManagerError>>,
}
impl Client {
pub fn builder(url: impl Into<Url>) -> ClientBuilder {
ClientBuilder::new(url)
}
pub async fn connect<S>(&self, ns: S) -> Result<(SocketSender, SocketReceiver), SocketError>
where
S: Into<ByteString>,
{
self.connect_with(ns, ByteString::new()).await
}
pub async fn connect_with<S, B>(
&self,
ns: S,
payload: B,
) -> Result<(SocketSender, SocketReceiver), SocketError>
where
S: Into<ByteString>,
B: Into<ByteString>,
{
let (tx, rx) = mpsc::channel(32);
let socket_tx = SocketSender::new(ns.into(), self.tx.clone());
let socket_rx = SocketReceiver { rx };
let directive = Directive::Connect {
tx,
payload: payload.into(),
};
socket_tx.send(directive).await?;
Ok((socket_tx, socket_rx))
}
pub async fn join(self) -> Result<(), ClientError> {
drop(self.tx);
self.handle.await??;
Ok(())
}
}
#[derive(Debug)]
pub struct SocketSender {
ns: ByteString,
tx: ManagerSender,
is_connected: AtomicBool,
}
impl SocketSender {
fn new(ns: ByteString, tx: ManagerSender) -> Self {
Self {
ns,
tx,
is_connected: AtomicBool::new(true),
}
}
async fn send(&self, directive: Directive) -> Result<(), SocketError> {
self.tx
.send(self.ns.clone(), directive)
.await
.map_err(SocketError::Send)
}
pub async fn emit<E, A, B>(&self, event: E) -> Result<E::Output, SocketError>
where
E: Emit<A, B>,
A: AckMarker,
B: BinaryMarker,
{
let (directive, output) = event.prepare()?;
self.send(directive).await?;
Ok(output)
}
pub async fn acknowledge<T, A, B>(&self, id: AckId<A>, payload: T) -> Result<(), SocketError>
where
T: Acknowledge<A, B>,
A: AckType,
B: BinaryMarker,
{
let directive = payload.into_directive(id.get())?;
self.send(directive).await
}
pub async fn disconnect(&self) -> Result<(), SocketError> {
if self.is_connected.swap(false, Ordering::Relaxed) {
self.send(Directive::Disconnect).await?;
}
Ok(())
}
}
impl Drop for SocketSender {
fn drop(&mut self) {
if self.is_connected.swap(false, Ordering::Relaxed) {
let type_name = std::any::type_name::<Self>();
tracing::warn!(ns = %self.ns, type_name, "dropped while connected");
let _ = self.tx.try_send(self.ns.clone(), Directive::Disconnect);
}
}
}
#[derive(Debug)]
pub struct SocketReceiver {
rx: mpsc::Receiver<Signal>,
}
impl SocketReceiver {
pub async fn recv(&mut self) -> Option<Signal> {
self.rx.recv().await
}
}