modo/auth/session/cookie/
middleware.rs1use std::future::Future;
2use std::pin::Pin;
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll};
6
7use axum::body::Body;
8use axum::extract::connect_info::ConnectInfo;
9use cookie::{Cookie, CookieJar, SameSite};
10use http::{HeaderValue, Request, Response};
11use tower::{Layer, Service};
12
13use crate::ip::ClientIp;
14
15use super::CookieSessionService;
16use super::extractor::{SessionAction, SessionState};
17use crate::auth::session::data::Session;
18use crate::auth::session::meta::{SessionMeta, header_str};
19use crate::auth::session::token::SessionToken;
20use crate::cookie::{CookieConfig, Key};
21
22#[derive(Clone)]
37pub struct CookieSessionLayer {
38 service: CookieSessionService,
39}
40
41pub fn layer(service: CookieSessionService) -> CookieSessionLayer {
43 CookieSessionLayer { service }
44}
45
46impl<S> Layer<S> for CookieSessionLayer {
47 type Service = CookieSessionMiddleware<S>;
48
49 fn layer(&self, inner: S) -> Self::Service {
50 CookieSessionMiddleware {
51 inner,
52 service: self.service.clone(),
53 }
54 }
55}
56
57#[derive(Clone)]
63pub struct CookieSessionMiddleware<S> {
64 inner: S,
65 service: CookieSessionService,
66}
67
68impl<S, ReqBody> Service<Request<ReqBody>> for CookieSessionMiddleware<S>
69where
70 S: Service<Request<ReqBody>, Response = Response<Body>> + Clone + Send + 'static,
71 S::Future: Send + 'static,
72 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
73 ReqBody: Send + 'static,
74{
75 type Response = Response<Body>;
76 type Error = S::Error;
77 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
78
79 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
80 self.inner.poll_ready(cx)
81 }
82
83 fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
84 let svc = self.service.clone();
85 let mut inner = self.inner.clone();
86 std::mem::swap(&mut self.inner, &mut inner);
87
88 Box::pin(async move {
89 let store = svc.store();
90 let config = store.config();
91 let cookie_name = &config.cookie_name;
92 let key = svc.cookie_key();
93 let cookie_config = svc.config().cookie.clone();
94
95 let ip = request
96 .extensions()
97 .get::<ClientIp>()
98 .map(|c| c.0.to_string())
99 .unwrap_or_else(|| {
100 request
101 .extensions()
102 .get::<ConnectInfo<std::net::SocketAddr>>()
103 .map(|ci| ci.0.ip().to_string())
104 .unwrap_or_else(|| "unknown".to_string())
105 });
106 let headers = request.headers();
107
108 let ua = header_str(headers, "user-agent");
109 let accept_lang = header_str(headers, "accept-language");
110 let accept_enc = header_str(headers, "accept-encoding");
111 let meta = SessionMeta::from_headers(ip, ua, accept_lang, accept_enc);
112
113 let session_token = read_signed_cookie(request.headers(), cookie_name, key);
114 let had_cookie = session_token.is_some();
115
116 let (current_session, read_failed) = if let Some(ref token) = session_token {
117 match store.read_by_token(token).await {
118 Ok(session) => (session, false),
119 Err(e) => {
120 tracing::error!("failed to read session: {e}");
121 (None, true)
122 }
123 }
124 } else {
125 (None, false)
126 };
127
128 let current_session = if let Some(session) = current_session {
129 if config.validate_fingerprint && meta.fingerprint != session.fingerprint {
130 tracing::warn!(
131 session_id = session.id,
132 user_id = session.user_id,
133 "session fingerprint mismatch — possible hijack, destroying session"
134 );
135 let _ = store.destroy(&session.id).await;
136 None
137 } else {
138 Some(session)
139 }
140 } else {
141 None
142 };
143
144 let should_touch = current_session.as_ref().is_some_and(|s| {
145 let elapsed = chrono::Utc::now() - s.last_active_at;
146 elapsed >= chrono::Duration::seconds(config.touch_interval_secs as i64)
147 });
148
149 if let Some(raw) = current_session.as_ref() {
150 let session_data = Session::from(raw.clone());
151 request.extensions_mut().insert(session_data);
152 }
153
154 let session_state = Arc::new(SessionState {
155 service: svc.clone(),
156 meta,
157 current: Mutex::new(current_session.clone()),
158 dirty: AtomicBool::new(false),
159 action: Mutex::new(SessionAction::None),
160 });
161
162 request.extensions_mut().insert(session_state.clone());
163
164 let mut response = inner.call(request).await?;
165
166 let action = {
167 let guard = session_state.action.lock().expect("session mutex poisoned");
168 guard.clone()
169 };
170 let is_dirty = session_state.dirty.load(Ordering::SeqCst);
171 let ttl_secs = config.session_ttl_secs;
172
173 match action {
174 SessionAction::Set(token) => {
175 set_signed_cookie(
176 &mut response,
177 cookie_name,
178 &token.as_hex(),
179 ttl_secs,
180 &cookie_config,
181 key,
182 );
183 }
184 SessionAction::Remove => {
185 remove_signed_cookie(&mut response, cookie_name, &cookie_config, key);
186 }
187 SessionAction::None => {
188 if let Some(ref session) = current_session {
189 let now = chrono::Utc::now();
190 let new_expires = now + chrono::Duration::seconds(ttl_secs as i64);
191
192 if is_dirty {
193 let data = {
194 let guard = session_state
195 .current
196 .lock()
197 .expect("session mutex poisoned");
198 guard.as_ref().map(|s| s.data.clone())
199 };
200 if let Some(data) = data
201 && let Err(e) =
202 store.flush(&session.id, &data, now, new_expires).await
203 {
204 tracing::error!(
205 session_id = session.id,
206 "failed to flush session data: {e}"
207 );
208 }
209 } else if should_touch
210 && let Err(e) = store.touch(&session.id, now, new_expires).await
211 {
212 tracing::error!(
213 session_id = session.id,
214 "failed to touch session: {e}"
215 );
216 }
217
218 if (is_dirty || should_touch)
219 && let Some(ref token) = session_token
220 {
221 set_signed_cookie(
222 &mut response,
223 cookie_name,
224 &token.as_hex(),
225 ttl_secs,
226 &cookie_config,
227 key,
228 );
229 }
230 }
231
232 if had_cookie && current_session.is_none() && !read_failed {
233 remove_signed_cookie(&mut response, cookie_name, &cookie_config, key);
234 }
235 }
236 }
237
238 Ok(response)
239 })
240 }
241}
242
243fn read_signed_cookie(
246 headers: &http::HeaderMap,
247 cookie_name: &str,
248 key: &Key,
249) -> Option<SessionToken> {
250 let cookie_header = headers.get(http::header::COOKIE)?;
251 let cookie_str = cookie_header.to_str().ok()?;
252
253 for pair in cookie_str.split(';') {
254 let pair = pair.trim();
255 if let Some((name, value)) = pair.split_once('=')
256 && name.trim() == cookie_name
257 {
258 let mut jar = CookieJar::new();
259 jar.add_original(Cookie::new(
260 cookie_name.to_string(),
261 value.trim().to_string(),
262 ));
263 let verified = jar.signed(key).get(cookie_name)?;
264 return SessionToken::from_hex(verified.value()).ok();
265 }
266 }
267 None
268}
269
270fn set_signed_cookie(
272 response: &mut Response<Body>,
273 name: &str,
274 value: &str,
275 max_age_secs: u64,
276 config: &CookieConfig,
277 key: &Key,
278) {
279 let mut jar = CookieJar::new();
281 jar.signed_mut(key)
282 .add(Cookie::new(name.to_string(), value.to_string()));
283 let signed_value = jar
284 .get(name)
285 .expect("cookie was just added")
286 .value()
287 .to_string();
288
289 let same_site = match config.same_site.as_str() {
290 "strict" => SameSite::Strict,
291 "none" => SameSite::None,
292 _ => SameSite::Lax,
293 };
294 let set_cookie_str = Cookie::build((name.to_string(), signed_value))
295 .path("/")
296 .secure(config.secure)
297 .http_only(config.http_only)
298 .same_site(same_site)
299 .max_age(cookie::time::Duration::seconds(max_age_secs as i64))
300 .build()
301 .to_string();
302
303 match HeaderValue::from_str(&set_cookie_str) {
304 Ok(v) => {
305 response.headers_mut().append(http::header::SET_COOKIE, v);
306 }
307 Err(e) => {
308 tracing::error!(
309 cookie_name = name,
310 "failed to set session cookie header: {e}"
311 );
312 }
313 }
314}
315
316fn remove_signed_cookie(
317 response: &mut Response<Body>,
318 name: &str,
319 config: &CookieConfig,
320 key: &Key,
321) {
322 set_signed_cookie(response, name, "", 0, config, key);
323}