1use super::{Hook, Hooks};
24use generational_box::{
25 AnyStorage, BorrowError, BorrowMutError, GenerationalBox, Owner, SyncStorage,
26};
27use std::{
28 cmp,
29 fmt::{self, Debug, Display, Formatter},
30 hash::{Hash, Hasher},
31 ops::{self, Deref, DerefMut},
32 task::{Poll, Waker},
33};
34
35mod private {
36 pub trait Sealed {}
37 impl Sealed for crate::hooks::Hooks<'_, '_> {}
38}
39
40pub trait UseState: private::Sealed {
41 fn use_state<T, F>(&mut self, init: F) -> State<T>
46 where
47 F: FnOnce() -> T,
48 T: Unpin + Send + Sync + 'static;
49}
50
51struct UseStateImpl<T>
52where
53 T: Unpin + Send + Sync + 'static,
54{
55 state: State<T>,
56 _storage: Owner<SyncStorage>,
57}
58
59impl<T> UseStateImpl<T>
60where
61 T: Unpin + Send + Sync + 'static,
62{
63 pub fn new(initial_value: T) -> Self {
65 let storage = Owner::default();
66 UseStateImpl {
67 state: State {
68 inner: storage.insert(StateValue {
69 value: initial_value,
70 waker: None,
71 is_changed: false,
72 }),
73 },
74 _storage: storage,
75 }
76 }
77}
78
79impl<T> Hook for UseStateImpl<T>
80where
81 T: Unpin + Send + Sync + 'static,
82{
83 fn poll_change(
84 mut self: std::pin::Pin<&mut Self>,
85 cx: &mut std::task::Context,
86 ) -> std::task::Poll<()> {
87 if let Ok(mut value) = self.state.inner.try_write() {
88 if value.is_changed {
89 value.is_changed = false;
90 Poll::Ready(())
91 } else {
92 value.waker = Some(cx.waker().clone());
93 Poll::Pending
94 }
95 } else {
96 Poll::Pending
97 }
98 }
99}
100
101impl UseState for Hooks<'_, '_> {
102 fn use_state<T, F>(&mut self, init: F) -> State<T>
103 where
104 F: FnOnce() -> T,
105 T: Unpin + Send + Sync + 'static,
106 {
107 self.use_hook(move || UseStateImpl::new(init())).state
108 }
109}
110
111struct StateValue<T> {
112 value: T,
113 waker: Option<Waker>,
114 is_changed: bool,
115}
116
117pub struct StateRef<'a, T: 'static> {
120 inner: <SyncStorage as AnyStorage>::Ref<'a, StateValue<T>>,
121}
122
123impl<T: 'static> Deref for StateRef<'_, T> {
124 type Target = T;
125
126 fn deref(&self) -> &Self::Target {
127 &self.inner.value
128 }
129}
130
131pub struct StateMutRef<'a, T: 'static> {
134 inner: <SyncStorage as AnyStorage>::Mut<'a, StateValue<T>>,
135 is_deref_mut: bool,
136}
137
138impl<T: 'static> Deref for StateMutRef<'_, T> {
139 type Target = T;
140
141 fn deref(&self) -> &Self::Target {
142 &self.inner.value
143 }
144}
145
146impl<T: 'static> DerefMut for StateMutRef<'_, T> {
147 fn deref_mut(&mut self) -> &mut Self::Target {
148 self.is_deref_mut = true;
149 &mut self.inner.value
150 }
151}
152
153impl<T: 'static> Drop for StateMutRef<'_, T> {
154 fn drop(&mut self) {
155 if self.is_deref_mut {
156 self.inner.is_changed = true;
157 if let Some(waker) = self.inner.waker.take() {
158 waker.wake();
159 }
160 }
161 }
162}
163
164pub struct StateMutNoUpdate<'a, T: 'static> {
167 inner: <SyncStorage as AnyStorage>::Mut<'a, StateValue<T>>,
168}
169
170impl<T: 'static> Deref for StateMutNoUpdate<'_, T> {
171 type Target = T;
172
173 fn deref(&self) -> &Self::Target {
174 &self.inner.value
175 }
176}
177
178impl<T: 'static> DerefMut for StateMutNoUpdate<'_, T> {
179 fn deref_mut(&mut self) -> &mut Self::Target {
180 &mut self.inner.value
181 }
182}
183
184pub struct State<T: Send + Sync + 'static> {
187 inner: GenerationalBox<StateValue<T>, SyncStorage>,
188}
189
190impl<T: Send + Sync + 'static> Clone for State<T> {
191 fn clone(&self) -> Self {
192 *self
193 }
194}
195
196impl<T: Send + Sync + 'static> Copy for State<T> {}
197
198impl<T: Send + Sync + Copy + 'static> State<T> {
199 pub fn get(&self) -> T {
200 *self.read()
201 }
202}
203
204impl<T: Send + Sync + 'static> State<T> {
205 pub fn try_read(&'_ self) -> Option<StateRef<'_, T>> {
207 loop {
208 match self.inner.try_read() {
209 Ok(inner) => return Some(StateRef { inner }),
210 Err(BorrowError::Dropped(_)) => {
211 return None;
212 }
213 Err(BorrowError::AlreadyBorrowedMut(_)) => match self.inner.try_write() {
214 Err(BorrowMutError::Dropped(_)) => {
215 return None;
216 }
217 _ => continue,
218 },
219 }
220 }
221 }
222
223 pub fn read(&'_ self) -> StateRef<'_, T> {
225 self.try_read()
226 .expect("attempt to read state after owner was dropped")
227 }
228
229 pub fn try_write(&'_ self) -> Option<StateMutRef<'_, T>> {
231 self.inner
232 .try_write()
233 .map(|inner| StateMutRef {
234 inner,
235 is_deref_mut: false,
236 })
237 .ok()
238 }
239
240 pub fn write(&'_ self) -> StateMutRef<'_, T> {
242 self.try_write()
243 .expect("attempt to write state after owner was dropped")
244 }
245
246 pub fn try_write_no_update(&'_ self) -> Option<StateMutNoUpdate<'_, T>> {
248 self.inner
249 .try_write()
250 .map(|inner| StateMutNoUpdate { inner })
251 .ok()
252 }
253
254 pub fn write_no_update(&'_ self) -> StateMutNoUpdate<'_, T> {
256 self.try_write_no_update()
257 .expect("attempt to write state after owner was dropped")
258 }
259
260 pub fn set(&mut self, value: T) {
262 if let Some(mut v) = self.try_write() {
263 *v = value;
264 }
265 }
266
267 pub fn set_no_update(&mut self, value: T) {
269 if let Some(mut v) = self.try_write_no_update() {
270 *v = value;
271 }
272 }
273}
274
275impl<T: Debug + Sync + Send + 'static> Debug for State<T> {
276 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
277 self.read().fmt(f)
278 }
279}
280
281impl<T: Display + Sync + Send + 'static> Display for State<T> {
282 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
283 self.read().fmt(f)
284 }
285}
286
287impl<T: ops::Add<Output = T> + Copy + Sync + Send + 'static> ops::Add<T> for State<T> {
288 type Output = T;
289
290 fn add(self, rhs: T) -> Self::Output {
291 self.get() + rhs
292 }
293}
294
295impl<T: ops::AddAssign<T> + Copy + Sync + Send + 'static> ops::AddAssign<T> for State<T> {
296 fn add_assign(&mut self, rhs: T) {
297 if let Some(mut v) = self.try_write() {
298 *v += rhs;
299 }
300 }
301}
302
303impl<T: ops::Sub<Output = T> + Copy + Sync + Send + 'static> ops::Sub<T> for State<T> {
304 type Output = T;
305
306 fn sub(self, rhs: T) -> Self::Output {
307 self.get() - rhs
308 }
309}
310
311impl<T: ops::SubAssign<T> + Copy + Sync + Send + 'static> ops::SubAssign<T> for State<T> {
312 fn sub_assign(&mut self, rhs: T) {
313 if let Some(mut v) = self.try_write() {
314 *v -= rhs;
315 }
316 }
317}
318
319impl<T: ops::Mul<Output = T> + Copy + Sync + Send + 'static> ops::Mul<T> for State<T> {
320 type Output = T;
321
322 fn mul(self, rhs: T) -> Self::Output {
323 self.get() * rhs
324 }
325}
326
327impl<T: ops::MulAssign<T> + Copy + Sync + Send + 'static> ops::MulAssign<T> for State<T> {
328 fn mul_assign(&mut self, rhs: T) {
329 if let Some(mut v) = self.try_write() {
330 *v *= rhs;
331 }
332 }
333}
334
335impl<T: ops::Div<Output = T> + Copy + Sync + Send + 'static> ops::Div<T> for State<T> {
336 type Output = T;
337
338 fn div(self, rhs: T) -> Self::Output {
339 self.get() / rhs
340 }
341}
342
343impl<T: ops::DivAssign<T> + Copy + Sync + Send + 'static> ops::DivAssign<T> for State<T> {
344 fn div_assign(&mut self, rhs: T) {
345 if let Some(mut v) = self.try_write() {
346 *v /= rhs;
347 }
348 }
349}
350
351impl<T: Hash + Sync + Send> Hash for State<T> {
352 fn hash<H: Hasher>(&self, state: &mut H) {
353 self.read().hash(state)
354 }
355}
356
357impl<T: cmp::PartialEq<T> + Sync + Send + 'static> cmp::PartialEq<T> for State<T> {
358 fn eq(&self, other: &T) -> bool {
359 *self.read() == *other
360 }
361}
362
363impl<T: cmp::PartialOrd<T> + Sync + Send + 'static> cmp::PartialOrd<T> for State<T> {
364 fn partial_cmp(&self, other: &T) -> Option<cmp::Ordering> {
365 self.read().partial_cmp(other)
366 }
367}
368
369impl<T: cmp::PartialEq<T> + Sync + Send + 'static> cmp::PartialEq<State<T>> for State<T> {
370 fn eq(&self, other: &State<T>) -> bool {
371 *self.read() == *other.read()
372 }
373}
374
375impl<T: cmp::PartialOrd<T> + Sync + Send + 'static> cmp::PartialOrd<State<T>> for State<T> {
376 fn partial_cmp(&self, other: &State<T>) -> Option<cmp::Ordering> {
377 self.read().partial_cmp(&other.read())
378 }
379}
380
381impl<T: cmp::Eq + Sync + Send + 'static> cmp::Eq for State<T> {}