#![cfg_attr(feature = "ssr", allow(unused_variables, unused_imports, dead_code))]
use crate::{core::ConnectionReadyState, use_interval_fn, ReconnectLimit};
use cfg_if::cfg_if;
use codee::{CodecError, Decoder, Encoder, HybridCoderError, HybridDecoder, HybridEncoder};
use default_struct_builder::DefaultBuilder;
use js_sys::Array;
use leptos::{leptos_dom::helpers::TimeoutHandle, prelude::*};
use std::marker::PhantomData;
use std::sync::{atomic::AtomicBool, Arc};
use std::time::Duration;
use thiserror::Error;
use wasm_bindgen::prelude::*;
use web_sys::{BinaryType, CloseEvent, Event, MessageEvent, WebSocket};
#[allow(rustdoc::bare_urls)]
pub fn use_websocket<Tx, Rx, C>(
url: &str,
) -> UseWebSocketReturn<
Tx,
Rx,
impl Fn() + Clone + Send + Sync + 'static,
impl Fn() + Clone + Send + Sync + 'static,
impl Fn(&Tx) + Clone + Send + Sync + 'static,
>
where
Tx: Send + Sync + 'static,
Rx: Send + Sync + 'static,
C: Encoder<Tx> + Decoder<Rx>,
C: HybridEncoder<Tx, <C as Encoder<Tx>>::Encoded, Error = <C as Encoder<Tx>>::Error>,
C: HybridDecoder<Rx, <C as Decoder<Rx>>::Encoded, Error = <C as Decoder<Rx>>::Error>,
{
use_websocket_with_options::<Tx, Rx, C, (), DummyEncoder>(url, UseWebSocketOptions::default())
}
#[allow(clippy::type_complexity)]
pub fn use_websocket_with_options<Tx, Rx, C, Hb, HbCodec>(
url: &str,
options: UseWebSocketOptions<
Rx,
HybridCoderError<<C as Encoder<Tx>>::Error>,
HybridCoderError<<C as Decoder<Rx>>::Error>,
Hb,
HbCodec,
>,
) -> UseWebSocketReturn<
Tx,
Rx,
impl Fn() + Clone + Send + Sync + 'static,
impl Fn() + Clone + Send + Sync + 'static,
impl Fn(&Tx) + Clone + Send + Sync + 'static,
>
where
Tx: Send + Sync + 'static,
Rx: Send + Sync + 'static,
C: Encoder<Tx> + Decoder<Rx>,
C: HybridEncoder<Tx, <C as Encoder<Tx>>::Encoded, Error = <C as Encoder<Tx>>::Error>,
C: HybridDecoder<Rx, <C as Decoder<Rx>>::Encoded, Error = <C as Decoder<Rx>>::Error>,
Hb: Default + Send + Sync + 'static,
HbCodec: Encoder<Hb> + Send + Sync,
HbCodec: HybridEncoder<
Hb,
<HbCodec as Encoder<Hb>>::Encoded,
Error = <HbCodec as Encoder<Hb>>::Error,
>,
<HbCodec as Encoder<Hb>>::Error: std::fmt::Debug,
{
let url = normalize_url(url);
let UseWebSocketOptions {
on_open,
on_message,
on_message_raw,
on_message_raw_bytes,
on_error,
on_close,
reconnect_limit,
reconnect_interval,
immediate,
protocols,
heartbeat,
} = options;
let (ready_state, set_ready_state) = signal(ConnectionReadyState::Closed);
let (message, set_message) = signal(None);
let ws_signal = RwSignal::new_local(None::<WebSocket>);
let reconnect_timer_ref: StoredValue<Option<TimeoutHandle>> = StoredValue::new(None);
let reconnect_times_ref: StoredValue<u64> = StoredValue::new(0);
let manually_closed_ref: StoredValue<bool> = StoredValue::new(false);
let unmounted = Arc::new(AtomicBool::new(false));
let connect_ref: StoredValue<Option<Arc<dyn Fn() + Send + Sync>>> = StoredValue::new(None);
let send_str = move |data: &str| {
if ready_state.get_untracked() == ConnectionReadyState::Open {
if let Some(web_socket) = ws_signal.get_untracked() {
let _ = web_socket.send_with_str(data);
}
}
};
let send_bytes = move |data: &[u8]| {
if ready_state.get_untracked() == ConnectionReadyState::Open {
if let Some(web_socket) = ws_signal.get_untracked() {
let _ = web_socket.send_with_u8_array(data);
}
}
};
let send = {
let on_error = Arc::clone(&on_error);
move |value: &Tx| {
let on_error = Arc::clone(&on_error);
send_with_codec::<Tx, C>(value, send_str, send_bytes, move |err| {
on_error(UseWebSocketError::Codec(CodecError::Encode(err)));
});
}
};
let heartbeat_interval_ref = StoredValue::new_local(None::<(Arc<dyn Fn()>, Arc<dyn Fn()>)>);
let stop_heartbeat = move || {
if let Some((pause, _)) = heartbeat_interval_ref.get_value() {
pause();
}
};
#[cfg(not(feature = "ssr"))]
{
use crate::utils::Pausable;
let start_heartbeat = {
let on_error = Arc::clone(&on_error);
move || {
if let Some(heartbeat) = &heartbeat {
if let Some((pause, resume)) = heartbeat_interval_ref.get_value() {
pause();
resume();
} else {
let on_error = Arc::clone(&on_error);
let Pausable { pause, resume, .. } = use_interval_fn(
move || {
send_with_codec::<Hb, HbCodec>(
&Hb::default(),
send_str,
send_bytes,
{
let on_error = Arc::clone(&on_error);
move |err| {
on_error(UseWebSocketError::HeartbeatCodec(format!(
"Failed to encode heartbeat data: {err:?}"
)))
}
},
)
},
heartbeat.interval,
);
heartbeat_interval_ref.set_value(Some((Arc::new(pause), Arc::new(resume))));
}
}
}
};
let reconnect_ref: StoredValue<Option<Arc<dyn Fn() + Send + Sync>>> =
StoredValue::new(None);
reconnect_ref.set_value({
let unmounted = Arc::clone(&unmounted);
Some(Arc::new(move || {
let unmounted = Arc::clone(&unmounted);
if !manually_closed_ref.get_value()
&& !reconnect_limit.is_exceeded_by(reconnect_times_ref.get_value())
&& ws_signal
.get_untracked()
.is_some_and(|ws: WebSocket| ws.ready_state() != WebSocket::OPEN)
&& reconnect_timer_ref.get_value().is_none()
{
reconnect_timer_ref.set_value(
set_timeout_with_handle(
move || {
if unmounted.load(std::sync::atomic::Ordering::Relaxed) {
return;
}
if let Some(connect) = connect_ref.get_value() {
connect();
reconnect_times_ref.update_value(|current| *current += 1);
}
},
Duration::from_millis(reconnect_interval),
)
.ok(),
);
}
}))
});
connect_ref.set_value({
let unmounted = Arc::clone(&unmounted);
let on_error = Arc::clone(&on_error);
Some(Arc::new(move || {
if let Some(reconnect_timer) = reconnect_timer_ref.get_value() {
reconnect_timer.clear();
reconnect_timer_ref.set_value(None);
}
if let Some(web_socket) = ws_signal.get_untracked() {
let _ = web_socket.close();
}
let web_socket = {
protocols.with_untracked(|protocols| {
protocols.as_ref().map_or_else(
|| WebSocket::new(&url).unwrap_throw(),
|protocols| {
let array = protocols
.iter()
.map(|p| JsValue::from(p.clone()))
.collect::<Array>();
WebSocket::new_with_str_sequence(&url, &JsValue::from(&array))
.unwrap_throw()
},
)
})
};
web_socket.set_binary_type(BinaryType::Arraybuffer);
set_ready_state.set(ConnectionReadyState::Connecting);
{
let unmounted = Arc::clone(&unmounted);
let on_open = Arc::clone(&on_open);
let onopen_closure = Closure::wrap(Box::new({
let start_heartbeat = start_heartbeat.clone();
move |e: Event| {
if unmounted.load(std::sync::atomic::Ordering::Relaxed) {
return;
}
#[cfg(debug_assertions)]
let zone = leptos::reactive::diagnostics::SpecialNonReactiveZone::enter();
on_open(e);
#[cfg(debug_assertions)]
drop(zone);
set_ready_state.set(ConnectionReadyState::Open);
start_heartbeat();
}
})
as Box<dyn FnMut(Event)>);
web_socket.set_onopen(Some(onopen_closure.as_ref().unchecked_ref()));
onopen_closure.forget();
}
{
let unmounted = Arc::clone(&unmounted);
let on_message = Arc::clone(&on_message);
let on_message_raw = Arc::clone(&on_message_raw);
let on_message_raw_bytes = Arc::clone(&on_message_raw_bytes);
let on_error = Arc::clone(&on_error);
let onmessage_closure = Closure::wrap(Box::new(move |e: MessageEvent| {
if unmounted.load(std::sync::atomic::Ordering::Relaxed) {
return;
}
e.data().dyn_into::<js_sys::ArrayBuffer>().map_or_else(
|_| {
e.data().dyn_into::<js_sys::JsString>().map_or_else(
|_| {
unreachable!(
"message event, received Unknown: {:?}",
e.data()
);
},
|txt| {
let txt = String::from(&txt);
#[cfg(debug_assertions)]
let zone = leptos::reactive::diagnostics::SpecialNonReactiveZone::enter();
on_message_raw(&txt);
#[cfg(debug_assertions)]
drop(zone);
match C::decode_str(&txt) {
Ok(val) => {
#[cfg(debug_assertions)]
let prev = leptos::reactive::diagnostics::SpecialNonReactiveZone::enter();
on_message(&val);
#[cfg(debug_assertions)]
drop(prev);
set_message.set(Some(val));
}
Err(err) => {
on_error(CodecError::Decode(err).into());
}
}
},
);
},
|array_buffer| {
let array = js_sys::Uint8Array::new(&array_buffer);
let array = array.to_vec();
#[cfg(debug_assertions)]
let zone = leptos::reactive::diagnostics::SpecialNonReactiveZone::enter();
on_message_raw_bytes(&array);
#[cfg(debug_assertions)]
drop(zone);
match C::decode_bin(array.as_slice()) {
Ok(val) => {
#[cfg(debug_assertions)]
let prev = leptos::reactive::diagnostics::SpecialNonReactiveZone::enter();
on_message(&val);
#[cfg(debug_assertions)]
drop(prev);
set_message.set(Some(val));
}
Err(err) => {
on_error(CodecError::Decode(err).into());
}
}
},
);
})
as Box<dyn FnMut(MessageEvent)>);
web_socket.set_onmessage(Some(onmessage_closure.as_ref().unchecked_ref()));
onmessage_closure.forget();
}
{
let unmounted = Arc::clone(&unmounted);
let on_error = Arc::clone(&on_error);
let onerror_closure = Closure::wrap(Box::new(move |e: Event| {
if unmounted.load(std::sync::atomic::Ordering::Relaxed) {
return;
}
stop_heartbeat();
if let Some(reconnect) = &reconnect_ref.get_value() {
reconnect();
}
#[cfg(debug_assertions)]
let zone = leptos::reactive::diagnostics::SpecialNonReactiveZone::enter();
on_error(UseWebSocketError::Event(e));
#[cfg(debug_assertions)]
drop(zone);
set_ready_state.set(ConnectionReadyState::Closed);
})
as Box<dyn FnMut(Event)>);
web_socket.set_onerror(Some(onerror_closure.as_ref().unchecked_ref()));
onerror_closure.forget();
}
{
let unmounted = Arc::clone(&unmounted);
let on_close = Arc::clone(&on_close);
let onclose_closure = Closure::wrap(Box::new(move |e: CloseEvent| {
if unmounted.load(std::sync::atomic::Ordering::Relaxed) {
return;
}
stop_heartbeat();
if let Some(reconnect) = &reconnect_ref.get_value() {
reconnect();
}
#[cfg(debug_assertions)]
let zone = leptos::reactive::diagnostics::SpecialNonReactiveZone::enter();
on_close(e);
#[cfg(debug_assertions)]
drop(zone);
set_ready_state.set(ConnectionReadyState::Closed);
})
as Box<dyn FnMut(CloseEvent)>);
web_socket.set_onclose(Some(onclose_closure.as_ref().unchecked_ref()));
onclose_closure.forget();
}
ws_signal.set(Some(web_socket));
}))
});
}
let open = move || {
reconnect_times_ref.set_value(0);
if let Some(connect) = connect_ref.get_value() {
connect();
}
};
let close = {
reconnect_timer_ref.set_value(None);
move || {
stop_heartbeat();
manually_closed_ref.set_value(true);
if let Some(web_socket) = ws_signal.get_untracked() {
let _ = web_socket.close();
}
}
};
Effect::new(move |_| {
if immediate {
open();
}
});
on_cleanup(move || {
unmounted.store(true, std::sync::atomic::Ordering::Relaxed);
close();
});
UseWebSocketReturn {
ready_state: ready_state.into(),
message: message.into(),
ws: ws_signal.into(),
open,
close,
send,
_marker: PhantomData,
}
}
fn send_with_codec<T, Codec>(
value: &T,
send_str: impl Fn(&str),
send_bytes: impl Fn(&[u8]),
on_error: impl Fn(HybridCoderError<<Codec as Encoder<T>>::Error>),
) where
Codec: Encoder<T>,
Codec: HybridEncoder<T, <Codec as Encoder<T>>::Encoded, Error = <Codec as Encoder<T>>::Error>,
{
if Codec::is_binary_encoder() {
match Codec::encode_bin(value) {
Ok(val) => send_bytes(&val),
Err(err) => on_error(err),
}
} else {
match Codec::encode_str(value) {
Ok(val) => send_str(&val),
Err(err) => on_error(err),
}
}
}
type ArcFnBytes = Arc<dyn Fn(&[u8]) + Send + Sync>;
#[derive(DefaultBuilder)]
pub struct UseWebSocketOptions<Rx, E, D, Hb, HbCodec>
where
Rx: ?Sized,
Hb: Default + Send + Sync + 'static,
HbCodec: Encoder<Hb>,
HbCodec: HybridEncoder<
Hb,
<HbCodec as Encoder<Hb>>::Encoded,
Error = <HbCodec as Encoder<Hb>>::Error,
>,
{
#[builder(skip)]
heartbeat: Option<HeartbeatOptions<Hb, HbCodec>>,
on_open: Arc<dyn Fn(Event) + Send + Sync>,
#[builder(skip)]
on_message: Arc<dyn Fn(&Rx) + Send + Sync>,
on_message_raw: Arc<dyn Fn(&str) + Send + Sync>,
on_message_raw_bytes: ArcFnBytes,
#[builder(skip)]
on_error: Arc<dyn Fn(UseWebSocketError<E, D>) + Send + Sync>,
on_close: Arc<dyn Fn(CloseEvent) + Send + Sync>,
reconnect_limit: ReconnectLimit,
reconnect_interval: u64,
immediate: bool,
#[builder(into)]
protocols: Signal<Option<Vec<String>>>,
}
impl<Rx: ?Sized, E, D, Hb, HbCodec> UseWebSocketOptions<Rx, E, D, Hb, HbCodec>
where
Hb: Default + Send + Sync + 'static,
HbCodec: Encoder<Hb>,
HbCodec: HybridEncoder<
Hb,
<HbCodec as Encoder<Hb>>::Encoded,
Error = <HbCodec as Encoder<Hb>>::Error,
>,
{
pub fn on_error<F>(self, handler: F) -> Self
where
F: Fn(UseWebSocketError<E, D>) + Send + Sync + 'static,
{
Self {
on_error: Arc::new(handler),
..self
}
}
pub fn on_message<F>(self, handler: F) -> Self
where
F: Fn(&Rx) + Send + Sync + 'static,
{
Self {
on_message: Arc::new(handler),
..self
}
}
pub fn heartbeat<NewHb, NewHbCodec>(
self,
interval: u64,
) -> UseWebSocketOptions<Rx, E, D, NewHb, NewHbCodec>
where
NewHb: Default + Send + Sync + 'static,
NewHbCodec: Encoder<NewHb>,
NewHbCodec: HybridEncoder<
NewHb,
<NewHbCodec as Encoder<NewHb>>::Encoded,
Error = <NewHbCodec as Encoder<NewHb>>::Error,
>,
{
UseWebSocketOptions {
heartbeat: Some(HeartbeatOptions {
data: PhantomData::<NewHb>,
interval,
codec: PhantomData::<NewHbCodec>,
}),
on_open: self.on_open,
on_message: self.on_message,
on_message_raw: self.on_message_raw,
on_message_raw_bytes: self.on_message_raw_bytes,
on_close: self.on_close,
on_error: self.on_error,
reconnect_limit: self.reconnect_limit,
reconnect_interval: self.reconnect_interval,
immediate: self.immediate,
protocols: self.protocols,
}
}
}
impl<Rx: ?Sized, E, D> Default for UseWebSocketOptions<Rx, E, D, (), DummyEncoder> {
fn default() -> Self {
Self {
heartbeat: None,
on_open: Arc::new(|_| {}),
on_message: Arc::new(|_| {}),
on_message_raw: Arc::new(|_| {}),
on_message_raw_bytes: Arc::new(|_| {}),
on_error: Arc::new(|_| {}),
on_close: Arc::new(|_| {}),
reconnect_limit: ReconnectLimit::default(),
reconnect_interval: 3000,
immediate: true,
protocols: Default::default(),
}
}
}
pub struct DummyEncoder;
impl Encoder<()> for DummyEncoder {
type Encoded = String;
type Error = ();
fn encode(_: &()) -> Result<Self::Encoded, Self::Error> {
Ok("".to_string())
}
}
pub struct HeartbeatOptions<Hb, HbCodec>
where
Hb: Default + Send + Sync + 'static,
HbCodec: Encoder<Hb>,
HbCodec: HybridEncoder<
Hb,
<HbCodec as Encoder<Hb>>::Encoded,
Error = <HbCodec as Encoder<Hb>>::Error,
>,
{
data: PhantomData<Hb>,
interval: u64,
codec: PhantomData<HbCodec>,
}
impl<Hb, HbCodec> Clone for HeartbeatOptions<Hb, HbCodec>
where
Hb: Default + Send + Sync + 'static,
HbCodec: Encoder<Hb>,
HbCodec: HybridEncoder<
Hb,
<HbCodec as Encoder<Hb>>::Encoded,
Error = <HbCodec as Encoder<Hb>>::Error,
>,
{
fn clone(&self) -> Self {
*self
}
}
impl<Hb, HbCodec> Copy for HeartbeatOptions<Hb, HbCodec>
where
Hb: Default + Send + Sync + 'static,
HbCodec: Encoder<Hb>,
HbCodec: HybridEncoder<
Hb,
<HbCodec as Encoder<Hb>>::Encoded,
Error = <HbCodec as Encoder<Hb>>::Error,
>,
{
}
#[derive(Clone)]
pub struct UseWebSocketReturn<Tx, Rx, OpenFn, CloseFn, SendFn>
where
Tx: Send + Sync + 'static,
Rx: Send + Sync + 'static,
OpenFn: Fn() + Clone + Send + Sync + 'static,
CloseFn: Fn() + Clone + Send + Sync + 'static,
SendFn: Fn(&Tx) + Clone + Send + Sync + 'static,
{
pub ready_state: Signal<ConnectionReadyState>,
pub message: Signal<Option<Rx>>,
pub ws: Signal<Option<WebSocket>, LocalStorage>,
pub open: OpenFn,
pub close: CloseFn,
pub send: SendFn,
_marker: PhantomData<Tx>,
}
#[derive(Error, Debug)]
pub enum UseWebSocketError<E, D> {
#[error("WebSocket error event")]
Event(Event),
#[error("WebSocket codec error: {0}")]
Codec(#[from] CodecError<E, D>),
#[error("WebSocket heartbeat codec error: {0}")]
HeartbeatCodec(String),
}
fn normalize_url(url: &str) -> String {
cfg_if! { if #[cfg(feature = "ssr")] {
url.to_string()
} else {
if url.starts_with("ws://") || url.starts_with("wss://") {
url.to_string()
} else if url.starts_with("//") {
format!("{}{}", detect_protocol(), url)
} else if url.starts_with('/') {
format!(
"{}//{}{}",
detect_protocol(),
window().location().host().expect("Host not found"),
url
)
} else {
let mut path = window().location().pathname().expect("Pathname not found");
if !path.ends_with('/') {
path.push('/')
}
format!(
"{}//{}{}{}",
detect_protocol(),
window().location().host().expect("Host not found"),
path,
url
)
}
}}
}
fn detect_protocol() -> String {
cfg_if! { if #[cfg(feature = "ssr")] {
"ws".to_string()
} else {
window().location().protocol().expect("Protocol not found").replace("http", "ws")
}}
}