1use crate::{Session, SessionConfig, SessionStore};
6use axum::{
7 body::Body,
8 extract::Request,
9 http::{Response, header::SET_COOKIE},
10 middleware::Next,
11};
12use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
13use rand::Rng;
14use std::sync::Arc;
15use tower::Layer;
16
17#[derive(Debug, Clone)]
21pub struct SessionLayer<S>
22where
23 S: SessionStore,
24{
25 store: S,
27 config: SessionConfig,
29}
30
31impl<S> SessionLayer<S>
32where
33 S: SessionStore,
34{
35 pub fn new(store: S, config: SessionConfig) -> Self {
37 Self { store, config }
38 }
39
40 pub fn with_store(store: S) -> Self {
42 Self::new(store, SessionConfig::default())
43 }
44
45 fn generate_session_id(&self) -> String {
47 let mut bytes = vec![0u8; self.config.id_length];
48 rand::rng().fill_bytes(&mut bytes);
49 URL_SAFE_NO_PAD.encode(&bytes)
50 }
51
52 fn extract_session_id(&self, request: &Request) -> Option<String> {
54 let cookie_header = request.headers().get("cookie")?.to_str().ok()?;
55
56 for cookie in cookie_header.split(';') {
57 let cookie = cookie.trim();
58 if let Some(value) = cookie.strip_prefix(&format!("{}=", self.config.cookie_name)) {
59 return Some(value.to_string());
60 }
61 }
62
63 None
64 }
65
66 pub async fn middleware(request: Request, next: Next) -> Response<Body> {
68 let layer = request.extensions().get::<Arc<Self>>().cloned().expect("SessionLayer not found in extensions");
69
70 let session_id = layer.extract_session_id(&request);
71 let session = if let Some(id) = session_id {
72 if let Some(data) = layer.store.get(&id).await {
73 if let Ok(data_map) = serde_json::from_str(&data) {
74 Session::from_data(id, data_map).await
75 }
76 else {
77 Session::new(layer.generate_session_id())
78 }
79 }
80 else {
81 Session::new(layer.generate_session_id())
82 }
83 }
84 else {
85 Session::new(layer.generate_session_id())
86 };
87
88 let session = Arc::new(session);
89 let mut request = request;
90 request.extensions_mut().insert(session.clone());
91
92 let mut response = next.run(request).await;
93
94 if session.is_dirty().await || session.is_new() {
95 let data = session.to_json().await;
96 layer.store.set(session.id(), &data, layer.config.ttl).await;
97
98 let cookie_value = layer.config.build_cookie_header(session.id());
99 if let Ok(header_value) = cookie_value.parse() {
100 response.headers_mut().append(SET_COOKIE, header_value);
101 }
102 }
103
104 response
105 }
106}
107
108impl<S> Layer<Arc<Self>> for SessionLayer<S>
109where
110 S: SessionStore,
111{
112 type Service = SessionService;
113
114 fn layer(&self, _inner: Arc<Self>) -> Self::Service {
115 SessionService
116 }
117}
118
119pub struct SessionService;
123
124impl tower::Service<Request> for SessionService {
125 type Response = Response<Body>;
126 type Error = std::convert::Infallible;
127 type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
128
129 fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
130 std::task::Poll::Ready(Ok(()))
131 }
132
133 fn call(&mut self, _req: Request) -> Self::Future {
134 Box::pin(async { Ok(Response::new(Body::empty())) })
135 }
136}