Skip to main content

wae_session/
layer.rs

1//! Session 中间件层
2//!
3//! 提供 Session 中间件实现,用于在请求处理过程中管理 Session。
4
5use crate::{Session, SessionConfig, SessionStore};
6use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
7use http::{Request, Response};
8use http_body::Body;
9use rand::Rng;
10use std::{marker::PhantomData, sync::Arc};
11use tower::{Layer, Service};
12
13/// Session 中间件层
14///
15/// 用于在 tower 应用中添加 Session 支持。
16#[derive(Debug, Clone)]
17pub struct SessionLayer<S, ReqBody, ResBody>
18where
19    S: SessionStore,
20    ReqBody: Body + Send + 'static,
21    ResBody: Body + Send + 'static,
22{
23    /// Session 存储
24    store: S,
25    /// Session 配置
26    config: SessionConfig,
27    /// 用于标记请求体类型
28    _phantom: PhantomData<(ReqBody, ResBody)>,
29}
30
31impl<S, ReqBody, ResBody> SessionLayer<S, ReqBody, ResBody>
32where
33    S: SessionStore,
34    ReqBody: Body + Send + 'static,
35    ResBody: Body + Send + 'static,
36{
37    /// 创建新的 Session 中间件层
38    pub fn new(store: S, config: SessionConfig) -> Self {
39        Self { store, config, _phantom: PhantomData }
40    }
41
42    /// 使用默认配置创建 Session 中间件层
43    pub fn with_store(store: S) -> Self {
44        Self::new(store, SessionConfig::default())
45    }
46}
47
48impl<S, T, ReqBody, ResBody> Layer<T> for SessionLayer<S, ReqBody, ResBody>
49where
50    S: SessionStore,
51    T: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
52    T::Future: Send,
53    ReqBody: Body + Send + 'static,
54    ResBody: Body + Send + 'static,
55{
56    type Service = SessionService<T, S, ReqBody, ResBody>;
57
58    fn layer(&self, inner: T) -> Self::Service {
59        SessionService { inner, store: self.store.clone(), config: self.config.clone(), _phantom: PhantomData }
60    }
61}
62
63/// Session 服务
64///
65/// 用于处理 Session 的服务包装器。
66#[derive(Debug, Clone)]
67pub struct SessionService<T, S, ReqBody, ResBody>
68where
69    S: SessionStore,
70    ReqBody: Body + Send + 'static,
71    ResBody: Body + Send + 'static,
72{
73    /// 内部服务
74    inner: T,
75    /// Session 存储
76    store: S,
77    /// Session 配置
78    config: SessionConfig,
79    /// 用于标记请求体类型
80    _phantom: PhantomData<(ReqBody, ResBody)>,
81}
82
83impl<T, S, ReqBody, ResBody> SessionService<T, S, ReqBody, ResBody>
84where
85    S: SessionStore,
86    ReqBody: Body + Send + 'static,
87    ResBody: Body + Send + 'static,
88{
89    /// 从请求中提取 Session ID
90    fn extract_session_id<B>(&self, request: &Request<B>) -> Option<String> {
91        let cookie_header = request.headers().get("cookie")?.to_str().ok()?;
92
93        for cookie in cookie_header.split(';') {
94            let cookie = cookie.trim();
95            if let Some(value) = cookie.strip_prefix(&format!("{}=", self.config.cookie_name)) {
96                return Some(value.to_string());
97            }
98        }
99
100        None
101    }
102}
103
104impl<T, S, ReqBody, ResBody> Service<Request<ReqBody>> for SessionService<T, S, ReqBody, ResBody>
105where
106    S: SessionStore,
107    T: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
108    T::Future: Send,
109    ReqBody: Body + Send + 'static,
110    ResBody: Body + Send + 'static,
111{
112    type Response = Response<ResBody>;
113    type Error = T::Error;
114    type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
115
116    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
117        self.inner.poll_ready(cx)
118    }
119
120    fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
121        let mut inner = self.inner.clone();
122        let session_id = self.extract_session_id(&request);
123        let config = self.config.clone();
124        let store = self.store.clone();
125
126        Box::pin(async move {
127            let generate_session_id = || -> String {
128                let mut bytes = vec![0u8; config.id_length];
129                rand::rng().fill_bytes(&mut bytes);
130                URL_SAFE_NO_PAD.encode(&bytes)
131            };
132
133            let session = if let Some(id) = session_id {
134                if let Some(data) = store.get(&id).await {
135                    if let Ok(data_map) = serde_json::from_str(&data) {
136                        Session::from_data(id.to_string(), data_map).await
137                    }
138                    else {
139                        Session::new(generate_session_id())
140                    }
141                }
142                else {
143                    Session::new(generate_session_id())
144                }
145            }
146            else {
147                Session::new(generate_session_id())
148            };
149
150            let session = Arc::new(session);
151            request.extensions_mut().insert(session.clone());
152
153            let mut response = inner.call(request).await?;
154
155            if session.is_dirty().await || session.is_new() {
156                let data = session.to_json().await;
157                store.set(session.id(), &data, config.ttl).await;
158
159                let cookie_value = config.build_cookie_header(session.id());
160                if let Ok(header_value) = cookie_value.parse() {
161                    response.headers_mut().append(http::header::SET_COOKIE, header_value);
162                }
163            }
164
165            Ok(response)
166        })
167    }
168}