use crate::ReconnectLimit;
use crate::core::ConnectionReadyState;
use codee::Decoder;
use default_struct_builder::DefaultBuilder;
use leptos::prelude::*;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::sync::Arc;
use thiserror::Error;
use wasm_bindgen::JsCast;
pub fn use_event_source<T, C>(
url: impl Into<Signal<String>>,
) -> UseEventSourceReturn<
T,
C,
C::Error,
impl Fn() + Clone + Send + Sync + 'static,
impl Fn() + Clone + Send + Sync + 'static,
>
where
T: Clone + PartialEq + Send + Sync + 'static,
C: Decoder<T, Encoded = str> + Send + Sync,
C::Error: Send + Sync,
{
use_event_source_with_options::<T, C>(url, UseEventSourceOptions::<T>::default())
}
pub fn use_event_source_with_options<T, C>(
url: impl Into<Signal<String>>,
options: UseEventSourceOptions<T>,
) -> UseEventSourceReturn<
T,
C,
C::Error,
impl Fn() + Clone + Send + Sync + 'static,
impl Fn() + Clone + Send + Sync + 'static,
>
where
T: Clone + PartialEq + Send + Sync + 'static,
C: Decoder<T, Encoded = str> + Send + Sync,
C::Error: Send + Sync,
{
let UseEventSourceOptions {
reconnect_limit,
reconnect_interval,
on_failed,
immediate,
named_events,
on_event,
with_credentials,
_marker,
} = options;
let (message, set_message) = signal(None::<UseEventSourceMessage<T, C>>);
let (ready_state, set_ready_state) = signal(ConnectionReadyState::Closed);
let (error, set_error) = signal(None::<UseEventSourceError<C::Error>>);
let open;
let close;
#[cfg(not(feature = "ssr"))]
{
use crate::{sendwrap_fn, use_event_listener};
use std::sync::atomic::{AtomicBool, AtomicU32};
use std::time::Duration;
use wasm_bindgen::prelude::*;
let (event_source, set_event_source) = signal_local(None::<web_sys::EventSource>);
let explicitly_closed = Arc::new(AtomicBool::new(false));
let retried = Arc::new(AtomicU32::new(0));
let on_event_return = move |e: &web_sys::Event| {
#[cfg(debug_assertions)]
let _ = leptos::reactive::diagnostics::SpecialNonReactiveZone::enter();
on_event(e)
};
let on_message_event = {
let on_event_return = on_event_return.clone();
move |e: &web_sys::Event| {
match on_event_return(e) {
UseEventSourceOnEventReturn::IgnoreProcessingMessage => {
}
UseEventSourceOnEventReturn::ProcessMessage => {
let message_event = e
.dyn_ref::<web_sys::MessageEvent>()
.expect("Event is not a MessageEvent");
match UseEventSourceMessage::<T, C>::try_from(message_event) {
Ok(event_msg) => {
set_message.set(Some(event_msg));
}
Err(err) => {
set_error.set(Some(err));
}
}
}
}
}
};
let init = StoredValue::new(None::<Arc<dyn Fn() + Send + Sync>>);
let set_init = {
let explicitly_closed = Arc::clone(&explicitly_closed);
let retried = Arc::clone(&retried);
move |url: String| {
init.set_value(Some(Arc::new({
let explicitly_closed = Arc::clone(&explicitly_closed);
let retried = Arc::clone(&retried);
let on_event_return = on_event_return.clone();
let on_message_event = on_message_event.clone();
let named_events = named_events.clone();
let on_failed = Arc::clone(&on_failed);
move || {
if explicitly_closed.load(std::sync::atomic::Ordering::Relaxed) {
return;
}
let event_src_opts = web_sys::EventSourceInit::new();
event_src_opts.set_with_credentials(with_credentials);
let es = web_sys::EventSource::new_with_event_source_init_dict(
&url,
&event_src_opts,
)
.unwrap_throw();
set_ready_state.set(ConnectionReadyState::Connecting);
set_event_source.set(Some(es.clone()));
let on_open = Closure::wrap(Box::new({
let on_event_return = on_event_return.clone();
move |e: web_sys::Event| {
on_event_return(&e);
set_ready_state.set(ConnectionReadyState::Open);
set_error.set(None);
}})
as Box<dyn FnMut(web_sys::Event)>);
es.set_onopen(Some(on_open.as_ref().unchecked_ref()));
on_open.forget();
let on_error = Closure::wrap(Box::new({
let on_event_return = on_event_return.clone();
let explicitly_closed = Arc::clone(&explicitly_closed);
let retried = Arc::clone(&retried);
let on_failed = Arc::clone(&on_failed);
let es = es.clone();
move |e: web_sys::Event| {
on_event_return(&e);
set_ready_state.set(ConnectionReadyState::Closed);
set_error.set(Some(UseEventSourceError::ErrorEvent));
if es.ready_state() == 2
&& !explicitly_closed.load(std::sync::atomic::Ordering::Relaxed)
{
es.close();
let retried_value = retried
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
+ 1;
if !reconnect_limit.is_exceeded_by(retried_value as u64) {
set_timeout(
move || {
if let Some(init) = init.get_value() {
init();
}
},
Duration::from_millis(reconnect_interval),
);
} else {
#[cfg(debug_assertions)]
let _z =
leptos::reactive::diagnostics::SpecialNonReactiveZone::enter();
on_failed();
}
}
}
})
as Box<dyn FnMut(web_sys::Event)>);
es.set_onerror(Some(on_error.as_ref().unchecked_ref()));
on_error.forget();
let on_message = Closure::wrap(Box::new({
let on_message_event = on_message_event.clone();
move |e: web_sys::MessageEvent| {
let e: &web_sys::Event = e.as_ref();
on_message_event(e);
}})
as Box<dyn FnMut(web_sys::MessageEvent)>);
es.set_onmessage(Some(on_message.as_ref().unchecked_ref()));
on_message.forget();
for event_name in named_events.clone() {
let event_handler = {
let on_message_event = on_message_event.clone();
move |e: web_sys::Event| {
on_message_event(&e);
}
};
let _ = use_event_listener(
es.clone(),
leptos::ev::Custom::<leptos::ev::Event>::new(event_name),
event_handler,
);
}
}
})))
}
};
close = {
let explicitly_closed = Arc::clone(&explicitly_closed);
sendwrap_fn!(move || {
if let Some(event_source) = event_source.get_untracked() {
event_source.close();
set_event_source.set(None);
set_ready_state.set(ConnectionReadyState::Closed);
explicitly_closed.store(true, std::sync::atomic::Ordering::Relaxed);
}
})
};
let url: Signal<String> = url.into();
open = {
let close = close.clone();
let explicitly_closed = Arc::clone(&explicitly_closed);
let retried = Arc::clone(&retried);
let set_init = set_init.clone();
sendwrap_fn!(move || {
close();
explicitly_closed.store(false, std::sync::atomic::Ordering::Relaxed);
retried.store(0, std::sync::atomic::Ordering::Relaxed);
if init.get_value().is_none() && !url.get_untracked().is_empty() {
set_init(url.get_untracked());
}
if let Some(init) = init.get_value() {
init();
}
})
};
{
let close = close.clone();
let open = open.clone();
let set_init = set_init.clone();
Effect::watch(
move || url.get(),
move |url, prev_url, _| {
if url.is_empty() {
close();
} else if Some(url) != prev_url {
close();
set_init(url.to_owned());
open();
}
},
immediate,
);
}
on_cleanup(close.clone());
}
#[cfg(feature = "ssr")]
{
open = move || {};
close = move || {};
let _ = reconnect_limit;
let _ = reconnect_interval;
let _ = on_failed;
let _ = immediate;
let _ = named_events;
let _ = on_event;
let _ = with_credentials;
let _ = set_message;
let _ = set_ready_state;
let _ = set_error;
let _ = url;
}
UseEventSourceReturn {
message: message.into(),
ready_state: ready_state.into(),
error: error.into(),
open,
close,
}
}
#[derive(PartialEq)]
pub struct UseEventSourceMessage<T, C>
where
T: Clone + Send + Sync + 'static,
C: Decoder<T, Encoded = str> + Send + Sync,
C::Error: Send + Sync,
{
pub event_type: String,
pub data: T,
pub last_event_id: String,
_marker: PhantomData<C>,
}
impl<T, C> Clone for UseEventSourceMessage<T, C>
where
T: Clone + Send + Sync + 'static,
C: Decoder<T, Encoded = str> + Send + Sync,
C::Error: Send + Sync,
{
fn clone(&self) -> Self {
Self {
event_type: self.event_type.clone(),
data: self.data.clone(),
last_event_id: self.last_event_id.clone(),
_marker: PhantomData,
}
}
}
impl<T, C> Debug for UseEventSourceMessage<T, C>
where
T: Debug + Clone + Send + Sync + 'static,
C: Decoder<T, Encoded = str> + Send + Sync,
C::Error: Send + Sync,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UseEventSourceMessage")
.field("data", &self.data)
.field("event_type", &self.event_type)
.field("last_event_id", &self.last_event_id)
.finish()
}
}
impl<T, C> TryFrom<&web_sys::MessageEvent> for UseEventSourceMessage<T, C>
where
T: Clone + Send + Sync + 'static,
C: Decoder<T, Encoded = str> + Send + Sync,
C::Error: Send + Sync,
{
type Error = UseEventSourceError<C::Error>;
fn try_from(message_event: &web_sys::MessageEvent) -> Result<Self, Self::Error> {
let data_string = message_event.data().as_string().unwrap_or_default();
let data = C::decode(&data_string).map_err(UseEventSourceError::Deserialize)?;
Ok(Self {
event_type: message_event.type_(),
data,
last_event_id: message_event.last_event_id(),
_marker: PhantomData,
})
}
}
impl<T, C> TryFrom<web_sys::Event> for UseEventSourceMessage<T, C>
where
T: Clone + Send + Sync + 'static,
C: Decoder<T, Encoded = str> + Send + Sync,
C::Error: Send + Sync,
{
type Error = UseEventSourceError<C::Error>;
fn try_from(event: web_sys::Event) -> Result<Self, Self::Error> {
let message_event = event
.dyn_into::<web_sys::MessageEvent>()
.map_err(|e| UseEventSourceError::CastToMessageEvent(e.type_()))?;
UseEventSourceMessage::try_from(&message_event)
}
}
#[derive(DefaultBuilder)]
pub struct UseEventSourceOptions<T>
where
T: 'static,
{
reconnect_limit: ReconnectLimit,
reconnect_interval: u64,
on_failed: Arc<dyn Fn() + Send + Sync>,
immediate: bool,
#[builder(into)]
named_events: Vec<String>,
on_event: Arc<dyn Fn(&web_sys::Event) -> UseEventSourceOnEventReturn + Send + Sync>,
with_credentials: bool,
_marker: PhantomData<T>,
}
impl<T> Default for UseEventSourceOptions<T> {
fn default() -> Self {
Self {
reconnect_limit: ReconnectLimit::default(),
reconnect_interval: 3000,
on_failed: Arc::new(|| {}),
immediate: true,
named_events: vec![],
on_event: Arc::new(|_| UseEventSourceOnEventReturn::ProcessMessage),
with_credentials: false,
_marker: PhantomData,
}
}
}
pub enum UseEventSourceOnEventReturn {
IgnoreProcessingMessage,
ProcessMessage,
}
pub struct UseEventSourceReturn<T, C, Err, OpenFn, CloseFn>
where
T: Clone + Send + Sync + 'static,
C: Decoder<T, Encoded = str> + Send + Sync,
C::Error: Send + Sync,
Err: Send + Sync + 'static,
OpenFn: Fn() + Clone + Send + Sync + 'static,
CloseFn: Fn() + Clone + Send + Sync + 'static,
{
pub message: Signal<Option<UseEventSourceMessage<T, C>>>,
pub ready_state: Signal<ConnectionReadyState>,
pub error: Signal<Option<UseEventSourceError<Err>>>,
pub open: OpenFn,
pub close: CloseFn,
}
#[derive(Error, Debug)]
pub enum UseEventSourceError<Err> {
#[error("Error event received")]
ErrorEvent,
#[error("Error decoding value")]
Deserialize(Err),
#[error("Error casting event '{0}' to MessageEvent")]
CastToMessageEvent(String),
}