ratatui_kit/hooks/
use_state.rs1use super::{Hook, Hooks};
2use generational_box::{
3 AnyStorage, BorrowError, BorrowMutError, GenerationalBox, Owner, SyncStorage,
4};
5use std::{
6 cmp,
7 fmt::{self, Debug, Display, Formatter},
8 hash::{Hash, Hasher},
9 ops::{self, Deref, DerefMut},
10 task::{Poll, Waker},
11};
12
13mod private {
14 pub trait Sealed {}
15 impl Sealed for crate::hooks::Hooks<'_, '_> {}
16}
17
18pub trait UseState: private::Sealed {
19 fn use_state<T, F>(&mut self, init: F) -> State<T>
20 where
21 F: FnOnce() -> T,
22 T: Unpin + Send + Sync + 'static;
23}
24
25struct UseStateImpl<T>
26where
27 T: Unpin + Send + Sync + 'static,
28{
29 state: State<T>,
30 _storage: Owner<SyncStorage>,
31}
32
33impl<T> UseStateImpl<T>
34where
35 T: Unpin + Send + Sync + 'static,
36{
37 pub fn new(initial_value: T) -> Self {
38 let storage = Owner::default();
39 UseStateImpl {
40 state: State {
41 inner: storage.insert(StateValue {
42 value: initial_value,
43 waker: None,
44 is_changed: false,
45 }),
46 },
47 _storage: storage,
48 }
49 }
50}
51
52impl<T> Hook for UseStateImpl<T>
53where
54 T: Unpin + Send + Sync + 'static,
55{
56 fn poll_change(
57 mut self: std::pin::Pin<&mut Self>,
58 cx: &mut std::task::Context,
59 ) -> std::task::Poll<()> {
60 if let Ok(mut value) = self.state.inner.try_write() {
61 if value.is_changed {
62 value.is_changed = false;
63 Poll::Ready(())
64 } else {
65 value.waker = Some(cx.waker().clone());
66 Poll::Pending
67 }
68 } else {
69 Poll::Pending
70 }
71 }
72}
73
74impl UseState for Hooks<'_, '_> {
75 fn use_state<T, F>(&mut self, init: F) -> State<T>
76 where
77 F: FnOnce() -> T,
78 T: Unpin + Send + Sync + 'static,
79 {
80 self.use_hook(move || UseStateImpl::new(init())).state
81 }
82}
83
84struct StateValue<T> {
85 value: T,
86 waker: Option<Waker>,
87 is_changed: bool,
88}
89
90pub struct StateRef<'a, T: 'static> {
91 inner: <SyncStorage as AnyStorage>::Ref<'a, StateValue<T>>,
92}
93
94impl<T: 'static> Deref for StateRef<'_, T> {
95 type Target = T;
96
97 fn deref(&self) -> &Self::Target {
98 &self.inner.value
99 }
100}
101
102pub struct StateMutRef<'a, T: 'static> {
103 inner: <SyncStorage as AnyStorage>::Mut<'a, StateValue<T>>,
104 is_deref_mut: bool,
105}
106
107impl<T: 'static> Deref for StateMutRef<'_, T> {
108 type Target = T;
109
110 fn deref(&self) -> &Self::Target {
111 &self.inner.value
112 }
113}
114
115impl<T: 'static> DerefMut for StateMutRef<'_, T> {
116 fn deref_mut(&mut self) -> &mut Self::Target {
117 self.is_deref_mut = true;
118 &mut self.inner.value
119 }
120}
121
122impl<T: 'static> Drop for StateMutRef<'_, T> {
123 fn drop(&mut self) {
124 if self.is_deref_mut {
125 self.inner.is_changed = true;
126 if let Some(waker) = self.inner.waker.take() {
127 waker.wake();
128 }
129 }
130 }
131}
132
133pub struct State<T: Send + Sync + 'static> {
134 inner: GenerationalBox<StateValue<T>, SyncStorage>,
135}
136
137impl<T: Send + Sync + 'static> Clone for State<T> {
138 fn clone(&self) -> Self {
139 *self
140 }
141}
142
143impl<T: Send + Sync + 'static> Copy for State<T> {}
144
145impl<T: Send + Sync + Copy + 'static> State<T> {
146 pub fn get(&self) -> T {
147 *self.read()
148 }
149}
150
151impl<T: Send + Sync + 'static> State<T> {
152 pub fn try_read(&self) -> Option<StateRef<T>> {
153 loop {
154 match self.inner.try_read() {
155 Ok(inner) => return Some(StateRef { inner }),
156 Err(BorrowError::Dropped(_)) => {
157 return None;
158 }
159 Err(BorrowError::AlreadyBorrowedMut(_)) => match self.inner.try_write() {
160 Err(BorrowMutError::Dropped(_)) => {
161 return None;
162 }
163 _ => continue,
164 },
165 }
166 }
167 }
168
169 pub fn read(&self) -> StateRef<T> {
170 self.try_read()
171 .expect("attempt to read state after owner was dropped")
172 }
173
174 pub fn try_write(&self) -> Option<StateMutRef<T>> {
175 self.inner
176 .try_write()
177 .map(|inner| StateMutRef {
178 inner,
179 is_deref_mut: false,
180 })
181 .ok()
182 }
183
184 pub fn write(&self) -> StateMutRef<T> {
185 self.try_write()
186 .expect("attempt to write state after owner was dropped")
187 }
188
189 pub fn set(&mut self, value: T) {
190 if let Some(mut v) = self.try_write() {
191 *v = value;
192 }
193 }
194}
195
196impl<T: Debug + Sync + Send + 'static> Debug for State<T> {
197 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
198 self.read().fmt(f)
199 }
200}
201
202impl<T: Display + Sync + Send + 'static> Display for State<T> {
203 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
204 self.read().fmt(f)
205 }
206}
207
208impl<T: ops::Add<Output = T> + Copy + Sync + Send + 'static> ops::Add<T> for State<T> {
209 type Output = T;
210
211 fn add(self, rhs: T) -> Self::Output {
212 self.get() + rhs
213 }
214}
215
216impl<T: ops::AddAssign<T> + Copy + Sync + Send + 'static> ops::AddAssign<T> for State<T> {
217 fn add_assign(&mut self, rhs: T) {
218 if let Some(mut v) = self.try_write() {
219 *v += rhs;
220 }
221 }
222}
223
224impl<T: ops::Sub<Output = T> + Copy + Sync + Send + 'static> ops::Sub<T> for State<T> {
225 type Output = T;
226
227 fn sub(self, rhs: T) -> Self::Output {
228 self.get() - rhs
229 }
230}
231
232impl<T: ops::SubAssign<T> + Copy + Sync + Send + 'static> ops::SubAssign<T> for State<T> {
233 fn sub_assign(&mut self, rhs: T) {
234 if let Some(mut v) = self.try_write() {
235 *v -= rhs;
236 }
237 }
238}
239
240impl<T: ops::Mul<Output = T> + Copy + Sync + Send + 'static> ops::Mul<T> for State<T> {
241 type Output = T;
242
243 fn mul(self, rhs: T) -> Self::Output {
244 self.get() * rhs
245 }
246}
247
248impl<T: ops::MulAssign<T> + Copy + Sync + Send + 'static> ops::MulAssign<T> for State<T> {
249 fn mul_assign(&mut self, rhs: T) {
250 if let Some(mut v) = self.try_write() {
251 *v *= rhs;
252 }
253 }
254}
255
256impl<T: ops::Div<Output = T> + Copy + Sync + Send + 'static> ops::Div<T> for State<T> {
257 type Output = T;
258
259 fn div(self, rhs: T) -> Self::Output {
260 self.get() / rhs
261 }
262}
263
264impl<T: ops::DivAssign<T> + Copy + Sync + Send + 'static> ops::DivAssign<T> for State<T> {
265 fn div_assign(&mut self, rhs: T) {
266 if let Some(mut v) = self.try_write() {
267 *v /= rhs;
268 }
269 }
270}
271
272impl<T: Hash + Sync + Send> Hash for State<T> {
273 fn hash<H: Hasher>(&self, state: &mut H) {
274 self.read().hash(state)
275 }
276}
277
278impl<T: cmp::PartialEq<T> + Sync + Send + 'static> cmp::PartialEq<T> for State<T> {
279 fn eq(&self, other: &T) -> bool {
280 *self.read() == *other
281 }
282}
283
284impl<T: cmp::PartialOrd<T> + Sync + Send + 'static> cmp::PartialOrd<T> for State<T> {
285 fn partial_cmp(&self, other: &T) -> Option<cmp::Ordering> {
286 self.read().partial_cmp(other)
287 }
288}
289
290impl<T: cmp::PartialEq<T> + Sync + Send + 'static> cmp::PartialEq<State<T>> for State<T> {
291 fn eq(&self, other: &State<T>) -> bool {
292 *self.read() == *other.read()
293 }
294}
295
296impl<T: cmp::PartialOrd<T> + Sync + Send + 'static> cmp::PartialOrd<State<T>> for State<T> {
297 fn partial_cmp(&self, other: &State<T>) -> Option<cmp::Ordering> {
298 self.read().partial_cmp(&other.read())
299 }
300}
301
302impl<T: cmp::Eq + Sync + Send + 'static> cmp::Eq for State<T> {}