Skip to main content

wae_session/
layer.rs

1//! Session 中间件层
2//!
3//! 提供 Session 中间件实现,用于在请求处理过程中管理 Session。
4
5use crate::{Session, SessionConfig, SessionStore};
6use axum::{
7    body::Body,
8    extract::Request,
9    http::{Response, header::SET_COOKIE},
10    middleware::Next,
11};
12use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
13use rand::Rng;
14use std::sync::Arc;
15use tower::Layer;
16
17/// Session 中间件层
18///
19/// 用于在 axum 应用中添加 Session 支持。
20#[derive(Debug, Clone)]
21pub struct SessionLayer<S>
22where
23    S: SessionStore,
24{
25    /// Session 存储
26    store: S,
27    /// Session 配置
28    config: SessionConfig,
29}
30
31impl<S> SessionLayer<S>
32where
33    S: SessionStore,
34{
35    /// 创建新的 Session 中间件层
36    pub fn new(store: S, config: SessionConfig) -> Self {
37        Self { store, config }
38    }
39
40    /// 使用默认配置创建 Session 中间件层
41    pub fn with_store(store: S) -> Self {
42        Self::new(store, SessionConfig::default())
43    }
44
45    /// 生成新的 Session ID
46    fn generate_session_id(&self) -> String {
47        let mut bytes = vec![0u8; self.config.id_length];
48        rand::rng().fill_bytes(&mut bytes);
49        URL_SAFE_NO_PAD.encode(&bytes)
50    }
51
52    /// 从请求中提取 Session ID
53    fn extract_session_id(&self, request: &Request) -> Option<String> {
54        let cookie_header = request.headers().get("cookie")?.to_str().ok()?;
55
56        for cookie in cookie_header.split(';') {
57            let cookie = cookie.trim();
58            if let Some(value) = cookie.strip_prefix(&format!("{}=", self.config.cookie_name)) {
59                return Some(value.to_string());
60            }
61        }
62
63        None
64    }
65
66    /// 中间件处理函数
67    pub async fn middleware(request: Request, next: Next) -> Response<Body> {
68        let layer = request.extensions().get::<Arc<Self>>().cloned().expect("SessionLayer not found in extensions");
69
70        let session_id = layer.extract_session_id(&request);
71        let session = if let Some(id) = session_id {
72            if let Some(data) = layer.store.get(&id).await {
73                if let Ok(data_map) = serde_json::from_str(&data) {
74                    Session::from_data(id, data_map).await
75                }
76                else {
77                    Session::new(layer.generate_session_id())
78                }
79            }
80            else {
81                Session::new(layer.generate_session_id())
82            }
83        }
84        else {
85            Session::new(layer.generate_session_id())
86        };
87
88        let session = Arc::new(session);
89        let mut request = request;
90        request.extensions_mut().insert(session.clone());
91
92        let mut response = next.run(request).await;
93
94        if session.is_dirty().await || session.is_new() {
95            let data = session.to_json().await;
96            layer.store.set(session.id(), &data, layer.config.ttl).await;
97
98            let cookie_value = layer.config.build_cookie_header(session.id());
99            if let Ok(header_value) = cookie_value.parse() {
100                response.headers_mut().append(SET_COOKIE, header_value);
101            }
102        }
103
104        response
105    }
106}
107
108impl<S> Layer<Arc<Self>> for SessionLayer<S>
109where
110    S: SessionStore,
111{
112    type Service = SessionService;
113
114    fn layer(&self, _inner: Arc<Self>) -> Self::Service {
115        SessionService
116    }
117}
118
119/// Session 服务
120///
121/// 内部使用的服务包装器。
122pub struct SessionService;
123
124impl tower::Service<Request> for SessionService {
125    type Response = Response<Body>;
126    type Error = std::convert::Infallible;
127    type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
128
129    fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
130        std::task::Poll::Ready(Ok(()))
131    }
132
133    fn call(&mut self, _req: Request) -> Self::Future {
134        Box::pin(async { Ok(Response::new(Body::empty())) })
135    }
136}