use crate::{
asynchronous::{
async_socket::Socket as InnerSocket,
async_transports::{PollingTransport, WebsocketSecureTransport, WebsocketTransport},
callback::OptionalCallback,
transport::AsyncTransport,
},
error::Result,
header::HeaderMap,
packet::HandshakePacket,
Error, Packet, ENGINE_IO_VERSION,
};
use bytes::Bytes;
use futures_util::{future::BoxFuture, StreamExt};
use native_tls::TlsConnector;
use url::Url;
use super::Client;
#[derive(Clone, Debug)]
pub struct ClientBuilder {
url: Url,
tls_config: Option<TlsConnector>,
headers: Option<HeaderMap>,
handshake: Option<HandshakePacket>,
on_error: OptionalCallback<String>,
on_open: OptionalCallback<()>,
on_close: OptionalCallback<()>,
on_data: OptionalCallback<Bytes>,
on_packet: OptionalCallback<Packet>,
}
impl ClientBuilder {
pub fn new(url: Url) -> Self {
let mut url = url;
url.query_pairs_mut()
.append_pair("EIO", &ENGINE_IO_VERSION.to_string());
if url.path() == "/" {
url.set_path("/engine.io/");
}
ClientBuilder {
url,
headers: None,
tls_config: None,
handshake: None,
on_close: OptionalCallback::default(),
on_data: OptionalCallback::default(),
on_error: OptionalCallback::default(),
on_open: OptionalCallback::default(),
on_packet: OptionalCallback::default(),
}
}
pub fn tls_config(mut self, tls_config: TlsConnector) -> Self {
self.tls_config = Some(tls_config);
self
}
pub fn headers(mut self, headers: HeaderMap) -> Self {
self.headers = Some(headers);
self
}
#[cfg(feature = "async-callbacks")]
pub fn on_close<T>(mut self, callback: T) -> Self
where
T: 'static + Send + Sync + Fn(()) -> BoxFuture<'static, ()>,
{
self.on_close = OptionalCallback::new(callback);
self
}
#[cfg(feature = "async-callbacks")]
pub fn on_data<T>(mut self, callback: T) -> Self
where
T: 'static + Send + Sync + Fn(Bytes) -> BoxFuture<'static, ()>,
{
self.on_data = OptionalCallback::new(callback);
self
}
#[cfg(feature = "async-callbacks")]
pub fn on_error<T>(mut self, callback: T) -> Self
where
T: 'static + Send + Sync + Fn(String) -> BoxFuture<'static, ()>,
{
self.on_error = OptionalCallback::new(callback);
self
}
#[cfg(feature = "async-callbacks")]
pub fn on_open<T>(mut self, callback: T) -> Self
where
T: 'static + Send + Sync + Fn(()) -> BoxFuture<'static, ()>,
{
self.on_open = OptionalCallback::new(callback);
self
}
#[cfg(feature = "async-callbacks")]
pub fn on_packet<T>(mut self, callback: T) -> Self
where
T: 'static + Send + Sync + Fn(Packet) -> BoxFuture<'static, ()>,
{
self.on_packet = OptionalCallback::new(callback);
self
}
async fn handshake_with_transport<T: AsyncTransport + Unpin>(
&mut self,
transport: &mut T,
) -> Result<()> {
if self.handshake.is_some() {
return Ok(());
}
let mut url = self.url.clone();
let handshake: HandshakePacket =
Packet::try_from(transport.next().await.ok_or(Error::IncompletePacket())??)?
.try_into()?;
url.query_pairs_mut().append_pair("sid", &handshake.sid[..]);
self.handshake = Some(handshake);
self.url = url;
Ok(())
}
async fn handshake(&mut self) -> Result<()> {
if self.handshake.is_some() {
return Ok(());
}
let headers = if let Some(map) = self.headers.clone() {
Some(map.try_into()?)
} else {
None
};
let mut transport =
PollingTransport::new(self.url.clone(), self.tls_config.clone(), headers);
self.handshake_with_transport(&mut transport).await
}
pub async fn build(mut self) -> Result<Client> {
self.handshake().await?;
if self.websocket_upgrade()? {
self.build_websocket_with_upgrade().await
} else {
self.build_polling().await
}
}
pub async fn build_polling(mut self) -> Result<Client> {
self.handshake().await?;
let transport = PollingTransport::new(
self.url,
self.tls_config,
self.headers.map(|v| v.try_into().unwrap()),
);
Ok(Client::new(InnerSocket::new(
transport.into(),
self.handshake.unwrap(),
self.on_close,
self.on_data,
self.on_error,
self.on_open,
self.on_packet,
)))
}
pub async fn build_websocket_with_upgrade(mut self) -> Result<Client> {
self.handshake().await?;
if self.websocket_upgrade()? {
self.build_websocket().await
} else {
Err(Error::IllegalWebsocketUpgrade())
}
}
pub async fn build_websocket(mut self) -> Result<Client> {
let headers = if let Some(map) = self.headers.clone() {
Some(map.try_into()?)
} else {
None
};
match self.url.scheme() {
"http" | "ws" => {
let mut transport = WebsocketTransport::new(self.url.clone(), headers).await?;
if self.handshake.is_some() {
transport.upgrade().await?;
} else {
self.handshake_with_transport(&mut transport).await?;
}
Ok(Client::new(InnerSocket::new(
transport.into(),
self.handshake.unwrap(),
self.on_close,
self.on_data,
self.on_error,
self.on_open,
self.on_packet,
)))
}
"https" | "wss" => {
let mut transport = WebsocketSecureTransport::new(
self.url.clone(),
self.tls_config.clone(),
headers,
)
.await?;
if self.handshake.is_some() {
transport.upgrade().await?;
} else {
self.handshake_with_transport(&mut transport).await?;
}
Ok(Client::new(InnerSocket::new(
transport.into(),
self.handshake.unwrap(),
self.on_close,
self.on_data,
self.on_error,
self.on_open,
self.on_packet,
)))
}
_ => Err(Error::InvalidUrlScheme(self.url.scheme().to_string())),
}
}
pub async fn build_with_fallback(self) -> Result<Client> {
let result = self.clone().build().await;
if result.is_err() {
self.build_polling().await
} else {
result
}
}
fn websocket_upgrade(&mut self) -> Result<bool> {
if self.handshake.is_none() {
return Ok(false);
}
Ok(self
.handshake
.as_ref()
.unwrap()
.upgrades
.iter()
.any(|upgrade| upgrade.to_lowercase() == *"websocket"))
}
}