use crate::socket::{InMessage, MessageSignal, SocketConfig};
use crate::CloseFrame;
use crate::Error;
use crate::Message;
use crate::RawMessage;
use crate::Request;
use crate::Socket;
use async_trait::async_trait;
use base64::Engine;
use bytes::Bytes;
use enfync::Handle;
use futures::{FutureExt, SinkExt, StreamExt};
use http::header::HeaderName;
use http::HeaderValue;
use std::fmt;
use std::time::Duration;
use tokio_tungstenite_wasm::error::ProtocolError;
use tokio_tungstenite_wasm::Error as WSError;
use tracing::Instrument;
use tungstenite::Utf8Bytes;
use url::Url;
#[cfg(not(target_family = "wasm"))]
use tokio::time::sleep;
#[cfg(target_family = "wasm")]
use wasmtimer::tokio::sleep;
pub const DEFAULT_RECONNECT_INTERVAL: Duration = Duration::new(5, 0);
#[derive(Debug)]
pub struct ClientConfig {
url: Url,
max_initial_connect_attempts: usize,
max_reconnect_attempts: usize,
reconnect_interval: Duration,
headers: http::HeaderMap,
socket_config: Option<SocketConfig>,
}
impl ClientConfig {
pub fn new<U>(url: U) -> Self
where
U: TryInto<Url>,
U::Error: fmt::Debug,
{
let url = url.try_into().expect("invalid URL");
Self {
url,
max_initial_connect_attempts: usize::MAX,
max_reconnect_attempts: usize::MAX,
reconnect_interval: DEFAULT_RECONNECT_INTERVAL,
headers: http::HeaderMap::new(),
socket_config: None,
}
}
pub fn basic(mut self, username: impl fmt::Display, password: impl fmt::Display) -> Self {
let credentials =
base64::engine::general_purpose::STANDARD.encode(format!("{username}:{password}"));
self.headers.insert(
http::header::AUTHORIZATION,
http::HeaderValue::from_str(&format!("Basic {credentials}")).unwrap(),
);
self
}
pub fn bearer(mut self, token: impl fmt::Display) -> Self {
self.headers.insert(
http::header::AUTHORIZATION,
http::HeaderValue::from_str(&format!("Bearer {token}"))
.expect("token contains invalid character"),
);
self
}
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
let name = <HeaderName as TryFrom<K>>::try_from(key)
.map_err(Into::into)
.expect("invalid header name");
let value = <HeaderValue as TryFrom<V>>::try_from(value)
.map_err(Into::into)
.expect("invalid header value");
self.headers.insert(name, value);
self
}
pub fn query_parameter(mut self, key: &str, value: &str) -> Self {
self.url.query_pairs_mut().append_pair(key, value);
self
}
pub fn max_initial_connect_attempts(mut self, max_initial_connect_attempts: usize) -> Self {
self.max_initial_connect_attempts = max_initial_connect_attempts;
self
}
pub fn max_reconnect_attempts(mut self, max_reconnect_attempts: usize) -> Self {
self.max_reconnect_attempts = max_reconnect_attempts;
self
}
pub fn reconnect_interval(mut self, reconnect_interval: Duration) -> Self {
self.reconnect_interval = reconnect_interval;
self
}
pub fn socket_config(mut self, socket_config: SocketConfig) -> Self {
self.socket_config = Some(socket_config);
self
}
pub fn headers(&self) -> &http::HeaderMap {
&self.headers
}
pub fn connect_http_request(&self) -> Request {
let mut http_request = Request::builder()
.uri(self.url.as_str())
.method("GET")
.header("Host", self.url.host().unwrap().to_string())
.header("Connection", "Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header(
"Sec-WebSocket-Key",
tungstenite::handshake::client::generate_key(),
)
.body(())
.unwrap();
for (key, value) in self.headers.clone() {
http_request.headers_mut().insert(key.unwrap(), value);
}
http_request
}
pub fn connect_url(&self) -> &str {
self.url.as_str()
}
}
#[derive(Debug, Clone)]
pub enum ClientCloseMode {
Reconnect,
Close,
}
#[async_trait]
pub trait ClientExt: Send {
type Call: Send;
async fn on_text(&mut self, text: Utf8Bytes) -> Result<(), Error>;
async fn on_binary(&mut self, bytes: Bytes) -> Result<(), Error>;
async fn on_call(&mut self, call: Self::Call) -> Result<(), Error>;
async fn on_connect(&mut self) -> Result<(), Error> {
Ok(())
}
async fn on_connect_fail(&mut self, _error: WSError) -> Result<ClientCloseMode, Error> {
Ok(ClientCloseMode::Reconnect)
}
async fn on_close(&mut self, _frame: Option<CloseFrame>) -> Result<ClientCloseMode, Error> {
Ok(ClientCloseMode::Reconnect)
}
async fn on_disconnect(&mut self) -> Result<ClientCloseMode, Error> {
Ok(ClientCloseMode::Reconnect)
}
}
#[async_trait]
pub trait ClientConnector {
type Handle: enfync::Handle;
type Message: Into<RawMessage> + From<RawMessage> + std::fmt::Debug + Send + 'static;
type WSError: std::error::Error + Into<WSError> + Send;
type Socket: SinkExt<Self::Message, Error = Self::WSError>
+ StreamExt<Item = Result<Self::Message, Self::WSError>>
+ Unpin
+ Send
+ 'static;
fn handle(&self) -> Self::Handle;
async fn connect(&self, client_config: &ClientConfig) -> Result<Self::Socket, Self::WSError>;
}
#[derive(Debug)]
pub struct Client<E: ClientExt> {
to_socket_sender: async_channel::Sender<InMessage>,
client_call_sender: async_channel::Sender<E::Call>,
}
impl<E: ClientExt> Clone for Client<E> {
fn clone(&self) -> Self {
Self {
to_socket_sender: self.to_socket_sender.clone(),
client_call_sender: self.client_call_sender.clone(),
}
}
}
impl<E: ClientExt> From<Client<E>> for async_channel::Sender<E::Call> {
fn from(client: Client<E>) -> Self {
client.client_call_sender
}
}
impl<E: ClientExt> Client<E> {
pub fn text(
&self,
text: impl Into<Utf8Bytes>,
) -> Result<MessageSignal, async_channel::SendError<InMessage>> {
let inmessage = InMessage::new(Message::Text(text.into()));
let inmessage_signal = inmessage.clone_signal().unwrap(); self.to_socket_sender
.send_blocking(inmessage)
.map(|_| inmessage_signal)
}
pub fn binary(
&self,
bytes: impl Into<Bytes>,
) -> Result<MessageSignal, async_channel::SendError<InMessage>> {
let inmessage = InMessage::new(Message::Binary(bytes.into()));
let inmessage_signal = inmessage.clone_signal().unwrap(); self.to_socket_sender
.send_blocking(inmessage)
.map(|_| inmessage_signal)
}
pub fn call(&self, message: E::Call) -> Result<(), async_channel::SendError<E::Call>> {
self.client_call_sender.send_blocking(message)
}
pub async fn call_with<R: fmt::Debug>(
&self,
f: impl FnOnce(async_channel::Sender<R>) -> E::Call,
) -> Option<R> {
let (sender, receiver) = async_channel::bounded(1usize);
let call = f(sender);
let Ok(_) = self.client_call_sender.send(call).await else {
return None;
};
let Ok(result) = receiver.recv().await else {
return None;
};
Some(result)
}
pub fn close(
&self,
frame: Option<CloseFrame>,
) -> Result<MessageSignal, async_channel::SendError<InMessage>> {
let inmessage = InMessage::new(Message::Close(frame));
let inmessage_signal = inmessage.clone_signal().unwrap(); self.to_socket_sender
.send_blocking(inmessage)
.map(|_| inmessage_signal)
}
}
#[cfg(feature = "native_client")]
#[cfg_attr(docsrs, doc(cfg(feature = "native_client")))]
pub async fn connect<E: ClientExt + 'static>(
client_fn: impl FnOnce(Client<E>) -> E,
config: ClientConfig,
) -> (
Client<E>,
impl std::future::Future<Output = Result<(), Error>>,
) {
let client_connector = crate::ClientConnectorTokio::default();
let (handle, mut future) = connect_with(client_fn, config, client_connector);
let future = async move {
future
.extract()
.await
.unwrap_or(Err("client actor crashed".into()))
};
(handle, future)
}
pub fn connect_with<E: ClientExt + 'static>(
client_fn: impl FnOnce(Client<E>) -> E,
config: ClientConfig,
client_connector: impl ClientConnector + Send + Sync + 'static,
) -> (Client<E>, enfync::PendingResult<Result<(), Error>>) {
let (to_socket_sender, mut to_socket_receiver) = async_channel::unbounded();
let (client_call_sender, client_call_receiver) = async_channel::unbounded();
let handle = Client {
to_socket_sender,
client_call_sender,
};
let mut client = client_fn(handle.clone());
let runtime_handle = client_connector.handle();
let future = runtime_handle.spawn(
async move {
tracing::info!("connecting to {}...", config.url);
let Some(socket) = client_connect(
config.max_initial_connect_attempts,
&config,
&client_connector,
&mut to_socket_receiver,
&mut client,
)
.await?
else {
return Ok(());
};
tracing::info!("connected to {}", config.url);
let mut actor = ClientActor {
client,
to_socket_receiver,
client_call_receiver,
config,
client_connector,
};
actor.run(Some(socket)).await?;
Ok(())
}
.instrument(tracing::Span::current()),
);
(handle, future)
}
struct ClientActor<E: ClientExt, C: ClientConnector> {
client: E,
to_socket_receiver: async_channel::Receiver<InMessage>,
client_call_receiver: async_channel::Receiver<E::Call>,
config: ClientConfig,
client_connector: C,
}
impl<E: ClientExt, C: ClientConnector> ClientActor<E, C> {
async fn run(&mut self, mut socket_shuttle: Option<Socket>) -> Result<(), Error> {
loop {
let Some(mut socket) = socket_shuttle.take() else {
return Ok(());
};
futures::select! {
res = self.to_socket_receiver.recv().fuse() => {
let Ok(inmessage) = res else {
break;
};
socket_shuttle = self.handle_outgoing_msg(socket, inmessage).await?;
}
res = self.client_call_receiver.recv().fuse() => {
let Ok(call) = res else {
break;
};
self.client.on_call(call).await?;
socket_shuttle = Some(socket);
}
result = socket.stream.recv().fuse() => {
socket_shuttle = self.handle_incoming_msg(socket, result).await?;
}
}
}
Ok(())
}
async fn handle_outgoing_msg(
&mut self,
mut socket: Socket,
inmessage: InMessage,
) -> Result<Option<Socket>, Error> {
let closed_self = matches!(inmessage.message, Some(Message::Close(_)));
if socket.send(inmessage).await.is_err() {
let result = socket.await_sink_close().await;
if let Err(err) = &result {
tracing::warn!(
?err,
"encountered sink closing error when trying to send a message"
);
}
match result {
Err(WSError::Protocol(ProtocolError::SendAfterClosing))
| Err(WSError::ConnectionClosed)
| Err(WSError::AlreadyClosed)
if !closed_self =>
{
tracing::debug!("client socket closed");
return self.handle_disconnect(socket).await;
}
Err(WSError::Io(_)) | Err(WSError::Tls(_)) if !closed_self => {
tracing::debug!("client socket IO send error");
return self.handle_disconnect(socket).await;
}
Err(_) if !closed_self => {
return Err(Error::from("unexpected sink error, aborting client actor"));
}
_ => (),
}
}
if closed_self {
tracing::debug!("client closed itself");
return Ok(None);
}
Ok(Some(socket))
}
async fn handle_incoming_msg(
&mut self,
socket: Socket,
result: Option<Result<Message, WSError>>,
) -> Result<Option<Socket>, Error> {
match result {
Some(Ok(message)) => {
match message.to_owned() {
Message::Text(text) => self.client.on_text(text).await?,
Message::Binary(bytes) => self.client.on_binary(bytes).await?,
Message::Close(frame) => {
tracing::debug!("client closed by server");
return self.handle_close(frame, socket).await;
}
};
}
Some(Err(error)) => {
let error = Error::from(error);
tracing::warn!("connection error: {error}");
return self.handle_disconnect(socket).await;
}
None => {
tracing::debug!("client socket died");
return self.handle_disconnect(socket).await;
}
}
Ok(Some(socket))
}
async fn handle_close(
&mut self,
frame: Option<CloseFrame>,
socket: Socket,
) -> Result<Option<Socket>, Error> {
match self.client.on_close(frame).await? {
ClientCloseMode::Reconnect => {
std::mem::drop(socket);
sleep(self.config.reconnect_interval).await;
client_connect(
self.config.max_reconnect_attempts,
&self.config,
&self.client_connector,
&mut self.to_socket_receiver,
&mut self.client,
)
.await
}
ClientCloseMode::Close => Ok(None),
}
}
async fn handle_disconnect(&mut self, socket: Socket) -> Result<Option<Socket>, Error> {
match self.client.on_disconnect().await? {
ClientCloseMode::Reconnect => {
std::mem::drop(socket);
client_connect(
self.config.max_reconnect_attempts,
&self.config,
&self.client_connector,
&mut self.to_socket_receiver,
&mut self.client,
)
.await
}
ClientCloseMode::Close => Ok(None),
}
}
}
async fn client_connect<E: ClientExt, Connector: ClientConnector>(
max_attempts: usize,
config: &ClientConfig,
client_connector: &Connector,
to_socket_receiver: &mut async_channel::Receiver<InMessage>,
client: &mut E,
) -> Result<Option<Socket>, Error> {
for i in 1.. {
loop {
let in_message = to_socket_receiver.try_recv();
match in_message {
Ok(inmessage) => match &inmessage.message {
Some(Message::Close(frame)) => {
tracing::debug!(?frame, "client closed itself while connecting");
return Ok(None);
}
_ => {
tracing::warn!("client is connecting, discarding message from user");
continue;
}
},
Err(async_channel::TryRecvError::Empty) => break,
Err(async_channel::TryRecvError::Closed) => {
tracing::warn!("client is dead, aborting connection attempts");
return Err(Error::from("client died while trying to connect"));
}
}
}
tracing::info!("connecting attempt no: {}...", i);
let result = client_connector.connect(config).await;
match result {
Ok(socket_impl) => {
tracing::info!("successfully connected");
client.on_connect().await?;
let socket = Socket::new(
socket_impl,
config.socket_config.clone().unwrap_or_default(),
client_connector.handle(),
);
return Ok(Some(socket));
}
Err(err) => {
tracing::warn!("connecting failed due to {}", err);
match client.on_connect_fail(err.into()).await? {
ClientCloseMode::Reconnect => {
tracing::debug!("will retry in {}s", config.reconnect_interval.as_secs());
}
ClientCloseMode::Close => {
tracing::debug!("client closed itself after a connection failure");
return Ok(None);
}
}
}
};
if i >= max_attempts {
return Err(Error::from(format!(
"failed to connect after {} attempt(s), aborting...",
i
)));
}
sleep(config.reconnect_interval).await;
}
Err(Error::from("client failed to connect"))
}