use futures_util::{SinkExt, StreamExt};
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::{broadcast, mpsc, Mutex};
use tokio::time::{Duration, MissedTickBehavior};
use tokio_tungstenite::{connect_async, tungstenite::Message};
use tracing::{debug, warn};
use url::Url;
use crate::error::{RelayError, Result};
use crate::types::WsEvent;
const DEFAULT_BASE_URL: &str = "https://api.relaycast.dev";
const SDK_VERSION: &str = env!("CARGO_PKG_VERSION");
const DEFAULT_ORIGIN_SURFACE: &str = "sdk";
const DEFAULT_ORIGIN_CLIENT: &str = "@relaycast/sdk-rust";
const PING_INTERVAL_SECS: u64 = 30;
const DEFAULT_MAX_RECONNECT_ATTEMPTS: u32 = 10;
const DEFAULT_MAX_RECONNECT_DELAY_MS: u64 = 30_000;
#[derive(Debug, Clone)]
pub struct WsClientOptions {
pub token: String,
pub base_url: Option<String>,
pub debug: bool,
pub origin_surface: Option<String>,
pub origin_client: Option<String>,
pub origin_version: Option<String>,
pub max_reconnect_attempts: Option<u32>,
pub max_reconnect_delay_ms: Option<u64>,
}
impl WsClientOptions {
pub fn new(token: impl Into<String>) -> Self {
Self {
token: token.into(),
base_url: None,
debug: false,
origin_surface: None,
origin_client: None,
origin_version: None,
max_reconnect_attempts: None,
max_reconnect_delay_ms: None,
}
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = Some(base_url.into());
self
}
pub fn with_debug(mut self, debug: bool) -> Self {
self.debug = debug;
self
}
pub fn with_origin(
mut self,
origin_surface: impl Into<String>,
origin_client: impl Into<String>,
origin_version: impl Into<String>,
) -> Self {
self.origin_surface = Some(origin_surface.into());
self.origin_client = Some(origin_client.into());
self.origin_version = Some(origin_version.into());
self
}
pub fn with_max_reconnect_attempts(mut self, attempts: u32) -> Self {
self.max_reconnect_attempts = Some(attempts);
self
}
pub fn with_max_reconnect_delay_ms(mut self, delay_ms: u64) -> Self {
self.max_reconnect_delay_ms = Some(delay_ms);
self
}
}
pub type EventReceiver = broadcast::Receiver<WsEvent>;
pub type LifecycleReceiver = broadcast::Receiver<WsLifecycleEvent>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum WsLifecycleEvent {
Open,
Close,
Error(String),
Reconnecting { attempt: u32 },
}
pub struct WsClient {
token: Arc<Mutex<String>>,
base_url: String,
debug: bool,
origin_surface: String,
origin_client: String,
origin_version: String,
max_reconnect_attempts: u32,
max_reconnect_delay_ms: u64,
event_tx: broadcast::Sender<WsEvent>,
lifecycle_tx: broadcast::Sender<WsLifecycleEvent>,
command_tx: Option<mpsc::Sender<WsCommand>>,
is_connected: Arc<Mutex<bool>>,
}
enum WsCommand {
Subscribe(Vec<String>),
Unsubscribe(Vec<String>),
Disconnect,
}
impl WsClient {
pub fn new(options: WsClientOptions) -> Self {
let base_url = options
.base_url
.unwrap_or_else(|| DEFAULT_BASE_URL.to_string())
.replace("https://", "wss://")
.replace("http://", "ws://");
let (event_tx, _) = broadcast::channel(1024);
let (lifecycle_tx, _) = broadcast::channel(128);
Self {
token: Arc::new(Mutex::new(options.token)),
base_url: base_url.trim_end_matches('/').to_string(),
debug: options.debug,
origin_surface: options
.origin_surface
.unwrap_or_else(|| DEFAULT_ORIGIN_SURFACE.to_string()),
origin_client: options
.origin_client
.unwrap_or_else(|| DEFAULT_ORIGIN_CLIENT.to_string()),
origin_version: options
.origin_version
.unwrap_or_else(|| SDK_VERSION.to_string()),
max_reconnect_attempts: options
.max_reconnect_attempts
.unwrap_or(DEFAULT_MAX_RECONNECT_ATTEMPTS),
max_reconnect_delay_ms: options
.max_reconnect_delay_ms
.unwrap_or(DEFAULT_MAX_RECONNECT_DELAY_MS),
event_tx,
lifecycle_tx,
command_tx: None,
is_connected: Arc::new(Mutex::new(false)),
}
}
pub async fn is_connected(&self) -> bool {
*self.is_connected.lock().await
}
pub fn subscribe_events(&self) -> EventReceiver {
self.event_tx.subscribe()
}
pub fn subscribe_lifecycle(&self) -> LifecycleReceiver {
self.lifecycle_tx.subscribe()
}
pub async fn set_token(&self, token: impl Into<String>) {
*self.token.lock().await = token.into();
}
pub async fn connect(&mut self) -> Result<()> {
if *self.is_connected.lock().await {
return Ok(());
}
let mut url = Url::parse(&format!("{}/v1/ws", self.base_url))?;
{
let token = self.token.lock().await.clone();
let mut query = url.query_pairs_mut();
query.append_pair("token", &token);
query.append_pair("origin_surface", &self.origin_surface);
query.append_pair("origin_client", &self.origin_client);
query.append_pair("origin_version", &self.origin_version);
}
let (ws_stream, _) = connect_async(url.as_str()).await?;
let (command_tx, mut command_rx) = mpsc::channel::<WsCommand>(32);
self.command_tx = Some(command_tx);
let token = self.token.clone();
let event_tx = self.event_tx.clone();
let lifecycle_tx = self.lifecycle_tx.clone();
let is_connected = self.is_connected.clone();
let debug = self.debug;
let base_url = self.base_url.clone();
let origin_surface = self.origin_surface.clone();
let origin_client = self.origin_client.clone();
let origin_version = self.origin_version.clone();
let max_reconnect_attempts = self.max_reconnect_attempts;
let max_reconnect_delay_ms = self.max_reconnect_delay_ms;
*is_connected.lock().await = true;
tokio::spawn(async move {
let mut subscribed_channels: HashSet<String> = HashSet::new();
let mut current_stream = Some(ws_stream);
let mut reconnect_attempt = 0u32;
let mut should_stop = false;
'outer: while !should_stop {
let stream = if let Some(stream) = current_stream.take() {
stream
} else {
let mut reconnect_url = match Url::parse(&format!("{}/v1/ws", base_url)) {
Ok(url) => url,
Err(err) => {
let _ = lifecycle_tx.send(WsLifecycleEvent::Error(err.to_string()));
break 'outer;
}
};
let current_token = token.lock().await.clone();
{
let mut query = reconnect_url.query_pairs_mut();
query.append_pair("token", ¤t_token);
query.append_pair("origin_surface", &origin_surface);
query.append_pair("origin_client", &origin_client);
query.append_pair("origin_version", &origin_version);
}
match connect_async(reconnect_url.as_str()).await {
Ok((stream, _)) => stream,
Err(err) => {
let _ = lifecycle_tx.send(WsLifecycleEvent::Error(err.to_string()));
if reconnect_attempt >= max_reconnect_attempts {
break 'outer;
}
reconnect_attempt += 1;
let _ = lifecycle_tx.send(WsLifecycleEvent::Reconnecting {
attempt: reconnect_attempt,
});
let delay_ms =
reconnect_delay_ms(reconnect_attempt, max_reconnect_delay_ms);
let reconnect_sleep =
tokio::time::sleep(Duration::from_millis(delay_ms));
tokio::pin!(reconnect_sleep);
loop {
tokio::select! {
_ = &mut reconnect_sleep => break,
cmd = command_rx.recv() => {
match cmd {
Some(WsCommand::Subscribe(channels)) => {
for ch in channels {
subscribed_channels.insert(ch);
}
}
Some(WsCommand::Unsubscribe(channels)) => {
for ch in channels {
subscribed_channels.remove(&ch);
}
}
Some(WsCommand::Disconnect) | None => {
should_stop = true;
break;
}
}
}
}
}
continue;
}
}
};
let (mut write, mut read) = stream.split();
reconnect_attempt = 0;
*is_connected.lock().await = true;
let _ = lifecycle_tx.send(WsLifecycleEvent::Open);
if !subscribed_channels.is_empty() {
let msg = serde_json::json!({
"type": "subscribe",
"channels": subscribed_channels.iter().cloned().collect::<Vec<_>>()
});
if let Err(err) = write.send(Message::Text(msg.to_string())).await {
let _ = lifecycle_tx.send(WsLifecycleEvent::Error(err.to_string()));
*is_connected.lock().await = false;
continue;
}
}
let mut ping_interval =
tokio::time::interval(Duration::from_secs(PING_INTERVAL_SECS));
ping_interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
loop {
tokio::select! {
msg = read.next() => {
match msg {
Some(Ok(Message::Text(text))) => {
match serde_json::from_str::<WsEvent>(&text) {
Ok(event) => {
let _ = event_tx.send(event);
}
Err(err) => {
if debug {
warn!("[relaycast] Dropped WebSocket message: {}: {}", err, &text[..text.len().min(200)]);
}
}
}
}
Some(Ok(Message::Close(_))) | None => {
debug!("WebSocket connection closed");
break;
}
Some(Err(err)) => {
let _ = lifecycle_tx.send(WsLifecycleEvent::Error(err.to_string()));
break;
}
_ => {}
}
}
cmd = command_rx.recv() => {
match cmd {
Some(WsCommand::Subscribe(channels)) => {
for ch in &channels {
subscribed_channels.insert(ch.clone());
}
let msg = serde_json::json!({
"type": "subscribe",
"channels": channels
});
if let Err(err) = write.send(Message::Text(msg.to_string())).await {
let _ = lifecycle_tx.send(WsLifecycleEvent::Error(err.to_string()));
break;
}
}
Some(WsCommand::Unsubscribe(channels)) => {
for ch in &channels {
subscribed_channels.remove(ch);
}
let msg = serde_json::json!({
"type": "unsubscribe",
"channels": channels
});
if let Err(err) = write.send(Message::Text(msg.to_string())).await {
let _ = lifecycle_tx.send(WsLifecycleEvent::Error(err.to_string()));
break;
}
}
Some(WsCommand::Disconnect) | None => {
should_stop = true;
let _ = write.send(Message::Close(None)).await;
break;
}
}
}
_ = ping_interval.tick() => {
let ping = serde_json::json!({"type": "ping"});
if let Err(err) = write.send(Message::Text(ping.to_string())).await {
let _ = lifecycle_tx.send(WsLifecycleEvent::Error(err.to_string()));
break;
}
}
}
}
*is_connected.lock().await = false;
let _ = lifecycle_tx.send(WsLifecycleEvent::Close);
if should_stop {
break 'outer;
}
if reconnect_attempt >= max_reconnect_attempts {
break 'outer;
}
reconnect_attempt += 1;
let _ = lifecycle_tx.send(WsLifecycleEvent::Reconnecting {
attempt: reconnect_attempt,
});
let delay_ms = reconnect_delay_ms(reconnect_attempt, max_reconnect_delay_ms);
let reconnect_sleep = tokio::time::sleep(Duration::from_millis(delay_ms));
tokio::pin!(reconnect_sleep);
loop {
tokio::select! {
_ = &mut reconnect_sleep => break,
cmd = command_rx.recv() => {
match cmd {
Some(WsCommand::Subscribe(channels)) => {
for ch in channels {
subscribed_channels.insert(ch);
}
}
Some(WsCommand::Unsubscribe(channels)) => {
for ch in channels {
subscribed_channels.remove(&ch);
}
}
Some(WsCommand::Disconnect) | None => {
should_stop = true;
break;
}
}
}
}
}
}
*is_connected.lock().await = false;
});
Ok(())
}
pub async fn disconnect(&mut self) {
if let Some(tx) = self.command_tx.take() {
let _ = tx.send(WsCommand::Disconnect).await;
}
*self.is_connected.lock().await = false;
}
pub async fn subscribe(&self, channels: Vec<String>) -> Result<()> {
if let Some(ref tx) = self.command_tx {
tx.send(WsCommand::Subscribe(channels))
.await
.map_err(|_| RelayError::NotConnected)?;
Ok(())
} else {
Err(RelayError::NotConnected)
}
}
pub async fn unsubscribe(&self, channels: Vec<String>) -> Result<()> {
if let Some(ref tx) = self.command_tx {
tx.send(WsCommand::Unsubscribe(channels))
.await
.map_err(|_| RelayError::NotConnected)?;
Ok(())
} else {
Err(RelayError::NotConnected)
}
}
}
impl Drop for WsClient {
fn drop(&mut self) {
}
}
fn reconnect_delay_ms(attempt: u32, max_delay_ms: u64) -> u64 {
let exp = attempt.saturating_sub(1);
let delay = 1_000u64.saturating_mul(2u64.saturating_pow(exp));
delay.min(max_delay_ms.max(1_000))
}