moduforge_rules_expression/functions/
state_guard.rs

1//! State 守卫模块
2//!
3//! 提供 RAII 模式的 State 管理,确保异常安全
4
5        use std::sync::Arc;
6use super::custom::CustomFunctionRegistry;
7use std::marker::PhantomData;
8
9/// State 守卫,使用 RAII 模式自动管理 State 的设置和清理
10///
11/// 当 StateGuard 被创建时,会自动设置当前线程的 State 上下文
12/// 当 StateGuard 被丢弃时(包括异常情况),会自动清理 State 上下文
13///
14/// # 示例
15/// ```rust,ignore
16/// use std::sync::Arc;
17/// use moduforge_state::State;
18/// use moduforge_rules_expression::functions::StateGuard;
19///
20/// // 创建 State
21/// let state = Arc::new(State::default());
22///
23/// {
24///     // 设置 State 上下文
25///     let _guard = StateGuard::new(state);
26///     
27///     // 在这个作用域内,自定义函数可以访问 State
28///     // 即使发生 panic,State 也会被正确清理
29///     
30/// } // 这里 StateGuard 被自动丢弃,State 上下文被清理
31/// ```
32pub struct StateGuard<S> {
33    _private: PhantomData<S>,
34}
35
36impl<S: Send + Sync + 'static> StateGuard<S> {
37    /// 创建新的 State 守卫
38    ///
39    /// # 参数
40    /// * `state` - 要设置的 State 对象
41    ///
42    /// # 返回值
43    /// 返回 StateGuard 实例,当其被丢弃时会自动清理 State
44    pub fn new(state: Arc<S>) -> Self {
45        CustomFunctionRegistry::set_current_state(Some(state));
46        Self { _private: PhantomData }
47    }
48
49    /// 创建空的 State 守卫(用于清理已有的 State)
50    ///
51    /// # 返回值
52    /// 返回 StateGuard 实例,会立即清理当前 State 并在丢弃时保持清理状态
53    pub fn empty() -> Self {
54        CustomFunctionRegistry::clear_current_state();
55        Self { _private: PhantomData }
56    }
57
58    /// 获取当前是否有活跃的 State
59    ///
60    /// # 返回值
61    /// * `true` - 当前线程有活跃的 State
62    /// * `false` - 当前线程没有 State
63    pub fn has_active_state() -> bool {
64        CustomFunctionRegistry::has_current_state()
65    }
66}
67
68impl<S> Drop for StateGuard<S> {
69    /// 自动清理 State 上下文
70    ///
71    /// 当 StateGuard 被丢弃时(正常情况或异常情况),
72    /// 会自动清理当前线程的 State 上下文
73    fn drop(&mut self) {
74        CustomFunctionRegistry::clear_current_state();
75    }
76}
77
78/// 便利宏,用于在指定作用域内设置 State
79///
80/// # 示例
81/// ```rust,ignore
82/// use moduforge_rules_expression::with_state;
83///
84/// let state = Arc::new(State::default());
85/// 
86/// with_state!(state => {
87///     // 在这个块内,State 是活跃的
88/// });
89/// // State 在这里已经被清理
90/// ```
91#[macro_export]
92macro_rules! with_state {
93    ($state:expr => $block:block) => {
94        {
95            let _guard = $crate::functions::StateGuard::new($state);
96            $block
97        }
98    };
99}
100
101/// 异步版本的 State 守卫便利函数
102///
103/// # 参数
104/// * `state` - 要设置的 State 对象
105/// * `future` - 要在 State 上下文中执行的异步操作
106///
107/// # 返回值
108/// 返回异步操作的结果
109///
110/// # 示例
111/// ```rust,ignore
112/// use moduforge_rules_expression::functions::with_state_async;
113///
114/// let state = Arc::new(State::default());
115/// 
116/// let result = with_state_async(state, async {
117///     // aync block
118/// }).await;
119/// ```
120pub async fn with_state_async<S, T, F, Fut>(
121    state: Arc<S>,
122    future: F,
123) -> T
124where
125    S: Send + Sync + 'static,
126    F: FnOnce() -> Fut,
127    Fut: std::future::Future<Output = T>,
128{
129    let _guard = StateGuard::new(state);
130    future().await
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use std::sync::Arc;
137
138    // A dummy struct for testing purposes
139    struct DummyState;
140
141    #[test]
142    fn test_state_guard_basic() {
143        // 初始状态应该没有 State
144        assert!(!StateGuard::<DummyState>::has_active_state());
145
146        {
147            // 创建一个模拟的 State
148            let state = Arc::new(DummyState);
149            let _guard = StateGuard::new(state);
150            
151            // 在这个作用域内应该有活跃的 State
152            assert!(StateGuard::<DummyState>::has_active_state());
153        }
154
155        // 离开作用域后,State 应该被清理
156        assert!(!StateGuard::<DummyState>::has_active_state());
157    }
158
159    #[test]
160    fn test_state_guard_panic_safety() {
161        assert!(!StateGuard::<DummyState>::has_active_state());
162
163        let result = std::panic::catch_unwind(|| {
164            let state = Arc::new(DummyState);
165            let _guard = StateGuard::new(state);
166            
167            // 模拟 panic
168            panic!("测试 panic 安全性");
169        });
170
171        // 即使发生了 panic,State 也应该被正确清理
172        assert!(!StateGuard::<DummyState>::has_active_state());
173        assert!(result.is_err());
174    }
175
176    #[test]
177    fn test_empty_guard() {
178        let state = Arc::new(DummyState);
179        let _guard = StateGuard::new(state);
180        assert!(StateGuard::<DummyState>::has_active_state());
181
182        // 创建空守卫应该立即清理 State
183        let _guard_empty = StateGuard::<DummyState>::empty();
184        assert!(!StateGuard::<DummyState>::has_active_state());
185    }
186}
187