ratatui_kit/hooks/
use_state.rs1use super::{Hook, Hooks};
2use generational_box::{
3 AnyStorage, BorrowError, BorrowMutError, GenerationalBox, Owner, SyncStorage,
4};
5use std::{
6 cmp,
7 fmt::{self, Debug, Display, Formatter},
8 hash::{Hash, Hasher},
9 ops::{self, Deref, DerefMut},
10 task::{Poll, Waker},
11};
12
13mod private {
14 pub trait Sealed {}
15 impl Sealed for crate::hooks::Hooks<'_, '_> {}
16}
17
18pub trait UseState: private::Sealed {
19 fn use_state<T, F>(&mut self, init: F) -> State<T>
21 where
22 F: FnOnce() -> T,
23 T: Unpin + Send + Sync + 'static;
24}
25
26struct UseStateImpl<T>
27where
28 T: Unpin + Send + Sync + 'static,
29{
30 state: State<T>,
31 _storage: Owner<SyncStorage>,
32}
33
34impl<T> UseStateImpl<T>
35where
36 T: Unpin + Send + Sync + 'static,
37{
38 pub fn new(initial_value: T) -> Self {
39 let storage = Owner::default();
40 UseStateImpl {
41 state: State {
42 inner: storage.insert(StateValue {
43 value: initial_value,
44 waker: None,
45 is_changed: false,
46 }),
47 },
48 _storage: storage,
49 }
50 }
51}
52
53impl<T> Hook for UseStateImpl<T>
54where
55 T: Unpin + Send + Sync + 'static,
56{
57 fn poll_change(
58 mut self: std::pin::Pin<&mut Self>,
59 cx: &mut std::task::Context,
60 ) -> std::task::Poll<()> {
61 if let Ok(mut value) = self.state.inner.try_write() {
62 if value.is_changed {
63 value.is_changed = false;
64 Poll::Ready(())
65 } else {
66 value.waker = Some(cx.waker().clone());
67 Poll::Pending
68 }
69 } else {
70 Poll::Pending
71 }
72 }
73}
74
75impl UseState for Hooks<'_, '_> {
76 fn use_state<T, F>(&mut self, init: F) -> State<T>
77 where
78 F: FnOnce() -> T,
79 T: Unpin + Send + Sync + 'static,
80 {
81 self.use_hook(move || UseStateImpl::new(init())).state
82 }
83}
84
85struct StateValue<T> {
86 value: T,
87 waker: Option<Waker>,
88 is_changed: bool,
89}
90
91pub struct StateRef<'a, T: 'static> {
92 inner: <SyncStorage as AnyStorage>::Ref<'a, StateValue<T>>,
93}
94
95impl<T: 'static> Deref for StateRef<'_, T> {
96 type Target = T;
97
98 fn deref(&self) -> &Self::Target {
99 &self.inner.value
100 }
101}
102
103pub struct StateMutRef<'a, T: 'static> {
104 inner: <SyncStorage as AnyStorage>::Mut<'a, StateValue<T>>,
105 is_deref_mut: bool,
106}
107
108impl<T: 'static> Deref for StateMutRef<'_, T> {
109 type Target = T;
110
111 fn deref(&self) -> &Self::Target {
112 &self.inner.value
113 }
114}
115
116impl<T: 'static> DerefMut for StateMutRef<'_, T> {
117 fn deref_mut(&mut self) -> &mut Self::Target {
118 self.is_deref_mut = true;
119 &mut self.inner.value
120 }
121}
122
123impl<T: 'static> Drop for StateMutRef<'_, T> {
124 fn drop(&mut self) {
125 if self.is_deref_mut {
126 self.inner.is_changed = true;
127 if let Some(waker) = self.inner.waker.take() {
128 waker.wake();
129 }
130 }
131 }
132}
133
134pub struct State<T: Send + Sync + 'static> {
135 inner: GenerationalBox<StateValue<T>, SyncStorage>,
136}
137
138impl<T: Send + Sync + 'static> Clone for State<T> {
139 fn clone(&self) -> Self {
140 *self
141 }
142}
143
144impl<T: Send + Sync + 'static> Copy for State<T> {}
145
146impl<T: Send + Sync + Copy + 'static> State<T> {
147 pub fn get(&self) -> T {
148 *self.read()
149 }
150}
151
152impl<T: Send + Sync + 'static> State<T> {
153 pub fn try_read(&'_ self) -> Option<StateRef<'_, T>> {
154 loop {
155 match self.inner.try_read() {
156 Ok(inner) => return Some(StateRef { inner }),
157 Err(BorrowError::Dropped(_)) => {
158 return None;
159 }
160 Err(BorrowError::AlreadyBorrowedMut(_)) => match self.inner.try_write() {
161 Err(BorrowMutError::Dropped(_)) => {
162 return None;
163 }
164 _ => continue,
165 },
166 }
167 }
168 }
169
170 pub fn read(&'_ self) -> StateRef<'_, T> {
171 self.try_read()
172 .expect("attempt to read state after owner was dropped")
173 }
174
175 pub fn try_write(&'_ self) -> Option<StateMutRef<'_, T>> {
176 self.inner
177 .try_write()
178 .map(|inner| StateMutRef {
179 inner,
180 is_deref_mut: false,
181 })
182 .ok()
183 }
184
185 pub fn write(&'_ self) -> StateMutRef<'_, T> {
186 self.try_write()
187 .expect("attempt to write state after owner was dropped")
188 }
189
190 pub fn set(&mut self, value: T) {
191 if let Some(mut v) = self.try_write() {
192 *v = value;
193 }
194 }
195}
196
197impl<T: Debug + Sync + Send + 'static> Debug for State<T> {
198 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
199 self.read().fmt(f)
200 }
201}
202
203impl<T: Display + Sync + Send + 'static> Display for State<T> {
204 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
205 self.read().fmt(f)
206 }
207}
208
209impl<T: ops::Add<Output = T> + Copy + Sync + Send + 'static> ops::Add<T> for State<T> {
210 type Output = T;
211
212 fn add(self, rhs: T) -> Self::Output {
213 self.get() + rhs
214 }
215}
216
217impl<T: ops::AddAssign<T> + Copy + Sync + Send + 'static> ops::AddAssign<T> for State<T> {
218 fn add_assign(&mut self, rhs: T) {
219 if let Some(mut v) = self.try_write() {
220 *v += rhs;
221 }
222 }
223}
224
225impl<T: ops::Sub<Output = T> + Copy + Sync + Send + 'static> ops::Sub<T> for State<T> {
226 type Output = T;
227
228 fn sub(self, rhs: T) -> Self::Output {
229 self.get() - rhs
230 }
231}
232
233impl<T: ops::SubAssign<T> + Copy + Sync + Send + 'static> ops::SubAssign<T> for State<T> {
234 fn sub_assign(&mut self, rhs: T) {
235 if let Some(mut v) = self.try_write() {
236 *v -= rhs;
237 }
238 }
239}
240
241impl<T: ops::Mul<Output = T> + Copy + Sync + Send + 'static> ops::Mul<T> for State<T> {
242 type Output = T;
243
244 fn mul(self, rhs: T) -> Self::Output {
245 self.get() * rhs
246 }
247}
248
249impl<T: ops::MulAssign<T> + Copy + Sync + Send + 'static> ops::MulAssign<T> for State<T> {
250 fn mul_assign(&mut self, rhs: T) {
251 if let Some(mut v) = self.try_write() {
252 *v *= rhs;
253 }
254 }
255}
256
257impl<T: ops::Div<Output = T> + Copy + Sync + Send + 'static> ops::Div<T> for State<T> {
258 type Output = T;
259
260 fn div(self, rhs: T) -> Self::Output {
261 self.get() / rhs
262 }
263}
264
265impl<T: ops::DivAssign<T> + Copy + Sync + Send + 'static> ops::DivAssign<T> for State<T> {
266 fn div_assign(&mut self, rhs: T) {
267 if let Some(mut v) = self.try_write() {
268 *v /= rhs;
269 }
270 }
271}
272
273impl<T: Hash + Sync + Send> Hash for State<T> {
274 fn hash<H: Hasher>(&self, state: &mut H) {
275 self.read().hash(state)
276 }
277}
278
279impl<T: cmp::PartialEq<T> + Sync + Send + 'static> cmp::PartialEq<T> for State<T> {
280 fn eq(&self, other: &T) -> bool {
281 *self.read() == *other
282 }
283}
284
285impl<T: cmp::PartialOrd<T> + Sync + Send + 'static> cmp::PartialOrd<T> for State<T> {
286 fn partial_cmp(&self, other: &T) -> Option<cmp::Ordering> {
287 self.read().partial_cmp(other)
288 }
289}
290
291impl<T: cmp::PartialEq<T> + Sync + Send + 'static> cmp::PartialEq<State<T>> for State<T> {
292 fn eq(&self, other: &State<T>) -> bool {
293 *self.read() == *other.read()
294 }
295}
296
297impl<T: cmp::PartialOrd<T> + Sync + Send + 'static> cmp::PartialOrd<State<T>> for State<T> {
298 fn partial_cmp(&self, other: &State<T>) -> Option<cmp::Ordering> {
299 self.read().partial_cmp(&other.read())
300 }
301}
302
303impl<T: cmp::Eq + Sync + Send + 'static> cmp::Eq for State<T> {}