modo/auth/session/cookie/
middleware.rs1use std::future::Future;
2use std::pin::Pin;
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll};
6
7use axum::body::Body;
8use axum::extract::connect_info::ConnectInfo;
9use cookie::{Cookie, CookieJar, SameSite};
10use http::{HeaderValue, Request, Response};
11use tower::{Layer, Service};
12
13use crate::ip::ClientIp;
14
15use super::CookieSessionService;
16use super::extractor::{SessionAction, SessionState};
17use crate::auth::session::data::Session;
18use crate::auth::session::meta::{SessionMeta, header_str};
19use crate::auth::session::token::SessionToken;
20use crate::cookie::{CookieConfig, Key};
21
22#[derive(Clone)]
37pub struct CookieSessionLayer {
38 service: CookieSessionService,
39}
40
41pub fn layer(service: CookieSessionService) -> CookieSessionLayer {
43 CookieSessionLayer { service }
44}
45
46impl<S> Layer<S> for CookieSessionLayer {
47 type Service = CookieSessionMiddleware<S>;
48
49 fn layer(&self, inner: S) -> Self::Service {
50 CookieSessionMiddleware {
51 inner,
52 service: self.service.clone(),
53 }
54 }
55}
56
57#[derive(Clone)]
63pub struct CookieSessionMiddleware<S> {
64 inner: S,
65 service: CookieSessionService,
66}
67
68impl<S, ReqBody> Service<Request<ReqBody>> for CookieSessionMiddleware<S>
69where
70 S: Service<Request<ReqBody>, Response = Response<Body>> + Clone + Send + 'static,
71 S::Future: Send + 'static,
72 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
73 ReqBody: Send + 'static,
74{
75 type Response = Response<Body>;
76 type Error = S::Error;
77 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
78
79 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
80 self.inner.poll_ready(cx)
81 }
82
83 fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
84 let svc = self.service.clone();
85 let mut inner = self.inner.clone();
86 std::mem::swap(&mut self.inner, &mut inner);
87
88 Box::pin(async move {
89 let store = svc.store();
90 let config = store.config();
91 let cookie_name = &config.cookie_name;
92 let key = svc.cookie_key();
93 let cookie_config = svc.config().cookie.clone();
94
95 let ip = request
97 .extensions()
98 .get::<ClientIp>()
99 .map(|c| c.0.to_string())
100 .unwrap_or_else(|| {
101 request
103 .extensions()
104 .get::<ConnectInfo<std::net::SocketAddr>>()
105 .map(|ci| ci.0.ip().to_string())
106 .unwrap_or_else(|| "unknown".to_string())
107 });
108 let headers = request.headers();
109
110 let ua = header_str(headers, "user-agent");
112 let accept_lang = header_str(headers, "accept-language");
113 let accept_enc = header_str(headers, "accept-encoding");
114 let meta = SessionMeta::from_headers(ip, ua, accept_lang, accept_enc);
115
116 let session_token = read_signed_cookie(request.headers(), cookie_name, key);
118 let had_cookie = session_token.is_some();
119
120 let (current_session, read_failed) = if let Some(ref token) = session_token {
122 match store.read_by_token(token).await {
123 Ok(session) => (session, false),
124 Err(e) => {
125 tracing::error!("failed to read session: {e}");
126 (None, true)
127 }
128 }
129 } else {
130 (None, false)
131 };
132
133 let current_session = if let Some(session) = current_session {
135 if config.validate_fingerprint && meta.fingerprint != session.fingerprint {
136 tracing::warn!(
137 session_id = session.id,
138 user_id = session.user_id,
139 "session fingerprint mismatch — possible hijack, destroying session"
140 );
141 let _ = store.destroy(&session.id).await;
142 None
143 } else {
144 Some(session)
145 }
146 } else {
147 None
148 };
149
150 let should_touch = current_session.as_ref().is_some_and(|s| {
152 let elapsed = chrono::Utc::now() - s.last_active_at;
153 elapsed >= chrono::Duration::seconds(config.touch_interval_secs as i64)
154 });
155
156 if let Some(raw) = current_session.as_ref() {
158 let session_data = Session::from(raw.clone());
159 request.extensions_mut().insert(session_data);
160 }
161
162 let session_state = Arc::new(SessionState {
164 service: svc.clone(),
165 meta,
166 current: Mutex::new(current_session.clone()),
167 dirty: AtomicBool::new(false),
168 action: Mutex::new(SessionAction::None),
169 });
170
171 request.extensions_mut().insert(session_state.clone());
172
173 let mut response = inner.call(request).await?;
175
176 let action = {
179 let guard = session_state.action.lock().expect("session mutex poisoned");
180 guard.clone()
181 };
182 let is_dirty = session_state.dirty.load(Ordering::SeqCst);
183 let ttl_secs = config.session_ttl_secs;
184
185 match action {
186 SessionAction::Set(token) => {
187 set_signed_cookie(
188 &mut response,
189 cookie_name,
190 &token.as_hex(),
191 ttl_secs,
192 &cookie_config,
193 key,
194 );
195 }
196 SessionAction::Remove => {
197 remove_signed_cookie(&mut response, cookie_name, &cookie_config, key);
198 }
199 SessionAction::None => {
200 if let Some(ref session) = current_session {
201 let now = chrono::Utc::now();
202 let new_expires = now + chrono::Duration::seconds(ttl_secs as i64);
203
204 if is_dirty {
205 let data = {
206 let guard = session_state
207 .current
208 .lock()
209 .expect("session mutex poisoned");
210 guard.as_ref().map(|s| s.data.clone())
211 };
212 if let Some(data) = data
213 && let Err(e) =
214 store.flush(&session.id, &data, now, new_expires).await
215 {
216 tracing::error!(
217 session_id = session.id,
218 "failed to flush session data: {e}"
219 );
220 }
221 } else if should_touch
222 && let Err(e) = store.touch(&session.id, now, new_expires).await
223 {
224 tracing::error!(
225 session_id = session.id,
226 "failed to touch session: {e}"
227 );
228 }
229
230 if (is_dirty || should_touch)
232 && let Some(ref token) = session_token
233 {
234 set_signed_cookie(
235 &mut response,
236 cookie_name,
237 &token.as_hex(),
238 ttl_secs,
239 &cookie_config,
240 key,
241 );
242 }
243 }
244
245 if had_cookie && current_session.is_none() && !read_failed {
247 remove_signed_cookie(&mut response, cookie_name, &cookie_config, key);
248 }
249 }
250 }
251
252 Ok(response)
253 })
254 }
255}
256
257fn read_signed_cookie(
260 headers: &http::HeaderMap,
261 cookie_name: &str,
262 key: &Key,
263) -> Option<SessionToken> {
264 let cookie_header = headers.get(http::header::COOKIE)?;
265 let cookie_str = cookie_header.to_str().ok()?;
266
267 for pair in cookie_str.split(';') {
268 let pair = pair.trim();
269 if let Some((name, value)) = pair.split_once('=')
270 && name.trim() == cookie_name
271 {
272 let mut jar = CookieJar::new();
274 jar.add_original(Cookie::new(
275 cookie_name.to_string(),
276 value.trim().to_string(),
277 ));
278 let verified = jar.signed(key).get(cookie_name)?;
279 return SessionToken::from_hex(verified.value()).ok();
280 }
281 }
282 None
283}
284
285fn set_signed_cookie(
287 response: &mut Response<Body>,
288 name: &str,
289 value: &str,
290 max_age_secs: u64,
291 config: &CookieConfig,
292 key: &Key,
293) {
294 let mut jar = CookieJar::new();
296 jar.signed_mut(key)
297 .add(Cookie::new(name.to_string(), value.to_string()));
298 let signed_value = jar
299 .get(name)
300 .expect("cookie was just added")
301 .value()
302 .to_string();
303
304 let same_site = match config.same_site.as_str() {
306 "strict" => SameSite::Strict,
307 "none" => SameSite::None,
308 _ => SameSite::Lax,
309 };
310 let set_cookie_str = Cookie::build((name.to_string(), signed_value))
311 .path("/")
312 .secure(config.secure)
313 .http_only(config.http_only)
314 .same_site(same_site)
315 .max_age(cookie::time::Duration::seconds(max_age_secs as i64))
316 .build()
317 .to_string();
318
319 match HeaderValue::from_str(&set_cookie_str) {
320 Ok(v) => {
321 response.headers_mut().append(http::header::SET_COOKIE, v);
322 }
323 Err(e) => {
324 tracing::error!(
325 cookie_name = name,
326 "failed to set session cookie header: {e}"
327 );
328 }
329 }
330}
331
332fn remove_signed_cookie(
333 response: &mut Response<Body>,
334 name: &str,
335 config: &CookieConfig,
336 key: &Key,
337) {
338 set_signed_cookie(response, name, "", 0, config, key);
339}