Skip to main content

wae_session/
extract.rs

1//! Session 提取器
2//!
3//! 提供 Session 提取器实现。
4
5use crate::Session;
6use std::sync::Arc;
7use wae_types::WaeError;
8
9/// Session 提取错误
10#[derive(Debug, Clone)]
11pub struct SessionRejection {
12    inner: WaeError,
13}
14
15impl SessionRejection {
16    fn new(error: WaeError) -> Self {
17        Self { inner: error }
18    }
19
20    /// 获取内部错误
21    pub fn into_inner(self) -> WaeError {
22        self.inner
23    }
24}
25
26impl std::fmt::Display for SessionRejection {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        self.inner.fmt(f)
29    }
30}
31
32impl std::error::Error for SessionRejection {}
33
34/// Session 提取器
35///
36/// 用于从请求中提取 Session。
37///
38/// # 示例
39///
40/// ```rust,ignore
41/// use wae_session::SessionExtractor;
42///
43/// async fn handler(request: &Request) {
44///     let session = SessionExtractor::from_request(request).unwrap();
45///     let user_id: Option<String> = session.get_typed("user_id").await;
46///     // ...
47/// }
48/// ```
49#[derive(Debug, Clone)]
50pub struct SessionExtractor {
51    /// Session 引用
52    session: Arc<Session>,
53}
54
55impl SessionExtractor {
56    /// 创建 Session 提取器
57    pub fn new(session: Arc<Session>) -> Self {
58        Self { session }
59    }
60
61    /// 从请求的扩展中提取 Session
62    pub fn from_request<B>(request: &http::Request<B>) -> Result<Self, SessionRejection> {
63        request
64            .extensions()
65            .get::<Arc<Session>>()
66            .cloned()
67            .map(|session| SessionExtractor { session })
68            .ok_or_else(|| SessionRejection::new(WaeError::internal("Session not found in request extensions")))
69    }
70
71    /// 获取 Session 引用
72    pub fn inner(&self) -> &Session {
73        &self.session
74    }
75
76    /// 获取 Session ID
77    pub fn id(&self) -> &str {
78        self.session.id()
79    }
80
81    /// 检查是否是新创建的 Session
82    pub fn is_new(&self) -> bool {
83        self.session.is_new()
84    }
85
86    /// 获取值
87    pub async fn get(&self, key: &str) -> Option<serde_json::Value> {
88        self.session.get(key).await
89    }
90
91    /// 获取类型化的值
92    pub async fn get_typed<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
93        self.session.get_typed(key).await
94    }
95
96    /// 设置值
97    pub async fn set<T: serde::Serialize>(&self, key: impl Into<String>, value: T) {
98        self.session.set(key, value).await
99    }
100
101    /// 删除值
102    pub async fn remove(&self, key: &str) -> Option<serde_json::Value> {
103        self.session.remove(key).await
104    }
105
106    /// 检查键是否存在
107    pub async fn contains(&self, key: &str) -> bool {
108        self.session.contains(key).await
109    }
110
111    /// 清空所有数据
112    pub async fn clear(&self) {
113        self.session.clear().await
114    }
115
116    /// 获取所有键
117    pub async fn keys(&self) -> Vec<String> {
118        self.session.keys().await
119    }
120
121    /// 获取数据条目数量
122    pub async fn len(&self) -> usize {
123        self.session.len().await
124    }
125
126    /// 检查是否为空
127    pub async fn is_empty(&self) -> bool {
128        self.session.is_empty().await
129    }
130}