leptos_fetch/
arc_local_signal.rs

1use std::{fmt::Debug, ops::Deref};
2
3use leptos::prelude::{
4    ArcSignal, DefinedAt, Get, LocalStorage, ReadUntracked, Signal, Track,
5    guards::{Mapped, ReadGuard},
6};
7use send_wrapper::SendWrapper;
8
9/// A local variant of an [`ArcSignal`], that will panic if accessed from a different thread.
10///
11/// Used for the [`QueryClient::subscribe_value_arc_local`] return type.
12pub struct ArcLocalSignal<T: 'static>(ArcSignal<SendWrapper<T>>);
13
14impl<T> Debug for ArcLocalSignal<T>
15where
16    T: Debug + 'static,
17{
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        f.debug_tuple("ArcLocalSignal").field(&self.0).finish()
20    }
21}
22
23impl<T> Clone for ArcLocalSignal<T> {
24    fn clone(&self) -> Self {
25        Self(self.0.clone())
26    }
27}
28
29impl<T> ArcLocalSignal<T> {
30    /// Like [`ArcSignal::derive`], but the value is not threadsafe and will panic if accessed from a different thread.
31    pub fn derive_local(derive_fn: impl Fn() -> T + 'static) -> Self {
32        let derive_fn = SendWrapper::new(derive_fn);
33        Self(ArcSignal::derive(move || {
34            let value = derive_fn();
35            SendWrapper::new(value)
36        }))
37    }
38}
39
40impl<T> DefinedAt for ArcLocalSignal<T> {
41    fn defined_at(&self) -> Option<&'static std::panic::Location<'static>> {
42        self.0.defined_at()
43    }
44}
45
46impl<T> ReadUntracked for ArcLocalSignal<T> {
47    type Value = ReadGuard<T, Mapped<<ArcSignal<SendWrapper<T>> as ReadUntracked>::Value, T>>;
48
49    fn try_read_untracked(&self) -> Option<Self::Value> {
50        self.0
51            .try_read_untracked()
52            .map(|g| ReadGuard::new(Mapped::new_with_guard(g, |v| v.deref())))
53    }
54}
55
56impl<T> Track for ArcLocalSignal<T> {
57    fn track(&self) {
58        self.0.track()
59    }
60}
61
62impl<T: Clone> From<ArcLocalSignal<T>> for Signal<T, LocalStorage> {
63    fn from(value: ArcLocalSignal<T>) -> Self {
64        Signal::derive_local(move || value.get())
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use std::marker::PhantomData;
71    use std::ops::Deref;
72    use std::ptr::NonNull;
73
74    use super::*;
75    use leptos::prelude::*;
76
77    #[test]
78    fn test_local_arc_signal() {
79        #[derive(Debug)]
80        struct UnsyncValue(u64, PhantomData<NonNull<()>>);
81        impl PartialEq for UnsyncValue {
82            fn eq(&self, other: &Self) -> bool {
83                self.0 == other.0
84            }
85        }
86        impl Eq for UnsyncValue {}
87        impl Clone for UnsyncValue {
88            fn clone(&self) -> Self {
89                Self(self.0, PhantomData)
90            }
91        }
92        impl UnsyncValue {
93            fn new(value: u64) -> Self {
94                Self(value, PhantomData)
95            }
96        }
97
98        let signal = ArcLocalSignal::derive_local(|| UnsyncValue::new(42));
99        assert_eq!(signal.get_untracked().0, 42);
100        assert_eq!(signal.read_untracked().0, 42);
101
102        // Should be no SendWrapper in public interface:
103        let foo = signal.read_untracked().deref().clone();
104        assert_eq!(foo, UnsyncValue::new(42));
105    }
106}