use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use futures_util::StreamExt;
use tokio::sync::{broadcast, mpsc, watch, RwLock};
use tokio_util::sync::CancellationToken;
use tungstenite::protocol::WebSocketConfig;
use tungstenite::Message;
use crate::connection::ConnectionEvent;
use crate::connection::{
ConnectionSnapshot, ConnectionSupervisor, DefaultConnector, ExponentialBackoff, NoRetry,
RetryStrategy,
};
use crate::error::{ClientError, DisconnectReason, HandshakeError, SendError, SupervisorError};
use crate::extension::{Extension, ExtensionHost};
use crate::handshake::{BoxHandshaker, Handshaker, NoOpHandshaker};
use crate::message::{DispatcherConfig, MessageDispatcher, ProcessorErrorPolicy, SharedMessage};
#[derive(Clone)]
pub struct ClientConfig {
pub receive_timeout: Duration,
pub exit_on_first_failure: bool,
pub connect_timeout: Duration,
pub handshake_retry_delay: Duration,
pub ws_config: Option<WebSocketConfig>,
pub disable_nagle: bool,
pub channel_buffer_size: usize,
pub send_queue_capacity: usize,
pub processor_error_policy: ProcessorErrorPolicy,
}
impl Default for ClientConfig {
fn default() -> Self {
Self {
receive_timeout: Duration::from_secs(20),
exit_on_first_failure: false,
connect_timeout: Duration::from_secs(30),
handshake_retry_delay: Duration::from_secs(5),
ws_config: None,
disable_nagle: false,
channel_buffer_size: 256,
send_queue_capacity: 256,
processor_error_policy: ProcessorErrorPolicy::Ignore,
}
}
}
impl ClientConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn with_receive_timeout(mut self, timeout: Duration) -> Self {
self.receive_timeout = timeout;
self
}
#[must_use]
pub const fn with_exit_on_first_failure(mut self, exit: bool) -> Self {
self.exit_on_first_failure = exit;
self
}
#[must_use]
pub const fn with_connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = timeout;
self
}
#[must_use]
pub const fn with_handshake_retry_delay(mut self, delay: Duration) -> Self {
self.handshake_retry_delay = delay;
self
}
#[must_use]
#[allow(clippy::missing_const_for_fn)] pub fn with_ws_config(mut self, config: WebSocketConfig) -> Self {
self.ws_config = Some(config);
self
}
#[must_use]
pub const fn with_nodelay(mut self, nodelay: bool) -> Self {
self.disable_nagle = nodelay;
self
}
#[must_use]
pub const fn with_channel_buffer(mut self, size: usize) -> Self {
self.channel_buffer_size = size;
self
}
#[must_use]
pub const fn with_send_queue_capacity(mut self, cap: usize) -> Self {
self.send_queue_capacity = cap;
self
}
#[must_use]
pub const fn with_processor_error_policy(mut self, policy: ProcessorErrorPolicy) -> Self {
self.processor_error_policy = policy;
self
}
#[must_use]
pub const fn fast_reconnect() -> Self {
Self {
receive_timeout: Duration::from_secs(10),
exit_on_first_failure: false,
connect_timeout: Duration::from_secs(10),
handshake_retry_delay: Duration::from_millis(500),
ws_config: None,
disable_nagle: true,
channel_buffer_size: 512,
send_queue_capacity: 512,
processor_error_policy: ProcessorErrorPolicy::Ignore,
}
}
#[must_use]
pub const fn stable_connection() -> Self {
Self {
receive_timeout: Duration::from_secs(60),
exit_on_first_failure: false,
connect_timeout: Duration::from_secs(60),
handshake_retry_delay: Duration::from_secs(2),
ws_config: None,
disable_nagle: false,
channel_buffer_size: 128,
send_queue_capacity: 128,
processor_error_policy: ProcessorErrorPolicy::Ignore,
}
}
}
impl From<&ClientConfig> for DispatcherConfig {
fn from(config: &ClientConfig) -> Self {
Self::new()
.with_receive_timeout(config.receive_timeout)
.with_broadcast_capacity(config.channel_buffer_size)
.with_send_buffer_capacity(config.send_queue_capacity)
.with_processor_error_policy(config.processor_error_policy)
}
}
#[derive(Clone)]
pub struct Sender {
tx: mpsc::Sender<Message>,
}
impl Sender {
pub fn send(&self, message: Message) -> Result<(), SendError> {
match self.tx.try_send(message) {
Ok(()) => Ok(()),
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => Err(SendError::ChannelFull),
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => Err(SendError::ChannelClosed),
}
}
pub fn send_text(&self, text: impl Into<String>) -> Result<(), SendError> {
self.send(Message::Text(text.into().into()))
}
pub fn send_binary(&self, data: impl Into<Vec<u8>>) -> Result<(), SendError> {
self.send(Message::Binary(data.into().into()))
}
pub fn ping(&self, data: impl Into<Vec<u8>>) -> Result<(), SendError> {
self.send(Message::Ping(data.into().into()))
}
pub async fn send_async(&self, message: Message) -> Result<(), SendError> {
self.tx
.send(message)
.await
.map_err(|_| SendError::ChannelClosed)
}
pub async fn send_text_async(&self, text: impl Into<String>) -> Result<(), SendError> {
self.send_async(Message::Text(text.into().into())).await
}
pub async fn send_binary_async(&self, data: impl Into<Vec<u8>>) -> Result<(), SendError> {
self.send_async(Message::Binary(data.into().into())).await
}
pub async fn ping_async(&self, data: impl Into<Vec<u8>>) -> Result<(), SendError> {
self.send_async(Message::Ping(data.into().into())).await
}
pub async fn send_timeout(&self, message: Message, timeout: Duration) -> Result<(), SendError> {
match tokio::time::timeout(timeout, self.tx.send(message)).await {
Ok(Ok(())) => Ok(()),
Ok(Err(_)) => Err(SendError::ChannelClosed),
Err(_) => Err(SendError::Timeout(timeout)),
}
}
}
struct ClientRuntime {
is_running: AtomicBool,
cancel: CancellationToken,
message_tx: broadcast::Sender<SharedMessage>,
send_tx: Arc<RwLock<Option<mpsc::Sender<Message>>>>,
dispatcher: Arc<MessageDispatcher<crate::connection::DefaultWsStream>>,
run_state: watch::Sender<bool>,
}
impl ClientRuntime {
fn new(config: &ClientConfig) -> Self {
let (message_tx, _) = broadcast::channel(config.channel_buffer_size);
let dispatcher_config = DispatcherConfig::from(config);
let (run_state, _rx) = watch::channel(false);
Self {
is_running: AtomicBool::new(false),
cancel: CancellationToken::new(),
message_tx,
send_tx: Arc::new(RwLock::new(None)),
dispatcher: Arc::new(MessageDispatcher::new(dispatcher_config)),
run_state,
}
}
fn begin_run(&self) -> Result<(), ClientError> {
if self.is_running.swap(true, Ordering::SeqCst) {
Err(ClientError::AlreadyRunning)
} else {
let _ = self.run_state.send(true);
Ok(())
}
}
fn finish_run(&self) {
self.is_running.store(false, Ordering::SeqCst);
let _ = self.run_state.send(false);
}
fn cancel(&self) {
self.cancel.cancel();
}
fn cancel_token(&self) -> CancellationToken {
self.cancel.clone()
}
fn is_cancelled(&self) -> bool {
self.cancel.is_cancelled()
}
fn is_running(&self) -> bool {
self.is_running.load(Ordering::SeqCst)
}
fn subscribe(&self) -> broadcast::Receiver<SharedMessage> {
self.message_tx.subscribe()
}
fn message_channel(&self) -> broadcast::Sender<SharedMessage> {
self.message_tx.clone()
}
fn dispatcher(&self) -> Arc<MessageDispatcher<crate::connection::DefaultWsStream>> {
self.dispatcher.clone()
}
async fn sender(&self) -> Option<Sender> {
let guard = self.send_tx.read().await;
guard.as_ref().map(|tx| Sender { tx: tx.clone() })
}
async fn send(&self, message: Message) -> Result<(), SendError> {
let guard = self.send_tx.read().await;
guard.as_ref().map_or(Err(SendError::NotConnected), |tx| {
match tx.try_send(message) {
Ok(()) => Ok(()),
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => Err(SendError::ChannelFull),
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
Err(SendError::ChannelClosed)
}
}
})
}
async fn send_async(&self, message: Message) -> Result<(), SendError> {
let tx = self
.send_tx
.read()
.await
.as_ref()
.ok_or(SendError::NotConnected)?
.clone();
tx.send(message).await.map_err(|_| SendError::ChannelClosed)
}
async fn send_timeout(&self, message: Message, timeout: Duration) -> Result<(), SendError> {
let tx = self
.send_tx
.read()
.await
.as_ref()
.ok_or(SendError::NotConnected)?
.clone();
match tokio::time::timeout(timeout, tx.send(message)).await {
Ok(Ok(())) => Ok(()),
Ok(Err(_)) => Err(SendError::ChannelClosed),
Err(_) => Err(SendError::Timeout(timeout)),
}
}
async fn set_send_channel(&self, tx: mpsc::Sender<Message>) {
let mut guard = self.send_tx.write().await;
*guard = Some(tx);
}
async fn clear_send_channel(&self) {
let mut guard = self.send_tx.write().await;
*guard = None;
}
fn run_state_receiver(&self) -> watch::Receiver<bool> {
self.run_state.subscribe()
}
}
pub struct WebSocketClient {
uri: String,
config: ClientConfig,
handshaker: BoxHandshaker,
extension_host: Arc<ExtensionHost>,
supervisor: ConnectionSupervisor<DefaultConnector>,
runtime: Arc<ClientRuntime>,
}
impl WebSocketClient {
pub fn builder(uri: impl Into<String>) -> WebSocketClientBuilder {
WebSocketClientBuilder::new(uri)
}
pub fn new(uri: impl Into<String>) -> Self {
Self::builder(uri).build()
}
#[must_use]
pub fn subscribe(&self) -> broadcast::Receiver<SharedMessage> {
self.runtime.subscribe()
}
#[must_use]
pub fn uri(&self) -> &str {
&self.uri
}
#[must_use]
pub fn subscribe_events(&self) -> broadcast::Receiver<ConnectionEvent> {
self.supervisor.subscribe()
}
pub async fn sender(&self) -> Option<Sender> {
self.runtime.sender().await
}
pub async fn send(&self, message: Message) -> Result<(), SendError> {
self.runtime.send(message).await
}
pub async fn send_async(&self, message: Message) -> Result<(), SendError> {
self.runtime.send_async(message).await
}
pub async fn send_timeout(&self, message: Message, timeout: Duration) -> Result<(), SendError> {
self.runtime.send_timeout(message, timeout).await
}
pub async fn state(&self) -> ConnectionSnapshot {
self.supervisor.snapshot().await
}
#[must_use]
pub fn is_connected(&self) -> bool {
self.supervisor.is_connected()
}
pub async fn register_extension<E: Extension + 'static>(
&self,
extension: E,
) -> Result<(), ClientError> {
self.extension_host
.register(extension)
.await
.map_err(ClientError::Extension)
}
pub async fn run(&self) -> Result<(), ClientError> {
self.runtime.begin_run()?;
let result = self.run_loop().await;
self.runtime.finish_run();
result
}
pub fn shutdown(&self) {
self.runtime.cancel();
self.supervisor.shutdown();
}
pub async fn shutdown_graceful(&self, timeout: Duration) -> Result<(), ClientError> {
let mut run_state = self.runtime.run_state_receiver();
self.shutdown();
if !self.runtime.is_running() || !*run_state.borrow() {
return Ok(());
}
let wait_for_shutdown = async {
while run_state.changed().await.is_ok() {
if !*run_state.borrow() {
break;
}
}
};
match tokio::time::timeout(timeout, wait_for_shutdown).await {
Ok(()) => Ok(()),
Err(_) => Err(ClientError::ShutdownTimeout(timeout)),
}
}
async fn run_loop(&self) -> Result<(), ClientError> {
loop {
if self.runtime.is_cancelled() {
tracing::info!("Shutdown requested");
self.extension_host.shutdown().await?;
return Ok(());
}
let (stream, mut send_rx, connection_id) = match self.establish_session().await {
Ok(t) => t,
Err(ClientError::Supervisor(SupervisorError::Shutdown)) => {
tracing::info!("Supervisor shutdown requested");
self.extension_host.shutdown().await?;
return Ok(());
}
Err(ClientError::Handshake(_)) => {
continue;
}
Err(e) => {
self.extension_host.shutdown().await?;
return Err(e);
}
};
let (mut recv_task, forward_task) = self.spawn_receiver_and_bridge(stream);
let disconnect_reason = self.drive_session(&mut send_rx, &mut recv_task).await;
self.cleanup_session(forward_task, disconnect_reason, connection_id)
.await?;
}
}
async fn connect_via_supervisor(
&self,
) -> Result<crate::connection::DefaultWsStream, ClientError> {
match self.supervisor.connect().await {
Ok(stream) => Ok(stream),
Err(e) => Err(ClientError::Supervisor(e)),
}
}
async fn establish_session(
&self,
) -> Result<
(
futures_util::stream::SplitStream<crate::connection::DefaultWsStream>,
mpsc::Receiver<Message>,
u64,
),
ClientError,
> {
let ws_stream = self.connect_via_supervisor().await?;
let (mut sink, mut stream) = ws_stream.split();
let (send_tx, send_rx) = mpsc::channel::<Message>(self.config.send_queue_capacity);
self.runtime.set_send_channel(send_tx).await;
if let Err(e) = self.perform_handshake(&mut sink, &mut stream).await {
tracing::error!(error = ?e, "Handshake failed");
self.supervisor
.mark_disconnected(DisconnectReason::Error(e.to_string()))
.await;
if self.handshaker.is_retryable(&e) {
tokio::time::sleep(self.config.handshake_retry_delay).await;
return Err(ClientError::Handshake(e));
}
self.supervisor
.fatal(crate::error::ConnectError::HandshakeFailed(e.to_string()));
self.supervisor.shutdown();
return Err(ClientError::Supervisor(SupervisorError::Shutdown));
}
self.runtime.dispatcher().attach(sink).await;
let connection_id = self.supervisor.connection_id();
let snapshot = self.supervisor.snapshot().await;
self.extension_host
.update_context(connection_id, snapshot.reconnect_count)
.await;
let _ = self.extension_host.notify_connect().await;
tracing::info!(connection_id = connection_id, "Connected");
Ok((stream, send_rx, connection_id))
}
fn spawn_receiver_and_bridge(
&self,
stream: futures_util::stream::SplitStream<crate::connection::DefaultWsStream>,
) -> (
tokio::task::JoinHandle<Result<(), crate::error::ReceiveError>>,
tokio::task::JoinHandle<()>,
) {
let dispatcher = self.runtime.dispatcher();
let ext_host = self.extension_host.clone();
let mut disp_rx = dispatcher.subscribe();
let client_broadcast = self.runtime.message_channel();
let cancel_token = self.runtime.cancel_token();
let forward_task = tokio::spawn(async move {
loop {
tokio::select! {
() = cancel_token.cancelled() => break,
msg = disp_rx.recv() => {
match msg {
Ok(m) => { let _ = client_broadcast.send(m); }
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => { }
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
}
}
}
}
});
let activity = self.supervisor.activity_handle();
let recv_task = tokio::spawn(async move {
dispatcher
.receive_loop_with_processor(
stream,
move || {
let activity = activity.clone();
async move { activity.update().await }
},
move |msg| {
let ext_host = ext_host.clone();
async move { ext_host.process_message(&msg).await }
},
)
.await
});
(recv_task, forward_task)
}
async fn drive_session(
&self,
send_rx: &mut mpsc::Receiver<Message>,
recv_task: &mut tokio::task::JoinHandle<Result<(), crate::error::ReceiveError>>,
) -> Option<DisconnectReason> {
let cancel_token = self.runtime.cancel_token();
let dispatcher = self.runtime.dispatcher();
loop {
tokio::select! {
() = cancel_token.cancelled() => {
recv_task.abort();
return Some(DisconnectReason::Shutdown);
}
res = &mut *recv_task => {
return match res {
Ok(Ok(())) => Some(DisconnectReason::Normal),
Ok(Err(e)) => Some(match e {
crate::error::ReceiveError::Timeout(_) => DisconnectReason::Timeout,
crate::error::ReceiveError::StreamClosed => DisconnectReason::Normal,
crate::error::ReceiveError::WebSocket(err) => DisconnectReason::Error(err),
}),
Err(_) => Some(DisconnectReason::Error("receiver task aborted".to_string())),
}
}
msg = send_rx.recv() => {
if let Some(message) = msg {
if let Err(e) = dispatcher.send(message).await {
return Some(DisconnectReason::Error(format!("send error: {e:?}")));
}
} else {
return Some(DisconnectReason::Error("send channel closed".to_string()));
}
}
}
}
}
async fn cleanup_session(
&self,
forward_task: tokio::task::JoinHandle<()>,
disconnect_reason: Option<DisconnectReason>,
connection_id: u64,
) -> Result<(), ClientError> {
self.runtime.clear_send_channel().await;
forward_task.abort();
self.runtime.dispatcher().detach().await;
let reason = disconnect_reason.unwrap_or(DisconnectReason::Normal);
self.supervisor.mark_disconnected(reason.clone()).await;
let _ = self.extension_host.notify_disconnect().await;
tracing::info!(
connection_id = connection_id,
reason = ?Some(reason.clone()),
"Disconnected"
);
if self.runtime.is_cancelled() {
tracing::info!("Shutdown requested after disconnect");
self.extension_host.shutdown().await?;
return Ok(());
}
Ok(())
}
async fn perform_handshake<S, St>(
&self,
sink: &mut S,
stream: &mut St,
) -> Result<(), HandshakeError>
where
S: futures_util::Sink<Message, Error = tungstenite::Error> + Unpin + Send,
St: futures_util::Stream<Item = Result<Message, tungstenite::Error>> + Unpin + Send,
{
use crate::context::ConnectionContext;
let snapshot = self.supervisor.snapshot().await;
let context =
ConnectionContext::new(snapshot.id).with_reconnect_count(snapshot.reconnect_count);
self.handshaker
.handshake_with_timeout(sink, stream, &context)
.await
}
}
pub struct WebSocketClientBuilder {
uri: String,
config: ClientConfig,
retry_strategy: Box<dyn RetryStrategy>,
handshaker: BoxHandshaker,
}
impl WebSocketClientBuilder {
pub fn new(uri: impl Into<String>) -> Self {
Self {
uri: uri.into(),
config: ClientConfig::default(),
retry_strategy: Box::new(ExponentialBackoff::default()),
handshaker: Box::new(NoOpHandshaker),
}
}
#[must_use]
#[allow(clippy::missing_const_for_fn)] pub fn config(mut self, config: ClientConfig) -> Self {
self.config = config;
self
}
#[must_use]
pub const fn receive_timeout(mut self, timeout: Duration) -> Self {
self.config.receive_timeout = timeout;
self
}
#[must_use]
pub fn retry_strategy<S: RetryStrategy + 'static>(mut self, strategy: S) -> Self {
self.retry_strategy = Box::new(strategy);
self
}
#[must_use]
pub fn handshaker<H: Handshaker + 'static>(mut self, handshaker: H) -> Self {
self.handshaker = Box::new(handshaker);
self
}
#[must_use]
pub fn no_retry(mut self) -> Self {
self.retry_strategy = Box::new(NoRetry);
self
}
#[must_use]
pub fn exponential_backoff(
mut self,
initial: Duration,
max: Duration,
multiplier: f64,
) -> Self {
self.retry_strategy = Box::new(
ExponentialBackoff::new(initial, max)
.with_factor(multiplier)
.with_jitter(0.1),
);
self
}
#[must_use]
pub fn build(self) -> WebSocketClient {
let runtime = Arc::new(ClientRuntime::new(&self.config));
let connector = DefaultConnector::new()
.with_nodelay(self.config.disable_nagle)
;
let connector = if let Some(ws_cfg) = self.config.ws_config {
connector.with_ws_config(ws_cfg)
} else {
connector
};
let mut sup_cfg = crate::connection::SupervisorConfig::new();
sup_cfg.retry_strategy = self.retry_strategy;
sup_cfg.exit_on_first_failure = self.config.exit_on_first_failure;
sup_cfg.connect_timeout = self.config.connect_timeout;
let supervisor =
ConnectionSupervisor::with_connector(self.uri.clone(), connector).with_config(sup_cfg);
WebSocketClient {
uri: self.uri,
config: self.config,
handshaker: self.handshaker,
extension_host: Arc::new(ExtensionHost::new()),
supervisor,
runtime,
}
}
}
pub trait WebSocketClientExt {
fn spawn(self: Arc<Self>) -> tokio::task::JoinHandle<Result<(), ClientError>>;
}
impl WebSocketClientExt for WebSocketClient {
fn spawn(self: Arc<Self>) -> tokio::task::JoinHandle<Result<(), ClientError>> {
tokio::spawn(async move { self.run().await })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_config_defaults() {
let config = ClientConfig::default();
assert_eq!(config.receive_timeout, Duration::from_secs(20));
assert!(!config.exit_on_first_failure);
assert!(!config.disable_nagle);
}
#[test]
fn test_client_config_presets() {
let fast = ClientConfig::fast_reconnect();
assert_eq!(fast.receive_timeout, Duration::from_secs(10));
assert!(fast.disable_nagle);
let stable = ClientConfig::stable_connection();
assert_eq!(stable.receive_timeout, Duration::from_secs(60));
}
#[test]
fn test_builder() {
let client = WebSocketClient::builder("ws://localhost:8080")
.receive_timeout(Duration::from_secs(30))
.no_retry()
.build();
assert_eq!(client.config.receive_timeout, Duration::from_secs(30));
}
#[tokio::test]
async fn test_sender_backpressure_full() {
use tokio::sync::mpsc;
let (tx, mut _rx) = mpsc::channel::<Message>(1);
let sender = Sender { tx };
assert!(sender.send(Message::Text("a".into())).is_ok());
let res = sender.send(Message::Text("b".into()));
assert!(matches!(res, Err(crate::error::SendError::ChannelFull)));
}
#[tokio::test]
async fn test_client_shutdown_exits_quickly() {
let client = WebSocketClient::builder("wss://example.test/ws")
.receive_timeout(std::time::Duration::from_millis(100))
.no_retry()
.build();
let client = std::sync::Arc::new(client);
client.shutdown();
let h = {
let c = client.clone();
tokio::spawn(async move { c.run().await })
};
let res = tokio::time::timeout(std::time::Duration::from_secs(1), h).await;
assert!(res.is_ok(), "run() did not exit in time");
let run_res = res.unwrap().unwrap();
assert!(run_res.is_ok(), "run() returned error: {run_res:?}");
}
}