1use std::future::Future;
2use std::pin::Pin;
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll};
6
7use axum::body::Body;
8use axum::extract::connect_info::ConnectInfo;
9use cookie::{Cookie, CookieJar, SameSite};
10use http::{HeaderValue, Request, Response};
11use tower::{Layer, Service};
12
13use crate::cookie::{CookieConfig, Key};
14use crate::ip::ClientIp;
15
16use super::extractor::{SessionAction, SessionState};
17use super::meta::{SessionMeta, header_str};
18use super::store::Store;
19use super::token::SessionToken;
20
21#[derive(Clone)]
36pub struct SessionLayer {
37 store: Arc<Store>,
38 cookie_config: CookieConfig,
39 key: Key,
40}
41
42pub fn layer(store: Store, cookie_config: &CookieConfig, key: &Key) -> SessionLayer {
60 SessionLayer {
61 store: Arc::new(store),
62 cookie_config: cookie_config.clone(),
63 key: key.clone(),
64 }
65}
66
67impl<S> Layer<S> for SessionLayer {
68 type Service = SessionMiddleware<S>;
69
70 fn layer(&self, inner: S) -> Self::Service {
71 SessionMiddleware {
72 inner,
73 store: self.store.clone(),
74 cookie_config: self.cookie_config.clone(),
75 key: self.key.clone(),
76 }
77 }
78}
79
80#[derive(Clone)]
86pub struct SessionMiddleware<S> {
87 inner: S,
88 store: Arc<Store>,
89 cookie_config: CookieConfig,
90 key: Key,
91}
92
93impl<S, ReqBody> Service<Request<ReqBody>> for SessionMiddleware<S>
94where
95 S: Service<Request<ReqBody>, Response = Response<Body>> + Clone + Send + 'static,
96 S::Future: Send + 'static,
97 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
98 ReqBody: Send + 'static,
99{
100 type Response = Response<Body>;
101 type Error = S::Error;
102 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
103
104 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
105 self.inner.poll_ready(cx)
106 }
107
108 fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
109 let store = self.store.clone();
110 let cookie_config = self.cookie_config.clone();
111 let key = self.key.clone();
112 let mut inner = self.inner.clone();
113 std::mem::swap(&mut self.inner, &mut inner);
114
115 Box::pin(async move {
116 let config = store.config();
117 let cookie_name = &config.cookie_name;
118
119 let ip = request
121 .extensions()
122 .get::<ClientIp>()
123 .map(|c| c.0.to_string())
124 .unwrap_or_else(|| {
125 request
127 .extensions()
128 .get::<ConnectInfo<std::net::SocketAddr>>()
129 .map(|ci| ci.0.ip().to_string())
130 .unwrap_or_else(|| "unknown".to_string())
131 });
132 let headers = request.headers();
133
134 let ua = header_str(headers, "user-agent");
136 let accept_lang = header_str(headers, "accept-language");
137 let accept_enc = header_str(headers, "accept-encoding");
138 let meta = SessionMeta::from_headers(ip, ua, accept_lang, accept_enc);
139
140 let session_token = read_signed_cookie(request.headers(), cookie_name, &key);
142 let had_cookie = session_token.is_some();
143
144 let (current_session, read_failed) = if let Some(ref token) = session_token {
146 match store.read_by_token(token).await {
147 Ok(session) => (session, false),
148 Err(e) => {
149 tracing::error!("failed to read session: {e}");
150 (None, true)
151 }
152 }
153 } else {
154 (None, false)
155 };
156
157 let current_session = if let Some(session) = current_session {
159 if config.validate_fingerprint && meta.fingerprint != session.fingerprint {
160 tracing::warn!(
161 session_id = session.id,
162 user_id = session.user_id,
163 "session fingerprint mismatch — possible hijack, destroying session"
164 );
165 let _ = store.destroy(&session.id).await;
166 None
167 } else {
168 Some(session)
169 }
170 } else {
171 None
172 };
173
174 let should_touch = current_session.as_ref().is_some_and(|s| {
176 let elapsed = chrono::Utc::now() - s.last_active_at;
177 elapsed >= chrono::Duration::seconds(config.touch_interval_secs as i64)
178 });
179
180 let session_state = Arc::new(SessionState {
182 store: (*store).clone(),
183 meta,
184 current: Mutex::new(current_session.clone()),
185 dirty: AtomicBool::new(false),
186 action: Mutex::new(SessionAction::None),
187 });
188
189 request.extensions_mut().insert(session_state.clone());
190
191 let mut response = inner.call(request).await?;
193
194 let action = {
197 let guard = session_state.action.lock().expect("session mutex poisoned");
198 guard.clone()
199 };
200 let is_dirty = session_state.dirty.load(Ordering::SeqCst);
201 let ttl_secs = config.session_ttl_secs;
202
203 match action {
204 SessionAction::Set(token) => {
205 set_signed_cookie(
206 &mut response,
207 cookie_name,
208 &token.as_hex(),
209 ttl_secs,
210 &cookie_config,
211 &key,
212 );
213 }
214 SessionAction::Remove => {
215 remove_signed_cookie(&mut response, cookie_name, &cookie_config, &key);
216 }
217 SessionAction::None => {
218 if let Some(ref session) = current_session {
219 let now = chrono::Utc::now();
220 let new_expires = now + chrono::Duration::seconds(ttl_secs as i64);
221
222 if is_dirty {
223 let data = {
224 let guard = session_state
225 .current
226 .lock()
227 .expect("session mutex poisoned");
228 guard.as_ref().map(|s| s.data.clone())
229 };
230 if let Some(data) = data
231 && let Err(e) =
232 store.flush(&session.id, &data, now, new_expires).await
233 {
234 tracing::error!(
235 session_id = session.id,
236 "failed to flush session data: {e}"
237 );
238 }
239 } else if should_touch
240 && let Err(e) = store.touch(&session.id, now, new_expires).await
241 {
242 tracing::error!(
243 session_id = session.id,
244 "failed to touch session: {e}"
245 );
246 }
247
248 if (is_dirty || should_touch)
250 && let Some(ref token) = session_token
251 {
252 set_signed_cookie(
253 &mut response,
254 cookie_name,
255 &token.as_hex(),
256 ttl_secs,
257 &cookie_config,
258 &key,
259 );
260 }
261 }
262
263 if had_cookie && current_session.is_none() && !read_failed {
265 remove_signed_cookie(&mut response, cookie_name, &cookie_config, &key);
266 }
267 }
268 }
269
270 Ok(response)
271 })
272 }
273}
274
275fn read_signed_cookie(
278 headers: &http::HeaderMap,
279 cookie_name: &str,
280 key: &Key,
281) -> Option<SessionToken> {
282 let cookie_header = headers.get(http::header::COOKIE)?;
283 let cookie_str = cookie_header.to_str().ok()?;
284
285 for pair in cookie_str.split(';') {
286 let pair = pair.trim();
287 if let Some((name, value)) = pair.split_once('=')
288 && name.trim() == cookie_name
289 {
290 let mut jar = CookieJar::new();
292 jar.add_original(Cookie::new(
293 cookie_name.to_string(),
294 value.trim().to_string(),
295 ));
296 let verified = jar.signed(key).get(cookie_name)?;
297 return SessionToken::from_hex(verified.value()).ok();
298 }
299 }
300 None
301}
302
303fn set_signed_cookie(
305 response: &mut Response<Body>,
306 name: &str,
307 value: &str,
308 max_age_secs: u64,
309 config: &CookieConfig,
310 key: &Key,
311) {
312 let mut jar = CookieJar::new();
314 jar.signed_mut(key)
315 .add(Cookie::new(name.to_string(), value.to_string()));
316 let signed_value = jar
317 .get(name)
318 .expect("cookie was just added")
319 .value()
320 .to_string();
321
322 let same_site = match config.same_site.as_str() {
324 "strict" => SameSite::Strict,
325 "none" => SameSite::None,
326 _ => SameSite::Lax,
327 };
328 let set_cookie_str = Cookie::build((name.to_string(), signed_value))
329 .path("/")
330 .secure(config.secure)
331 .http_only(config.http_only)
332 .same_site(same_site)
333 .max_age(cookie::time::Duration::seconds(max_age_secs as i64))
334 .build()
335 .to_string();
336
337 match HeaderValue::from_str(&set_cookie_str) {
338 Ok(v) => {
339 response.headers_mut().append(http::header::SET_COOKIE, v);
340 }
341 Err(e) => {
342 tracing::error!(
343 cookie_name = name,
344 "failed to set session cookie header: {e}"
345 );
346 }
347 }
348}
349
350fn remove_signed_cookie(
351 response: &mut Response<Body>,
352 name: &str,
353 config: &CookieConfig,
354 key: &Key,
355) {
356 set_signed_cookie(response, name, "", 0, config, key);
357}