ratatui_kit/hooks/
use_state.rs

1use super::{Hook, Hooks};
2use generational_box::{
3    AnyStorage, BorrowError, BorrowMutError, GenerationalBox, Owner, SyncStorage,
4};
5use std::{
6    cmp,
7    fmt::{self, Debug, Display, Formatter},
8    hash::{Hash, Hasher},
9    ops::{self, Deref, DerefMut},
10    task::{Poll, Waker},
11};
12
13mod private {
14    pub trait Sealed {}
15    impl Sealed for crate::hooks::Hooks<'_, '_> {}
16}
17
18pub trait UseState: private::Sealed {
19    fn use_state<T, F>(&mut self, init: F) -> State<T>
20    where
21        F: FnOnce() -> T,
22        T: Unpin + Send + Sync + 'static;
23}
24
25struct UseStateImpl<T>
26where
27    T: Unpin + Send + Sync + 'static,
28{
29    state: State<T>,
30    _storage: Owner<SyncStorage>,
31}
32
33impl<T> UseStateImpl<T>
34where
35    T: Unpin + Send + Sync + 'static,
36{
37    pub fn new(initial_value: T) -> Self {
38        let storage = Owner::default();
39        UseStateImpl {
40            state: State {
41                inner: storage.insert(StateValue {
42                    value: initial_value,
43                    waker: None,
44                    is_changed: false,
45                }),
46            },
47            _storage: storage,
48        }
49    }
50}
51
52impl<T> Hook for UseStateImpl<T>
53where
54    T: Unpin + Send + Sync + 'static,
55{
56    fn poll_change(
57        mut self: std::pin::Pin<&mut Self>,
58        cx: &mut std::task::Context,
59    ) -> std::task::Poll<()> {
60        if let Ok(mut value) = self.state.inner.try_write() {
61            if value.is_changed {
62                value.is_changed = false;
63                Poll::Ready(())
64            } else {
65                value.waker = Some(cx.waker().clone());
66                Poll::Pending
67            }
68        } else {
69            Poll::Pending
70        }
71    }
72}
73
74impl UseState for Hooks<'_, '_> {
75    fn use_state<T, F>(&mut self, init: F) -> State<T>
76    where
77        F: FnOnce() -> T,
78        T: Unpin + Send + Sync + 'static,
79    {
80        self.use_hook(move || UseStateImpl::new(init())).state
81    }
82}
83
84struct StateValue<T> {
85    value: T,
86    waker: Option<Waker>,
87    is_changed: bool,
88}
89
90pub struct StateRef<'a, T: 'static> {
91    inner: <SyncStorage as AnyStorage>::Ref<'a, StateValue<T>>,
92}
93
94impl<T: 'static> Deref for StateRef<'_, T> {
95    type Target = T;
96
97    fn deref(&self) -> &Self::Target {
98        &self.inner.value
99    }
100}
101
102pub struct StateMutRef<'a, T: 'static> {
103    inner: <SyncStorage as AnyStorage>::Mut<'a, StateValue<T>>,
104    is_deref_mut: bool,
105}
106
107impl<T: 'static> Deref for StateMutRef<'_, T> {
108    type Target = T;
109
110    fn deref(&self) -> &Self::Target {
111        &self.inner.value
112    }
113}
114
115impl<T: 'static> DerefMut for StateMutRef<'_, T> {
116    fn deref_mut(&mut self) -> &mut Self::Target {
117        self.is_deref_mut = true;
118        &mut self.inner.value
119    }
120}
121
122impl<T: 'static> Drop for StateMutRef<'_, T> {
123    fn drop(&mut self) {
124        if self.is_deref_mut {
125            self.inner.is_changed = true;
126            if let Some(waker) = self.inner.waker.take() {
127                waker.wake();
128            }
129        }
130    }
131}
132
133pub struct State<T: Send + Sync + 'static> {
134    inner: GenerationalBox<StateValue<T>, SyncStorage>,
135}
136
137impl<T: Send + Sync + 'static> Clone for State<T> {
138    fn clone(&self) -> Self {
139        *self
140    }
141}
142
143impl<T: Send + Sync + 'static> Copy for State<T> {}
144
145impl<T: Send + Sync + Copy + 'static> State<T> {
146    pub fn get(&self) -> T {
147        *self.read()
148    }
149}
150
151impl<T: Send + Sync + 'static> State<T> {
152    pub fn try_read(&self) -> Option<StateRef<T>> {
153        loop {
154            match self.inner.try_read() {
155                Ok(inner) => return Some(StateRef { inner }),
156                Err(BorrowError::Dropped(_)) => {
157                    return None;
158                }
159                Err(BorrowError::AlreadyBorrowedMut(_)) => match self.inner.try_write() {
160                    Err(BorrowMutError::Dropped(_)) => {
161                        return None;
162                    }
163                    _ => continue,
164                },
165            }
166        }
167    }
168
169    pub fn read(&self) -> StateRef<T> {
170        self.try_read()
171            .expect("attempt to read state after owner was dropped")
172    }
173
174    pub fn try_write(&self) -> Option<StateMutRef<T>> {
175        self.inner
176            .try_write()
177            .map(|inner| StateMutRef {
178                inner,
179                is_deref_mut: false,
180            })
181            .ok()
182    }
183
184    pub fn write(&self) -> StateMutRef<T> {
185        self.try_write()
186            .expect("attempt to write state after owner was dropped")
187    }
188
189    pub fn set(&mut self, value: T) {
190        if let Some(mut v) = self.try_write() {
191            *v = value;
192        }
193    }
194}
195
196impl<T: Debug + Sync + Send + 'static> Debug for State<T> {
197    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
198        self.read().fmt(f)
199    }
200}
201
202impl<T: Display + Sync + Send + 'static> Display for State<T> {
203    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
204        self.read().fmt(f)
205    }
206}
207
208impl<T: ops::Add<Output = T> + Copy + Sync + Send + 'static> ops::Add<T> for State<T> {
209    type Output = T;
210
211    fn add(self, rhs: T) -> Self::Output {
212        self.get() + rhs
213    }
214}
215
216impl<T: ops::AddAssign<T> + Copy + Sync + Send + 'static> ops::AddAssign<T> for State<T> {
217    fn add_assign(&mut self, rhs: T) {
218        if let Some(mut v) = self.try_write() {
219            *v += rhs;
220        }
221    }
222}
223
224impl<T: ops::Sub<Output = T> + Copy + Sync + Send + 'static> ops::Sub<T> for State<T> {
225    type Output = T;
226
227    fn sub(self, rhs: T) -> Self::Output {
228        self.get() - rhs
229    }
230}
231
232impl<T: ops::SubAssign<T> + Copy + Sync + Send + 'static> ops::SubAssign<T> for State<T> {
233    fn sub_assign(&mut self, rhs: T) {
234        if let Some(mut v) = self.try_write() {
235            *v -= rhs;
236        }
237    }
238}
239
240impl<T: ops::Mul<Output = T> + Copy + Sync + Send + 'static> ops::Mul<T> for State<T> {
241    type Output = T;
242
243    fn mul(self, rhs: T) -> Self::Output {
244        self.get() * rhs
245    }
246}
247
248impl<T: ops::MulAssign<T> + Copy + Sync + Send + 'static> ops::MulAssign<T> for State<T> {
249    fn mul_assign(&mut self, rhs: T) {
250        if let Some(mut v) = self.try_write() {
251            *v *= rhs;
252        }
253    }
254}
255
256impl<T: ops::Div<Output = T> + Copy + Sync + Send + 'static> ops::Div<T> for State<T> {
257    type Output = T;
258
259    fn div(self, rhs: T) -> Self::Output {
260        self.get() / rhs
261    }
262}
263
264impl<T: ops::DivAssign<T> + Copy + Sync + Send + 'static> ops::DivAssign<T> for State<T> {
265    fn div_assign(&mut self, rhs: T) {
266        if let Some(mut v) = self.try_write() {
267            *v /= rhs;
268        }
269    }
270}
271
272impl<T: Hash + Sync + Send> Hash for State<T> {
273    fn hash<H: Hasher>(&self, state: &mut H) {
274        self.read().hash(state)
275    }
276}
277
278impl<T: cmp::PartialEq<T> + Sync + Send + 'static> cmp::PartialEq<T> for State<T> {
279    fn eq(&self, other: &T) -> bool {
280        *self.read() == *other
281    }
282}
283
284impl<T: cmp::PartialOrd<T> + Sync + Send + 'static> cmp::PartialOrd<T> for State<T> {
285    fn partial_cmp(&self, other: &T) -> Option<cmp::Ordering> {
286        self.read().partial_cmp(other)
287    }
288}
289
290impl<T: cmp::PartialEq<T> + Sync + Send + 'static> cmp::PartialEq<State<T>> for State<T> {
291    fn eq(&self, other: &State<T>) -> bool {
292        *self.read() == *other.read()
293    }
294}
295
296impl<T: cmp::PartialOrd<T> + Sync + Send + 'static> cmp::PartialOrd<State<T>> for State<T> {
297    fn partial_cmp(&self, other: &State<T>) -> Option<cmp::Ordering> {
298        self.read().partial_cmp(&other.read())
299    }
300}
301
302impl<T: cmp::Eq + Sync + Send + 'static> cmp::Eq for State<T> {}