use std::sync::Arc;
use super::mf_function::MfFunctionRegistry;
use std::marker::PhantomData;
pub struct StateGuard<S> {
_private: PhantomData<S>,
}
impl<S: Send + Sync + 'static> StateGuard<S> {
pub fn new(state: Arc<S>) -> Self {
MfFunctionRegistry::set_current_state(Some(state));
Self { _private: PhantomData }
}
pub fn empty() -> Self {
MfFunctionRegistry::clear_current_state();
Self { _private: PhantomData }
}
pub fn has_active_state() -> bool {
MfFunctionRegistry::has_current_state()
}
}
impl<S> Drop for StateGuard<S> {
fn drop(&mut self) {
MfFunctionRegistry::clear_current_state();
}
}
#[macro_export]
macro_rules! with_state {
($state:expr => $block:block) => {{
let _guard = $crate::functions::StateGuard::new($state);
$block
}};
}
pub async fn with_state_async<S, T, F, Fut>(
state: Arc<S>,
future: F,
) -> T
where
S: Send + Sync + 'static,
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = T>,
{
let _guard = StateGuard::new(state);
future().await
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
struct DummyState;
#[test]
fn test_state_guard_basic() {
assert!(!StateGuard::<DummyState>::has_active_state());
{
let state = Arc::new(DummyState);
let _guard = StateGuard::new(state);
assert!(StateGuard::<DummyState>::has_active_state());
}
assert!(!StateGuard::<DummyState>::has_active_state());
}
#[test]
fn test_state_guard_panic_safety() {
assert!(!StateGuard::<DummyState>::has_active_state());
let result = std::panic::catch_unwind(|| {
let state = Arc::new(DummyState);
let _guard = StateGuard::new(state);
panic!("测试 panic 安全性");
});
assert!(!StateGuard::<DummyState>::has_active_state());
assert!(result.is_err());
}
#[test]
fn test_empty_guard() {
let state = Arc::new(DummyState);
let _guard = StateGuard::new(state);
assert!(StateGuard::<DummyState>::has_active_state());
let _guard_empty = StateGuard::<DummyState>::empty();
assert!(!StateGuard::<DummyState>::has_active_state());
}
}