leptos_use/
use_websocket.rs

1use crate::{ReconnectLimit, core::ConnectionReadyState};
2use cfg_if::cfg_if;
3use codee::{CodecError, Decoder, Encoder, HybridCoderError, HybridDecoder, HybridEncoder};
4use default_struct_builder::DefaultBuilder;
5use leptos::prelude::*;
6use std::marker::PhantomData;
7use std::sync::Arc;
8use thiserror::Error;
9use web_sys::{CloseEvent, Event};
10
11#[allow(rustdoc::bare_urls)]
12/// Creating and managing a [Websocket](https://developer.mozilla.org/en-US/docs/Web/API/WebSocket) connection.
13///
14/// ## Demo
15///
16/// [Link to Demo](https://github.com/Synphonyte/leptos-use/tree/main/examples/use_websocket)
17///
18/// ## Usage
19///
20/// Values are (en)decoded via the given codec. You can use any of the codecs, string or binary.
21///
22/// > Please check [the codec chapter](https://leptos-use.rs/codecs.html) to see what codecs are
23/// > available and what feature flags they require.
24///
25/// ```
26/// # use leptos::prelude::*;
27/// # use codee::string::FromToStringCodec;
28/// # use leptos_use::{use_websocket, UseWebSocketReturn};
29/// # use leptos_use::core::ConnectionReadyState;
30/// #
31/// # #[component]
32/// # fn Demo() -> impl IntoView {
33/// let UseWebSocketReturn {
34///     ready_state,
35///     message,
36///     send,
37///     open,
38///     close,
39///     ..
40/// } = use_websocket::<String, String, FromToStringCodec>("wss://echo.websocket.events/");
41///
42/// let send_message = move |_| {
43///     send(&"Hello, world!".to_string());
44/// };
45///
46/// let status = move || ready_state.get().to_string();
47///
48/// let connected = move || ready_state.get() == ConnectionReadyState::Open;
49///
50/// let open_connection = move |_| {
51///     open();
52/// };
53///
54/// let close_connection = move |_| {
55///     close();
56/// };
57///
58/// view! {
59///     <div>
60///         <p>"status: " {status}</p>
61///
62///         <button on:click=send_message disabled=move || !connected()>"Send"</button>
63///         <button on:click=open_connection disabled=connected>"Open"</button>
64///         <button on:click=close_connection disabled=move || !connected()>"Close"</button>
65///
66///         <p>"Receive message: " {move || format!("{:?}", message.get())}</p>
67///     </div>
68/// }
69/// # }
70/// ```
71///
72/// Here is another example using `msgpack` for encoding and decoding. This means that only binary
73/// messages can be sent or received. For this to work you have to enable the **`msgpack_serde` feature** flag.
74///
75/// ```
76/// # use leptos::*;
77/// # use codee::binary::MsgpackSerdeCodec;
78/// # use leptos_use::{use_websocket, UseWebSocketReturn};
79/// # use serde::{Deserialize, Serialize};
80/// #
81/// # #[component]
82/// # fn Demo() -> impl IntoView {
83/// #[derive(Serialize, Deserialize)]
84/// struct SomeData {
85///     name: String,
86///     count: i32,
87/// }
88///
89/// let UseWebSocketReturn {
90///     message,
91///     send,
92///     ..
93/// } = use_websocket::<SomeData, SomeData, MsgpackSerdeCodec>("wss://some.websocket.server/");
94///
95/// let send_data = move || {
96///     send(&SomeData {
97///         name: "John Doe".to_string(),
98///         count: 42,
99///     });
100/// };
101/// #
102/// # view! {}
103/// }
104/// ```
105///
106/// ### Heartbeats
107///
108/// Heartbeats can be configured by the `heartbeat` option. You have to provide a heartbeat
109/// type, that implements the `Default` trait and an `Encoder` for it. This encoder doesn't have
110/// to be the same as the one used for the other websocket messages.
111///
112/// ```
113/// # use leptos::*;
114/// # use codee::string::FromToStringCodec;
115/// # use leptos_use::{use_websocket_with_options, UseWebSocketOptions, UseWebSocketReturn};
116/// # use serde::{Deserialize, Serialize};
117/// #
118/// # #[component]
119/// # fn Demo() -> impl IntoView {
120/// #[derive(Default)]
121/// struct Heartbeat;
122///
123/// // Simple example for usage with `FromToStringCodec`
124/// impl std::fmt::Display for Heartbeat {
125///     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126///         write!(f, "<Heartbeat>")
127///     }
128/// }
129///
130/// let UseWebSocketReturn {
131///     send,
132///     message,
133///     ..
134/// } = use_websocket_with_options::<String, String, FromToStringCodec, _, _>(
135///     "wss://echo.websocket.events/",
136///     UseWebSocketOptions::default()
137///         // Enable heartbeats every 10 seconds. In this case we use the same codec as for the
138///         // other messages. But this is not necessary.
139///         .heartbeat::<Heartbeat, FromToStringCodec>(10_000),
140/// );
141/// #
142/// # view! {}
143/// }
144/// ```
145///
146/// ## Relative Paths
147///
148/// If the provided `url` is relative, it will be resolved relative to the current page.
149/// Urls will be resolved like this the following. Please note that the protocol (http vs https) will
150/// be taken into account as well.
151///
152/// | Current Page                   | Relative Url             | Resolved Url                        |
153/// |--------------------------------|--------------------------|-------------------------------------|
154/// | http://example.com/some/where  | /api/ws                  | ws://example.com/api/ws             |
155/// | https://example.com/some/where | /api/ws                  | wss://example.com/api/ws            |
156/// | https://example.com/some/where | api/ws                   | wss://example.com/some/where/api/ws |
157/// | https://example.com/some/where | //otherdomain.com/api/ws | wss://otherdomain.com/api/ws        |
158///
159///
160/// ## Usage with `provide_context`
161///
162/// The return value of `use_websocket` utilizes several type parameters which can make it
163/// cumbersome to use with `provide_context` + `expect_context`.
164/// The following example shows how to avoid type parameters with dynamic dispatch.
165/// This sacrifices a little bit of performance for the sake of ergonomics. However,
166/// compared to network transmission speeds this loss of performance is negligible.
167///
168/// First we define the `struct` that is going to be passed around as context.
169///
170/// ```
171/// # use leptos::prelude::*;
172/// use std::sync::Arc;
173///
174/// #[derive(Clone)]
175/// pub struct WebsocketContext {
176///     pub message: Signal<Option<String>>,
177///     send: Arc<dyn Fn(&String)>,  // use Arc to make it easily cloneable
178/// }
179///
180/// impl WebsocketContext {
181///     pub fn new(message: Signal<Option<String>>, send: Arc<dyn Fn(&String)>) -> Self {
182///         Self {
183///             message,
184///             send,
185///         }
186///     }
187///
188///     // create a method to avoid having to use parantheses around the field
189///     #[inline(always)]
190///     pub fn send(&self, message: &str) {
191///         (self.send)(&message.to_string())
192///     }
193/// }
194/// ```
195///
196/// Now you can provide the context like the following.
197///
198/// ```
199/// # use leptos::prelude::*;
200/// # use codee::string::FromToStringCodec;
201/// # use leptos_use::{use_websocket, UseWebSocketReturn};
202/// # use std::sync::Arc;
203/// # #[derive(Clone)]
204/// # pub struct WebsocketContext {
205/// #     pub message: Signal<Option<String>>,
206/// #     send: Arc<dyn Fn(&String) + Send + Sync>,
207/// # }
208/// #
209/// # impl WebsocketContext {
210/// #     pub fn new(message: Signal<Option<String>>, send: Arc<dyn Fn(&String) + Send + Sync>) -> Self {
211/// #         Self {
212/// #             message,
213/// #             send,
214/// #         }
215/// #     }
216/// # }
217///
218/// # #[component]
219/// # fn Demo() -> impl IntoView {
220/// let UseWebSocketReturn {
221///     message,
222///     send,
223///     ..
224/// } = use_websocket::<String, String, FromToStringCodec>("ws:://some.websocket.io");
225///
226/// provide_context(WebsocketContext::new(message, Arc::new(send.clone())));
227/// #
228/// # view! {}
229/// # }
230/// ```
231///
232/// Finally let's use the context:
233///
234/// ```
235/// # use leptos::prelude::*;
236/// # use leptos_use::{use_websocket, UseWebSocketReturn};
237/// # use std::sync::Arc;
238/// # #[derive(Clone)]
239/// # pub struct WebsocketContext {
240/// #     pub message: Signal<Option<String>>,
241/// #     send: Arc<dyn Fn(&String)>,
242/// # }
243/// #
244/// # impl WebsocketContext {
245/// #     #[inline(always)]
246/// #     pub fn send(&self, message: &str) {
247/// #         (self.send)(&message.to_string())
248/// #     }
249/// # }
250///
251/// # #[component]
252/// # fn Demo() -> impl IntoView {
253/// let websocket = expect_context::<WebsocketContext>();
254///
255/// websocket.send("Hello World!");
256/// #
257/// # view! {}
258/// # }
259/// ```
260///
261/// ## SendWrapped Return
262///
263/// The returned closures `open`, `close`, and `send` are sendwrapped functions. They can
264/// only be called from the same thread that called `use_websocket`.
265///
266/// ## Server-Side Rendering
267///
268/// > Make sure you follow the [instructions in Server-Side Rendering](https://leptos-use.rs/server_side_rendering.html).
269///
270/// On the server the returned functions amount to no-ops.
271pub fn use_websocket<Tx, Rx, C>(
272    url: &str,
273) -> UseWebSocketReturn<
274    Tx,
275    Rx,
276    impl Fn() + Clone + Send + Sync + 'static,
277    impl Fn() + Clone + Send + Sync + 'static,
278    impl Fn(&Tx) + Clone + Send + Sync + 'static,
279>
280where
281    Tx: Send + Sync + 'static,
282    Rx: Send + Sync + 'static,
283    C: Encoder<Tx> + Decoder<Rx>,
284    C: HybridEncoder<Tx, <C as Encoder<Tx>>::Encoded, Error = <C as Encoder<Tx>>::Error>,
285    C: HybridDecoder<Rx, <C as Decoder<Rx>>::Encoded, Error = <C as Decoder<Rx>>::Error>,
286{
287    use_websocket_with_options::<Tx, Rx, C, (), DummyEncoder>(url, UseWebSocketOptions::default())
288}
289
290/// Version of [`use_websocket`] that takes `UseWebSocketOptions`. See [`use_websocket`] for how to use.
291#[allow(clippy::type_complexity)]
292pub fn use_websocket_with_options<Tx, Rx, C, Hb, HbCodec>(
293    url: &str,
294    options: UseWebSocketOptions<
295        Rx,
296        HybridCoderError<<C as Encoder<Tx>>::Error>,
297        HybridCoderError<<C as Decoder<Rx>>::Error>,
298        Hb,
299        HbCodec,
300    >,
301) -> UseWebSocketReturn<
302    Tx,
303    Rx,
304    impl Fn() + Clone + Send + Sync + 'static,
305    impl Fn() + Clone + Send + Sync + 'static,
306    impl Fn(&Tx) + Clone + Send + Sync + 'static,
307>
308where
309    Tx: Send + Sync + 'static,
310    Rx: Send + Sync + 'static,
311    C: Encoder<Tx> + Decoder<Rx>,
312    C: HybridEncoder<Tx, <C as Encoder<Tx>>::Encoded, Error = <C as Encoder<Tx>>::Error>,
313    C: HybridDecoder<Rx, <C as Decoder<Rx>>::Encoded, Error = <C as Decoder<Rx>>::Error>,
314    Hb: Default + Send + Sync + 'static,
315    HbCodec: Encoder<Hb> + Send + Sync,
316    HbCodec: HybridEncoder<
317            Hb,
318            <HbCodec as Encoder<Hb>>::Encoded,
319            Error = <HbCodec as Encoder<Hb>>::Error,
320        >,
321    <HbCodec as Encoder<Hb>>::Error: std::fmt::Debug,
322{
323    let url = normalize_url(url);
324
325    let UseWebSocketOptions {
326        on_open,
327        on_message,
328        on_message_raw,
329        on_message_raw_bytes,
330        on_error,
331        on_close,
332        reconnect_limit,
333        reconnect_interval,
334        immediate,
335        protocols,
336        heartbeat,
337    } = options;
338
339    let (ready_state, set_ready_state) = signal(ConnectionReadyState::Closed);
340    let (message, set_message) = signal(None);
341
342    let open;
343    let close;
344    let send;
345
346    #[cfg(not(feature = "ssr"))]
347    {
348        use crate::{sendwrap_fn, use_interval_fn, utils::Pausable};
349        use js_sys::Array;
350        use leptos::leptos_dom::helpers::TimeoutHandle;
351        use std::sync::atomic::AtomicBool;
352        use std::time::Duration;
353        use wasm_bindgen::prelude::*;
354        use web_sys::{BinaryType, MessageEvent, WebSocket};
355
356        let ws = StoredValue::new_local(None::<WebSocket>);
357
358        let reconnect_timer_ref: StoredValue<Option<TimeoutHandle>> = StoredValue::new(None);
359
360        let reconnect_times_ref: StoredValue<u64> = StoredValue::new(0);
361        let manually_closed_ref: StoredValue<bool> = StoredValue::new(false);
362
363        let unmounted = Arc::new(AtomicBool::new(false));
364
365        let connect_ref: StoredValue<Option<Arc<dyn Fn() + Send + Sync>>> = StoredValue::new(None);
366
367        let send_str = move |data: &str| {
368            if ready_state.get_untracked() == ConnectionReadyState::Open
369                && let Some(web_socket) = ws.get_value()
370            {
371                let _ = web_socket.send_with_str(data);
372            }
373        };
374
375        let send_bytes = move |data: &[u8]| {
376            if ready_state.get_untracked() == ConnectionReadyState::Open
377                && let Some(web_socket) = ws.get_value()
378            {
379                let _ = web_socket.send_with_u8_array(data);
380            }
381        };
382
383        send = {
384            let on_error = Arc::clone(&on_error);
385
386            sendwrap_fn!(move |value: &Tx| {
387                let on_error = Arc::clone(&on_error);
388
389                send_with_codec::<Tx, C>(value, send_str, send_bytes, move |err| {
390                    on_error(UseWebSocketError::Codec(CodecError::Encode(err)));
391                });
392            })
393        };
394
395        let heartbeat_interval_ref = StoredValue::new_local(None::<(Arc<dyn Fn()>, Arc<dyn Fn()>)>);
396
397        let stop_heartbeat = move || {
398            if let Some((pause, _)) = heartbeat_interval_ref.get_value() {
399                pause();
400            }
401        };
402
403        let start_heartbeat = {
404            let on_error = Arc::clone(&on_error);
405
406            move || {
407                if let Some(heartbeat) = &heartbeat {
408                    if let Some((pause, resume)) = heartbeat_interval_ref.get_value() {
409                        pause();
410                        resume();
411                    } else {
412                        let on_error = Arc::clone(&on_error);
413
414                        let Pausable { pause, resume, .. } = use_interval_fn(
415                            move || {
416                                send_with_codec::<Hb, HbCodec>(
417                                    &Hb::default(),
418                                    send_str,
419                                    send_bytes,
420                                    {
421                                        let on_error = Arc::clone(&on_error);
422
423                                        move |err| {
424                                            on_error(UseWebSocketError::HeartbeatCodec(format!(
425                                                "Failed to encode heartbeat data: {err:?}"
426                                            )))
427                                        }
428                                    },
429                                )
430                            },
431                            heartbeat.interval,
432                        );
433
434                        heartbeat_interval_ref.set_value(Some((Arc::new(pause), Arc::new(resume))));
435                    }
436                }
437            }
438        };
439
440        let reconnect_ref: StoredValue<Option<Arc<dyn Fn() + Send + Sync>>> =
441            StoredValue::new(None);
442        reconnect_ref.set_value({
443            let unmounted = Arc::clone(&unmounted);
444
445            Some(Arc::new(move || {
446                let unmounted = Arc::clone(&unmounted);
447
448                if !manually_closed_ref.get_value()
449                    && !reconnect_limit.is_exceeded_by(reconnect_times_ref.get_value())
450                    && ws
451                        .get_value()
452                        .is_some_and(|ws: WebSocket| ws.ready_state() != WebSocket::OPEN)
453                    && reconnect_timer_ref.get_value().is_none()
454                {
455                    reconnect_timer_ref.set_value(
456                        set_timeout_with_handle(
457                            move || {
458                                if unmounted.load(std::sync::atomic::Ordering::Relaxed) {
459                                    return;
460                                }
461                                if let Some(connect) = connect_ref.get_value() {
462                                    connect();
463                                    reconnect_times_ref.update_value(|current| *current += 1);
464                                }
465                            },
466                            Duration::from_millis(reconnect_interval),
467                        )
468                        .ok(),
469                    );
470                }
471            }))
472        });
473
474        connect_ref.set_value({
475            let unmounted = Arc::clone(&unmounted);
476            let on_error = Arc::clone(&on_error);
477
478            Some(Arc::new(move || {
479                if let Some(reconnect_timer) = reconnect_timer_ref.get_value() {
480                    reconnect_timer.clear();
481                    reconnect_timer_ref.set_value(None);
482                }
483
484                if let Some(web_socket) = ws.get_value() {
485                    let _ = web_socket.close();
486                }
487
488                let web_socket = {
489                    protocols.with_untracked(|protocols| {
490                        protocols.as_ref().map_or_else(
491                            || WebSocket::new(&url).unwrap_throw(),
492                            |protocols| {
493                                let array = protocols
494                                    .iter()
495                                    .map(|p| JsValue::from(p.clone()))
496                                    .collect::<Array>();
497                                WebSocket::new_with_str_sequence(&url, &JsValue::from(&array))
498                                    .unwrap_throw()
499                            },
500                        )
501                    })
502                };
503                web_socket.set_binary_type(BinaryType::Arraybuffer);
504                set_ready_state.set(ConnectionReadyState::Connecting);
505
506                // onopen handler
507                {
508                    let unmounted = Arc::clone(&unmounted);
509                    let on_open = Arc::clone(&on_open);
510
511                    let onopen_closure = Closure::wrap(Box::new({
512                        let start_heartbeat = start_heartbeat.clone();
513
514                        move |e: Event| {
515                            if unmounted.load(std::sync::atomic::Ordering::Relaxed) {
516                                return;
517                            }
518
519                            #[cfg(debug_assertions)]
520                            let zone = leptos::reactive::diagnostics::SpecialNonReactiveZone::enter();
521
522                            on_open(e);
523
524                            #[cfg(debug_assertions)]
525                            drop(zone);
526
527                            set_ready_state.set(ConnectionReadyState::Open);
528
529                            start_heartbeat();
530                        }
531                    })
532                        as Box<dyn FnMut(Event)>);
533                    web_socket.set_onopen(Some(onopen_closure.as_ref().unchecked_ref()));
534                    // Forget the closure to keep it alive
535                    onopen_closure.forget();
536                }
537
538                // onmessage handler
539                {
540                    let unmounted = Arc::clone(&unmounted);
541                    let on_message = Arc::clone(&on_message);
542                    let on_message_raw = Arc::clone(&on_message_raw);
543                    let on_message_raw_bytes = Arc::clone(&on_message_raw_bytes);
544                    let on_error = Arc::clone(&on_error);
545
546                    let onmessage_closure = Closure::wrap(Box::new(move |e: MessageEvent| {
547                        if unmounted.load(std::sync::atomic::Ordering::Relaxed) {
548                            return;
549                        }
550
551                        e.data().dyn_into::<js_sys::ArrayBuffer>().map_or_else(
552                            |_| {
553                                e.data().dyn_into::<js_sys::JsString>().map_or_else(
554                                    |_| {
555                                        unreachable!(
556                                            "message event, received Unknown: {:?}",
557                                            e.data()
558                                        );
559                                    },
560                                    |txt| {
561                                        let txt = String::from(&txt);
562
563                                        #[cfg(debug_assertions)]
564                                        let zone = leptos::reactive::diagnostics::SpecialNonReactiveZone::enter();
565
566                                        on_message_raw(&txt);
567
568                                        #[cfg(debug_assertions)]
569                                        drop(zone);
570
571                                        match C::decode_str(&txt) {
572                                            Ok(val) => {
573                                                #[cfg(debug_assertions)]
574                                                let prev = leptos::reactive::diagnostics::SpecialNonReactiveZone::enter();
575
576                                                on_message(&val);
577
578                                                #[cfg(debug_assertions)]
579                                                drop(prev);
580
581                                                set_message.set(Some(val));
582                                            }
583                                            Err(err) => {
584                                                on_error(CodecError::Decode(err).into());
585                                            }
586                                        }
587                                    },
588                                );
589                            },
590                            |array_buffer| {
591                                let array = js_sys::Uint8Array::new(&array_buffer);
592                                let array = array.to_vec();
593
594                                #[cfg(debug_assertions)]
595                                let zone = leptos::reactive::diagnostics::SpecialNonReactiveZone::enter();
596
597                                on_message_raw_bytes(&array);
598
599                                #[cfg(debug_assertions)]
600                                drop(zone);
601
602                                match C::decode_bin(array.as_slice()) {
603                                    Ok(val) => {
604                                        #[cfg(debug_assertions)]
605                                        let prev = leptos::reactive::diagnostics::SpecialNonReactiveZone::enter();
606
607                                        on_message(&val);
608
609                                        #[cfg(debug_assertions)]
610                                        drop(prev);
611
612                                        set_message.set(Some(val));
613                                    }
614                                    Err(err) => {
615                                        on_error(CodecError::Decode(err).into());
616                                    }
617                                }
618                            },
619                        );
620                    })
621                        as Box<dyn FnMut(MessageEvent)>);
622                    web_socket.set_onmessage(Some(onmessage_closure.as_ref().unchecked_ref()));
623                    onmessage_closure.forget();
624                }
625
626                // onerror handler
627                {
628                    let unmounted = Arc::clone(&unmounted);
629                    let on_error = Arc::clone(&on_error);
630
631                    let onerror_closure = Closure::wrap(Box::new(move |e: Event| {
632                        if unmounted.load(std::sync::atomic::Ordering::Relaxed) {
633                            return;
634                        }
635
636                        stop_heartbeat();
637
638                        #[cfg(debug_assertions)]
639                        let zone = leptos::reactive::diagnostics::SpecialNonReactiveZone::enter();
640
641                        on_error(UseWebSocketError::Event(e));
642
643                        #[cfg(debug_assertions)]
644                        drop(zone);
645
646                        set_ready_state.set(ConnectionReadyState::Closed);
647
648                        // try to reconnect
649                        if let Some(reconnect) = &reconnect_ref.get_value() {
650                            reconnect();
651                        }
652                    })
653                        as Box<dyn FnMut(Event)>);
654                    web_socket.set_onerror(Some(onerror_closure.as_ref().unchecked_ref()));
655                    onerror_closure.forget();
656                }
657
658                // onclose handler
659                {
660                    let unmounted = Arc::clone(&unmounted);
661                    let on_close = Arc::clone(&on_close);
662
663                    let onclose_closure = Closure::wrap(Box::new(move |e: CloseEvent| {
664                        if unmounted.load(std::sync::atomic::Ordering::Relaxed) {
665                            return;
666                        }
667
668                        stop_heartbeat();
669
670                        #[cfg(debug_assertions)]
671                        let zone = leptos::reactive::diagnostics::SpecialNonReactiveZone::enter();
672
673                        on_close(e);
674
675                        #[cfg(debug_assertions)]
676                        drop(zone);
677
678                        set_ready_state.set(ConnectionReadyState::Closed);
679
680                        // if closing was not intentional, try to reconnect
681                        if let Some(reconnect) = &reconnect_ref.get_value() {
682                            reconnect();
683                        }
684                    })
685                        as Box<dyn FnMut(CloseEvent)>);
686                    web_socket.set_onclose(Some(onclose_closure.as_ref().unchecked_ref()));
687                    onclose_closure.forget();
688                }
689
690                ws.set_value(Some(web_socket));
691            }))
692        });
693
694        // Open connection
695        open = sendwrap_fn!(move || {
696            reconnect_times_ref.set_value(0);
697            if let Some(connect) = connect_ref.get_value() {
698                connect();
699            }
700        });
701
702        // Close connection
703        close = {
704            reconnect_timer_ref.set_value(None);
705
706            sendwrap_fn!(move || {
707                stop_heartbeat();
708                manually_closed_ref.set_value(true);
709                if let Some(web_socket) = ws.get_value() {
710                    let _ = web_socket.close();
711                }
712            })
713        };
714
715        // Open connection (not called if option `manual` is true)
716        Effect::new({
717            let open = open.clone();
718            move |_| {
719                if immediate {
720                    open();
721                }
722            }
723        });
724
725        // clean up (unmount)
726        on_cleanup({
727            let close = close.clone();
728            move || {
729                unmounted.store(true, std::sync::atomic::Ordering::Relaxed);
730                close();
731            }
732        });
733    }
734
735    #[cfg(feature = "ssr")]
736    {
737        open = move || {};
738        close = move || {};
739        send = move |_: &Tx| {};
740
741        let _ = url;
742        let _ = on_open;
743        let _ = on_message;
744        let _ = on_message_raw;
745        let _ = on_message_raw_bytes;
746        let _ = on_error;
747        let _ = on_close;
748        let _ = reconnect_limit;
749        let _ = reconnect_interval;
750        let _ = immediate;
751        let _ = protocols;
752        let _ = heartbeat;
753        let _ = set_ready_state;
754        let _ = set_message;
755    }
756
757    UseWebSocketReturn {
758        ready_state: ready_state.into(),
759        message: message.into(),
760        open,
761        close,
762        send,
763        _marker: PhantomData,
764    }
765}
766
767#[cfg(not(feature = "ssr"))]
768fn send_with_codec<T, Codec>(
769    value: &T,
770    send_str: impl Fn(&str),
771    send_bytes: impl Fn(&[u8]),
772    on_error: impl Fn(HybridCoderError<<Codec as Encoder<T>>::Error>),
773) where
774    Codec: Encoder<T>,
775    Codec: HybridEncoder<T, <Codec as Encoder<T>>::Encoded, Error = <Codec as Encoder<T>>::Error>,
776{
777    if Codec::is_binary_encoder() {
778        match Codec::encode_bin(value) {
779            Ok(val) => send_bytes(&val),
780            Err(err) => on_error(err),
781        }
782    } else {
783        match Codec::encode_str(value) {
784            Ok(val) => send_str(&val),
785            Err(err) => on_error(err),
786        }
787    }
788}
789
790type ArcFnBytes = Arc<dyn Fn(&[u8]) + Send + Sync>;
791
792/// Options for [`use_websocket_with_options`].
793#[derive(DefaultBuilder)]
794pub struct UseWebSocketOptions<Rx, E, D, Hb, HbCodec>
795where
796    Rx: ?Sized,
797    Hb: Default + Send + Sync + 'static,
798    HbCodec: Encoder<Hb>,
799    HbCodec: HybridEncoder<
800            Hb,
801            <HbCodec as Encoder<Hb>>::Encoded,
802            Error = <HbCodec as Encoder<Hb>>::Error,
803        >,
804{
805    /// Heartbeat options
806    #[builder(skip)]
807    heartbeat: Option<HeartbeatOptions<Hb, HbCodec>>,
808    /// `WebSocket` connect callback.
809    on_open: Arc<dyn Fn(Event) + Send + Sync>,
810    /// `WebSocket` message callback for typed message decoded by codec.
811    #[builder(skip)]
812    on_message: Arc<dyn Fn(&Rx) + Send + Sync>,
813    /// `WebSocket` message callback for text.
814    on_message_raw: Arc<dyn Fn(&str) + Send + Sync>,
815    /// `WebSocket` message callback for binary.
816    on_message_raw_bytes: ArcFnBytes,
817    /// `WebSocket` error callback.
818    #[builder(skip)]
819    on_error: Arc<dyn Fn(UseWebSocketError<E, D>) + Send + Sync>,
820    /// `WebSocket` close callback.
821    on_close: Arc<dyn Fn(CloseEvent) + Send + Sync>,
822    /// Retry times. Defaults to `ReconnectLimit::Limited(3)`. Use `ReconnectLimit::Infinite` for
823    /// infinite retries.
824    reconnect_limit: ReconnectLimit,
825    /// Retry interval in ms. Defaults to 3000.
826    reconnect_interval: u64,
827    /// If `true` the `WebSocket` connection will immediately be opened when calling this function.
828    /// If `false` you have to manually call the `open` function.
829    /// Defaults to `true`.
830    immediate: bool,
831    /// Sub protocols. See [MDN Docs](https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/WebSocket#protocols).
832    ///
833    /// Can be set as a signal to support protocols only available after the initial render.
834    ///
835    /// Note that protocols are only updated on the next websocket open() call, not whenever the signal is updated.
836    /// Therefore "lazy" protocols should use the `immediate(false)` option and manually call `open()`.
837    #[builder(into)]
838    protocols: Signal<Option<Vec<String>>>,
839}
840
841impl<Rx: ?Sized, E, D, Hb, HbCodec> UseWebSocketOptions<Rx, E, D, Hb, HbCodec>
842where
843    Hb: Default + Send + Sync + 'static,
844    HbCodec: Encoder<Hb>,
845    HbCodec: HybridEncoder<
846            Hb,
847            <HbCodec as Encoder<Hb>>::Encoded,
848            Error = <HbCodec as Encoder<Hb>>::Error,
849        >,
850{
851    /// `WebSocket` error callback.
852    pub fn on_error<F>(self, handler: F) -> Self
853    where
854        F: Fn(UseWebSocketError<E, D>) + Send + Sync + 'static,
855    {
856        Self {
857            on_error: Arc::new(handler),
858            ..self
859        }
860    }
861
862    /// `WebSocket` message callback for typed message decoded by codec.
863    pub fn on_message<F>(self, handler: F) -> Self
864    where
865        F: Fn(&Rx) + Send + Sync + 'static,
866    {
867        Self {
868            on_message: Arc::new(handler),
869            ..self
870        }
871    }
872
873    /// Set the data, codec and interval at which the heartbeat is sent. The heartbeat
874    /// is the default value of the `NewHb` type.
875    pub fn heartbeat<NewHb, NewHbCodec>(
876        self,
877        interval: u64,
878    ) -> UseWebSocketOptions<Rx, E, D, NewHb, NewHbCodec>
879    where
880        NewHb: Default + Send + Sync + 'static,
881        NewHbCodec: Encoder<NewHb>,
882        NewHbCodec: HybridEncoder<
883                NewHb,
884                <NewHbCodec as Encoder<NewHb>>::Encoded,
885                Error = <NewHbCodec as Encoder<NewHb>>::Error,
886            >,
887    {
888        UseWebSocketOptions {
889            heartbeat: Some(HeartbeatOptions {
890                data: PhantomData::<NewHb>,
891                interval,
892                codec: PhantomData::<NewHbCodec>,
893            }),
894            on_open: self.on_open,
895            on_message: self.on_message,
896            on_message_raw: self.on_message_raw,
897            on_message_raw_bytes: self.on_message_raw_bytes,
898            on_close: self.on_close,
899            on_error: self.on_error,
900            reconnect_limit: self.reconnect_limit,
901            reconnect_interval: self.reconnect_interval,
902            immediate: self.immediate,
903            protocols: self.protocols,
904        }
905    }
906}
907
908impl<Rx: ?Sized, E, D> Default for UseWebSocketOptions<Rx, E, D, (), DummyEncoder> {
909    fn default() -> Self {
910        Self {
911            heartbeat: None,
912            on_open: Arc::new(|_| {}),
913            on_message: Arc::new(|_| {}),
914            on_message_raw: Arc::new(|_| {}),
915            on_message_raw_bytes: Arc::new(|_| {}),
916            on_error: Arc::new(|_| {}),
917            on_close: Arc::new(|_| {}),
918            reconnect_limit: ReconnectLimit::default(),
919            reconnect_interval: 3000,
920            immediate: true,
921            protocols: Default::default(),
922        }
923    }
924}
925
926pub struct DummyEncoder;
927
928impl Encoder<()> for DummyEncoder {
929    type Encoded = String;
930    type Error = ();
931
932    fn encode(_: &()) -> Result<Self::Encoded, Self::Error> {
933        Ok("".to_string())
934    }
935}
936
937/// Options for heartbeats
938#[cfg_attr(feature = "ssr", allow(dead_code))]
939pub struct HeartbeatOptions<Hb, HbCodec>
940where
941    Hb: Default + Send + Sync + 'static,
942    HbCodec: Encoder<Hb>,
943    HbCodec: HybridEncoder<
944            Hb,
945            <HbCodec as Encoder<Hb>>::Encoded,
946            Error = <HbCodec as Encoder<Hb>>::Error,
947        >,
948{
949    /// Heartbeat data that will be sent to the server
950    data: PhantomData<Hb>,
951    /// Heartbeat interval in ms. A heartbeat will be sent every `interval` ms.
952    interval: u64,
953    /// Codec used to encode the heartbeat data
954    codec: PhantomData<HbCodec>,
955}
956
957impl<Hb, HbCodec> Clone for HeartbeatOptions<Hb, HbCodec>
958where
959    Hb: Default + Send + Sync + 'static,
960    HbCodec: Encoder<Hb>,
961    HbCodec: HybridEncoder<
962            Hb,
963            <HbCodec as Encoder<Hb>>::Encoded,
964            Error = <HbCodec as Encoder<Hb>>::Error,
965        >,
966{
967    fn clone(&self) -> Self {
968        *self
969    }
970}
971
972impl<Hb, HbCodec> Copy for HeartbeatOptions<Hb, HbCodec>
973where
974    Hb: Default + Send + Sync + 'static,
975    HbCodec: Encoder<Hb>,
976    HbCodec: HybridEncoder<
977            Hb,
978            <HbCodec as Encoder<Hb>>::Encoded,
979            Error = <HbCodec as Encoder<Hb>>::Error,
980        >,
981{
982}
983
984/// Return type of [`use_websocket`].
985#[derive(Clone)]
986pub struct UseWebSocketReturn<Tx, Rx, OpenFn, CloseFn, SendFn>
987where
988    Tx: Send + Sync + 'static,
989    Rx: Send + Sync + 'static,
990    OpenFn: Fn() + Clone + Send + Sync + 'static,
991    CloseFn: Fn() + Clone + Send + Sync + 'static,
992    SendFn: Fn(&Tx) + Clone + Send + Sync + 'static,
993{
994    /// The current state of the `WebSocket` connection.
995    pub ready_state: Signal<ConnectionReadyState>,
996    /// Latest message received from `WebSocket`.
997    pub message: Signal<Option<Rx>>,
998    /// Opens the `WebSocket` connection
999    pub open: OpenFn,
1000    /// Closes the `WebSocket` connection
1001    pub close: CloseFn,
1002    /// Sends data through the socket
1003    pub send: SendFn,
1004
1005    _marker: PhantomData<Tx>,
1006}
1007
1008#[derive(Error, Debug)]
1009pub enum UseWebSocketError<E, D> {
1010    #[error("WebSocket error event")]
1011    Event(Event),
1012    #[error("WebSocket codec error: {0}")]
1013    Codec(#[from] CodecError<E, D>),
1014    #[error("WebSocket heartbeat codec error: {0}")]
1015    HeartbeatCodec(String),
1016}
1017
1018fn normalize_url(url: &str) -> String {
1019    cfg_if! { if #[cfg(feature = "ssr")] {
1020        url.to_string()
1021    } else {
1022        if url.starts_with("ws://") || url.starts_with("wss://") {
1023            url.to_string()
1024        } else if url.starts_with("//") {
1025            format!("{}{}", detect_protocol(), url)
1026        } else if url.starts_with('/') {
1027            format!(
1028                "{}//{}{}",
1029                detect_protocol(),
1030                window().location().host().expect("Host not found"),
1031                url
1032            )
1033        } else {
1034            let mut path = window().location().pathname().expect("Pathname not found");
1035            if !path.ends_with('/') {
1036                path.push('/')
1037            }
1038            format!(
1039                "{}//{}{}{}",
1040                detect_protocol(),
1041                window().location().host().expect("Host not found"),
1042                path,
1043                url
1044            )
1045        }
1046    }}
1047}
1048
1049#[cfg_attr(feature = "ssr", allow(dead_code))]
1050fn detect_protocol() -> String {
1051    cfg_if! { if #[cfg(feature = "ssr")] {
1052        "ws".to_string()
1053    } else {
1054        window().location().protocol().expect("Protocol not found").replace("http", "ws")
1055    }}
1056}