ratatui_kit/hooks/
use_state.rs

1use super::{Hook, Hooks};
2use generational_box::{
3    AnyStorage, BorrowError, BorrowMutError, GenerationalBox, Owner, SyncStorage,
4};
5use std::{
6    cmp,
7    fmt::{self, Debug, Display, Formatter},
8    hash::{Hash, Hasher},
9    ops::{self, Deref, DerefMut},
10    task::{Poll, Waker},
11};
12
13mod private {
14    pub trait Sealed {}
15    impl Sealed for crate::hooks::Hooks<'_, '_> {}
16}
17
18pub trait UseState: private::Sealed {
19    /// 创建响应式状态,适合计数器、输入框等本地状态。
20    fn use_state<T, F>(&mut self, init: F) -> State<T>
21    where
22        F: FnOnce() -> T,
23        T: Unpin + Send + Sync + 'static;
24}
25
26struct UseStateImpl<T>
27where
28    T: Unpin + Send + Sync + 'static,
29{
30    state: State<T>,
31    _storage: Owner<SyncStorage>,
32}
33
34impl<T> UseStateImpl<T>
35where
36    T: Unpin + Send + Sync + 'static,
37{
38    pub fn new(initial_value: T) -> Self {
39        let storage = Owner::default();
40        UseStateImpl {
41            state: State {
42                inner: storage.insert(StateValue {
43                    value: initial_value,
44                    waker: None,
45                    is_changed: false,
46                }),
47            },
48            _storage: storage,
49        }
50    }
51}
52
53impl<T> Hook for UseStateImpl<T>
54where
55    T: Unpin + Send + Sync + 'static,
56{
57    fn poll_change(
58        mut self: std::pin::Pin<&mut Self>,
59        cx: &mut std::task::Context,
60    ) -> std::task::Poll<()> {
61        if let Ok(mut value) = self.state.inner.try_write() {
62            if value.is_changed {
63                value.is_changed = false;
64                Poll::Ready(())
65            } else {
66                value.waker = Some(cx.waker().clone());
67                Poll::Pending
68            }
69        } else {
70            Poll::Pending
71        }
72    }
73}
74
75impl UseState for Hooks<'_, '_> {
76    fn use_state<T, F>(&mut self, init: F) -> State<T>
77    where
78        F: FnOnce() -> T,
79        T: Unpin + Send + Sync + 'static,
80    {
81        self.use_hook(move || UseStateImpl::new(init())).state
82    }
83}
84
85struct StateValue<T> {
86    value: T,
87    waker: Option<Waker>,
88    is_changed: bool,
89}
90
91pub struct StateRef<'a, T: 'static> {
92    inner: <SyncStorage as AnyStorage>::Ref<'a, StateValue<T>>,
93}
94
95impl<T: 'static> Deref for StateRef<'_, T> {
96    type Target = T;
97
98    fn deref(&self) -> &Self::Target {
99        &self.inner.value
100    }
101}
102
103pub struct StateMutRef<'a, T: 'static> {
104    inner: <SyncStorage as AnyStorage>::Mut<'a, StateValue<T>>,
105    is_deref_mut: bool,
106}
107
108impl<T: 'static> Deref for StateMutRef<'_, T> {
109    type Target = T;
110
111    fn deref(&self) -> &Self::Target {
112        &self.inner.value
113    }
114}
115
116impl<T: 'static> DerefMut for StateMutRef<'_, T> {
117    fn deref_mut(&mut self) -> &mut Self::Target {
118        self.is_deref_mut = true;
119        &mut self.inner.value
120    }
121}
122
123impl<T: 'static> Drop for StateMutRef<'_, T> {
124    fn drop(&mut self) {
125        if self.is_deref_mut {
126            self.inner.is_changed = true;
127            if let Some(waker) = self.inner.waker.take() {
128                waker.wake();
129            }
130        }
131    }
132}
133
134pub struct State<T: Send + Sync + 'static> {
135    inner: GenerationalBox<StateValue<T>, SyncStorage>,
136}
137
138impl<T: Send + Sync + 'static> Clone for State<T> {
139    fn clone(&self) -> Self {
140        *self
141    }
142}
143
144impl<T: Send + Sync + 'static> Copy for State<T> {}
145
146impl<T: Send + Sync + Copy + 'static> State<T> {
147    pub fn get(&self) -> T {
148        *self.read()
149    }
150}
151
152impl<T: Send + Sync + 'static> State<T> {
153    pub fn try_read(&'_ self) -> Option<StateRef<'_, T>> {
154        loop {
155            match self.inner.try_read() {
156                Ok(inner) => return Some(StateRef { inner }),
157                Err(BorrowError::Dropped(_)) => {
158                    return None;
159                }
160                Err(BorrowError::AlreadyBorrowedMut(_)) => match self.inner.try_write() {
161                    Err(BorrowMutError::Dropped(_)) => {
162                        return None;
163                    }
164                    _ => continue,
165                },
166            }
167        }
168    }
169
170    pub fn read(&'_ self) -> StateRef<'_, T> {
171        self.try_read()
172            .expect("attempt to read state after owner was dropped")
173    }
174
175    pub fn try_write(&'_ self) -> Option<StateMutRef<'_, T>> {
176        self.inner
177            .try_write()
178            .map(|inner| StateMutRef {
179                inner,
180                is_deref_mut: false,
181            })
182            .ok()
183    }
184
185    pub fn write(&'_ self) -> StateMutRef<'_, T> {
186        self.try_write()
187            .expect("attempt to write state after owner was dropped")
188    }
189
190    pub fn set(&mut self, value: T) {
191        if let Some(mut v) = self.try_write() {
192            *v = value;
193        }
194    }
195}
196
197impl<T: Debug + Sync + Send + 'static> Debug for State<T> {
198    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
199        self.read().fmt(f)
200    }
201}
202
203impl<T: Display + Sync + Send + 'static> Display for State<T> {
204    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
205        self.read().fmt(f)
206    }
207}
208
209impl<T: ops::Add<Output = T> + Copy + Sync + Send + 'static> ops::Add<T> for State<T> {
210    type Output = T;
211
212    fn add(self, rhs: T) -> Self::Output {
213        self.get() + rhs
214    }
215}
216
217impl<T: ops::AddAssign<T> + Copy + Sync + Send + 'static> ops::AddAssign<T> for State<T> {
218    fn add_assign(&mut self, rhs: T) {
219        if let Some(mut v) = self.try_write() {
220            *v += rhs;
221        }
222    }
223}
224
225impl<T: ops::Sub<Output = T> + Copy + Sync + Send + 'static> ops::Sub<T> for State<T> {
226    type Output = T;
227
228    fn sub(self, rhs: T) -> Self::Output {
229        self.get() - rhs
230    }
231}
232
233impl<T: ops::SubAssign<T> + Copy + Sync + Send + 'static> ops::SubAssign<T> for State<T> {
234    fn sub_assign(&mut self, rhs: T) {
235        if let Some(mut v) = self.try_write() {
236            *v -= rhs;
237        }
238    }
239}
240
241impl<T: ops::Mul<Output = T> + Copy + Sync + Send + 'static> ops::Mul<T> for State<T> {
242    type Output = T;
243
244    fn mul(self, rhs: T) -> Self::Output {
245        self.get() * rhs
246    }
247}
248
249impl<T: ops::MulAssign<T> + Copy + Sync + Send + 'static> ops::MulAssign<T> for State<T> {
250    fn mul_assign(&mut self, rhs: T) {
251        if let Some(mut v) = self.try_write() {
252            *v *= rhs;
253        }
254    }
255}
256
257impl<T: ops::Div<Output = T> + Copy + Sync + Send + 'static> ops::Div<T> for State<T> {
258    type Output = T;
259
260    fn div(self, rhs: T) -> Self::Output {
261        self.get() / rhs
262    }
263}
264
265impl<T: ops::DivAssign<T> + Copy + Sync + Send + 'static> ops::DivAssign<T> for State<T> {
266    fn div_assign(&mut self, rhs: T) {
267        if let Some(mut v) = self.try_write() {
268            *v /= rhs;
269        }
270    }
271}
272
273impl<T: Hash + Sync + Send> Hash for State<T> {
274    fn hash<H: Hasher>(&self, state: &mut H) {
275        self.read().hash(state)
276    }
277}
278
279impl<T: cmp::PartialEq<T> + Sync + Send + 'static> cmp::PartialEq<T> for State<T> {
280    fn eq(&self, other: &T) -> bool {
281        *self.read() == *other
282    }
283}
284
285impl<T: cmp::PartialOrd<T> + Sync + Send + 'static> cmp::PartialOrd<T> for State<T> {
286    fn partial_cmp(&self, other: &T) -> Option<cmp::Ordering> {
287        self.read().partial_cmp(other)
288    }
289}
290
291impl<T: cmp::PartialEq<T> + Sync + Send + 'static> cmp::PartialEq<State<T>> for State<T> {
292    fn eq(&self, other: &State<T>) -> bool {
293        *self.read() == *other.read()
294    }
295}
296
297impl<T: cmp::PartialOrd<T> + Sync + Send + 'static> cmp::PartialOrd<State<T>> for State<T> {
298    fn partial_cmp(&self, other: &State<T>) -> Option<cmp::Ordering> {
299        self.read().partial_cmp(&other.read())
300    }
301}
302
303impl<T: cmp::Eq + Sync + Send + 'static> cmp::Eq for State<T> {}