ratatui_kit/hooks/
use_state.rs

1//! 响应式状态管理 Hook 实现。
2//!
3//! 本模块为 ratatui-kit 提供了类似 React useState 的响应式状态管理能力,适用于计数器、输入框等本地状态。
4//!
5//! ## 主要类型
6//! - [`State<T>`]:响应式状态持有者,支持原子读写、变更通知、算术运算等。
7//! - [`StateRef<'a, T>`]:状态的只读引用。
8//! - [`StateMutRef<'a, T>`]:状态的可变引用,支持变更通知。
9//! - [`StateMutNoUpdate<'a, T>`]:状态的可变引用,不触发变更通知。
10//! - [`UseState`] trait:为 [`Hooks`] 提供 use_state 方法。
11//!
12//! ## 用法示例
13//! ```rust
14//! let mut hooks = ...;
15//! let counter = hooks.use_state(|| 0);
16//! counter.set(1);
17//! let value = counter.read();
18//! ```
19//!
20//! ## 线程安全
21//! 本模块所有状态类型均为 Send + Sync,可安全用于多线程场景。
22
23use super::{Hook, Hooks};
24use generational_box::{
25    AnyStorage, BorrowError, BorrowMutError, GenerationalBox, Owner, SyncStorage,
26};
27use std::{
28    cmp,
29    fmt::{self, Debug, Display, Formatter},
30    hash::{Hash, Hasher},
31    ops::{self, Deref, DerefMut},
32    task::{Poll, Waker},
33};
34
35mod private {
36    pub trait Sealed {}
37    impl Sealed for crate::hooks::Hooks<'_, '_> {}
38}
39
40pub trait UseState: private::Sealed {
41    /// 为 [`Hooks`] 提供 use_state 方法,创建响应式状态。
42    ///
43    /// - `init`:状态初始化闭包。
44    /// - 返回 [`State<T>`]。
45    fn use_state<T, F>(&mut self, init: F) -> State<T>
46    where
47        F: FnOnce() -> T,
48        T: Unpin + Send + Sync + 'static;
49}
50
51struct UseStateImpl<T>
52where
53    T: Unpin + Send + Sync + 'static,
54{
55    state: State<T>,
56    _storage: Owner<SyncStorage>,
57}
58
59impl<T> UseStateImpl<T>
60where
61    T: Unpin + Send + Sync + 'static,
62{
63    /// use_state 的内部实现,持有状态和存储。
64    pub fn new(initial_value: T) -> Self {
65        let storage = Owner::default();
66        UseStateImpl {
67            state: State {
68                inner: storage.insert(StateValue {
69                    value: initial_value,
70                    waker: None,
71                    is_changed: false,
72                }),
73            },
74            _storage: storage,
75        }
76    }
77}
78
79impl<T> Hook for UseStateImpl<T>
80where
81    T: Unpin + Send + Sync + 'static,
82{
83    fn poll_change(
84        mut self: std::pin::Pin<&mut Self>,
85        cx: &mut std::task::Context,
86    ) -> std::task::Poll<()> {
87        if let Ok(mut value) = self.state.inner.try_write() {
88            if value.is_changed {
89                value.is_changed = false;
90                Poll::Ready(())
91            } else {
92                value.waker = Some(cx.waker().clone());
93                Poll::Pending
94            }
95        } else {
96            Poll::Pending
97        }
98    }
99}
100
101impl UseState for Hooks<'_, '_> {
102    fn use_state<T, F>(&mut self, init: F) -> State<T>
103    where
104        F: FnOnce() -> T,
105        T: Unpin + Send + Sync + 'static,
106    {
107        self.use_hook(move || UseStateImpl::new(init())).state
108    }
109}
110
111struct StateValue<T> {
112    value: T,
113    waker: Option<Waker>,
114    is_changed: bool,
115}
116
117/// 状态的只读引用。
118/// 通过 [`State::read`] 或 [`State::try_read`] 获取。
119pub struct StateRef<'a, T: 'static> {
120    inner: <SyncStorage as AnyStorage>::Ref<'a, StateValue<T>>,
121}
122
123impl<T: 'static> Deref for StateRef<'_, T> {
124    type Target = T;
125
126    fn deref(&self) -> &Self::Target {
127        &self.inner.value
128    }
129}
130
131/// 状态的可变引用,支持变更通知。
132/// 通过 [`State::write`] 或 [`State::try_write`] 获取。
133pub struct StateMutRef<'a, T: 'static> {
134    inner: <SyncStorage as AnyStorage>::Mut<'a, StateValue<T>>,
135    is_deref_mut: bool,
136}
137
138impl<T: 'static> Deref for StateMutRef<'_, T> {
139    type Target = T;
140
141    fn deref(&self) -> &Self::Target {
142        &self.inner.value
143    }
144}
145
146impl<T: 'static> DerefMut for StateMutRef<'_, T> {
147    fn deref_mut(&mut self) -> &mut Self::Target {
148        self.is_deref_mut = true;
149        &mut self.inner.value
150    }
151}
152
153impl<T: 'static> Drop for StateMutRef<'_, T> {
154    fn drop(&mut self) {
155        if self.is_deref_mut {
156            self.inner.is_changed = true;
157            if let Some(waker) = self.inner.waker.take() {
158                waker.wake();
159            }
160        }
161    }
162}
163
164/// 状态的可变引用,不触发变更通知。
165/// 通过 [`State::write_no_update`] 或 [`State::try_write_no_update`] 获取。
166pub struct StateMutNoUpdate<'a, T: 'static> {
167    inner: <SyncStorage as AnyStorage>::Mut<'a, StateValue<T>>,
168}
169
170impl<T: 'static> Deref for StateMutNoUpdate<'_, T> {
171    type Target = T;
172
173    fn deref(&self) -> &Self::Target {
174        &self.inner.value
175    }
176}
177
178impl<T: 'static> DerefMut for StateMutNoUpdate<'_, T> {
179    fn deref_mut(&mut self) -> &mut Self::Target {
180        &mut self.inner.value
181    }
182}
183
184/// 响应式状态持有者。
185/// 支持原子读写、变更通知、算术运算等。
186pub struct State<T: Send + Sync + 'static> {
187    inner: GenerationalBox<StateValue<T>, SyncStorage>,
188}
189
190impl<T: Send + Sync + 'static> Clone for State<T> {
191    fn clone(&self) -> Self {
192        *self
193    }
194}
195
196impl<T: Send + Sync + 'static> Copy for State<T> {}
197
198impl<T: Send + Sync + Copy + 'static> State<T> {
199    pub fn get(&self) -> T {
200        *self.read()
201    }
202}
203
204impl<T: Send + Sync + 'static> State<T> {
205    /// 尝试获取只读引用,失败时返回 None。
206    pub fn try_read(&'_ self) -> Option<StateRef<'_, T>> {
207        loop {
208            match self.inner.try_read() {
209                Ok(inner) => return Some(StateRef { inner }),
210                Err(BorrowError::Dropped(_)) => {
211                    return None;
212                }
213                Err(BorrowError::AlreadyBorrowedMut(_)) => match self.inner.try_write() {
214                    Err(BorrowMutError::Dropped(_)) => {
215                        return None;
216                    }
217                    _ => continue,
218                },
219            }
220        }
221    }
222
223    /// 获取只读引用,失败时 panic。
224    pub fn read(&'_ self) -> StateRef<'_, T> {
225        self.try_read()
226            .expect("attempt to read state after owner was dropped")
227    }
228
229    /// 尝试获取可变引用,支持变更通知,失败时返回 None。
230    pub fn try_write(&'_ self) -> Option<StateMutRef<'_, T>> {
231        self.inner
232            .try_write()
233            .map(|inner| StateMutRef {
234                inner,
235                is_deref_mut: false,
236            })
237            .ok()
238    }
239
240    /// 获取可变引用,支持变更通知,失败时 panic。
241    pub fn write(&'_ self) -> StateMutRef<'_, T> {
242        self.try_write()
243            .expect("attempt to write state after owner was dropped")
244    }
245
246    /// 尝试获取可变引用,不触发变更通知,失败时返回 None。
247    pub fn try_write_no_update(&'_ self) -> Option<StateMutNoUpdate<'_, T>> {
248        self.inner
249            .try_write()
250            .map(|inner| StateMutNoUpdate { inner })
251            .ok()
252    }
253
254    /// 获取可变引用,不触发变更通知,失败时 panic。
255    pub fn write_no_update(&'_ self) -> StateMutNoUpdate<'_, T> {
256        self.try_write_no_update()
257            .expect("attempt to write state after owner was dropped")
258    }
259
260    /// 设置状态值,触发变更通知。
261    pub fn set(&mut self, value: T) {
262        if let Some(mut v) = self.try_write() {
263            *v = value;
264        }
265    }
266
267    /// 设置状态值,不触发变更通知。
268    pub fn set_no_update(&mut self, value: T) {
269        if let Some(mut v) = self.try_write_no_update() {
270            *v = value;
271        }
272    }
273}
274
275impl<T: Debug + Sync + Send + 'static> Debug for State<T> {
276    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
277        self.read().fmt(f)
278    }
279}
280
281impl<T: Display + Sync + Send + 'static> Display for State<T> {
282    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
283        self.read().fmt(f)
284    }
285}
286
287impl<T: ops::Add<Output = T> + Copy + Sync + Send + 'static> ops::Add<T> for State<T> {
288    type Output = T;
289
290    fn add(self, rhs: T) -> Self::Output {
291        self.get() + rhs
292    }
293}
294
295impl<T: ops::AddAssign<T> + Copy + Sync + Send + 'static> ops::AddAssign<T> for State<T> {
296    fn add_assign(&mut self, rhs: T) {
297        if let Some(mut v) = self.try_write() {
298            *v += rhs;
299        }
300    }
301}
302
303impl<T: ops::Sub<Output = T> + Copy + Sync + Send + 'static> ops::Sub<T> for State<T> {
304    type Output = T;
305
306    fn sub(self, rhs: T) -> Self::Output {
307        self.get() - rhs
308    }
309}
310
311impl<T: ops::SubAssign<T> + Copy + Sync + Send + 'static> ops::SubAssign<T> for State<T> {
312    fn sub_assign(&mut self, rhs: T) {
313        if let Some(mut v) = self.try_write() {
314            *v -= rhs;
315        }
316    }
317}
318
319impl<T: ops::Mul<Output = T> + Copy + Sync + Send + 'static> ops::Mul<T> for State<T> {
320    type Output = T;
321
322    fn mul(self, rhs: T) -> Self::Output {
323        self.get() * rhs
324    }
325}
326
327impl<T: ops::MulAssign<T> + Copy + Sync + Send + 'static> ops::MulAssign<T> for State<T> {
328    fn mul_assign(&mut self, rhs: T) {
329        if let Some(mut v) = self.try_write() {
330            *v *= rhs;
331        }
332    }
333}
334
335impl<T: ops::Div<Output = T> + Copy + Sync + Send + 'static> ops::Div<T> for State<T> {
336    type Output = T;
337
338    fn div(self, rhs: T) -> Self::Output {
339        self.get() / rhs
340    }
341}
342
343impl<T: ops::DivAssign<T> + Copy + Sync + Send + 'static> ops::DivAssign<T> for State<T> {
344    fn div_assign(&mut self, rhs: T) {
345        if let Some(mut v) = self.try_write() {
346            *v /= rhs;
347        }
348    }
349}
350
351impl<T: Hash + Sync + Send> Hash for State<T> {
352    fn hash<H: Hasher>(&self, state: &mut H) {
353        self.read().hash(state)
354    }
355}
356
357impl<T: cmp::PartialEq<T> + Sync + Send + 'static> cmp::PartialEq<T> for State<T> {
358    fn eq(&self, other: &T) -> bool {
359        *self.read() == *other
360    }
361}
362
363impl<T: cmp::PartialOrd<T> + Sync + Send + 'static> cmp::PartialOrd<T> for State<T> {
364    fn partial_cmp(&self, other: &T) -> Option<cmp::Ordering> {
365        self.read().partial_cmp(other)
366    }
367}
368
369impl<T: cmp::PartialEq<T> + Sync + Send + 'static> cmp::PartialEq<State<T>> for State<T> {
370    fn eq(&self, other: &State<T>) -> bool {
371        *self.read() == *other.read()
372    }
373}
374
375impl<T: cmp::PartialOrd<T> + Sync + Send + 'static> cmp::PartialOrd<State<T>> for State<T> {
376    fn partial_cmp(&self, other: &State<T>) -> Option<cmp::Ordering> {
377        self.read().partial_cmp(&other.read())
378    }
379}
380
381impl<T: cmp::Eq + Sync + Send + 'static> cmp::Eq for State<T> {}