modo/auth/session/cookie/
middleware.rs1use std::future::Future;
2use std::pin::Pin;
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll};
6
7use axum::body::Body;
8use axum::extract::connect_info::ConnectInfo;
9use cookie::{Cookie, CookieJar, SameSite};
10use http::{HeaderValue, Request, Response};
11use tower::{Layer, Service};
12
13use crate::ip::ClientIp;
14
15use super::CookieSessionService;
16use super::extractor::{SessionAction, SessionState};
17use crate::auth::session::data::Session;
18use crate::auth::session::meta::{SessionMeta, header_str};
19use crate::auth::session::token::SessionToken;
20use crate::cookie::{CookieConfig, Key};
21
22#[derive(Clone)]
37pub struct CookieSessionLayer {
38 service: CookieSessionService,
39}
40
41pub fn layer(service: CookieSessionService) -> CookieSessionLayer {
47 CookieSessionLayer { service }
48}
49
50impl<S> Layer<S> for CookieSessionLayer {
51 type Service = CookieSessionMiddleware<S>;
52
53 fn layer(&self, inner: S) -> Self::Service {
54 CookieSessionMiddleware {
55 inner,
56 service: self.service.clone(),
57 }
58 }
59}
60
61#[derive(Clone)]
67pub struct CookieSessionMiddleware<S> {
68 inner: S,
69 service: CookieSessionService,
70}
71
72impl<S, ReqBody> Service<Request<ReqBody>> for CookieSessionMiddleware<S>
73where
74 S: Service<Request<ReqBody>, Response = Response<Body>> + Clone + Send + 'static,
75 S::Future: Send + 'static,
76 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
77 ReqBody: Send + 'static,
78{
79 type Response = Response<Body>;
80 type Error = S::Error;
81 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
82
83 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
84 self.inner.poll_ready(cx)
85 }
86
87 fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
88 let svc = self.service.clone();
89 let mut inner = self.inner.clone();
90 std::mem::swap(&mut self.inner, &mut inner);
91
92 Box::pin(async move {
93 let store = svc.store();
94 let config = store.config();
95 let cookie_name = &config.cookie_name;
96 let key = svc.cookie_key();
97 let cookie_config = svc.config().cookie.clone();
98
99 let ip = request
100 .extensions()
101 .get::<ClientIp>()
102 .map(|c| c.0.to_string())
103 .unwrap_or_else(|| {
104 request
105 .extensions()
106 .get::<ConnectInfo<std::net::SocketAddr>>()
107 .map(|ci| ci.0.ip().to_string())
108 .unwrap_or_else(|| "unknown".to_string())
109 });
110 let headers = request.headers();
111
112 let ua = header_str(headers, "user-agent");
113 let accept_lang = header_str(headers, "accept-language");
114 let accept_enc = header_str(headers, "accept-encoding");
115 let meta = SessionMeta::from_headers(ip, ua, accept_lang, accept_enc);
116
117 let session_token = read_signed_cookie(request.headers(), cookie_name, key);
118 let had_cookie = session_token.is_some();
119
120 let (current_session, read_failed) = if let Some(ref token) = session_token {
121 match store.read_by_token(token).await {
122 Ok(session) => (session, false),
123 Err(e) => {
124 tracing::error!("failed to read session: {e}");
125 (None, true)
126 }
127 }
128 } else {
129 (None, false)
130 };
131
132 let current_session = if let Some(session) = current_session {
133 if config.validate_fingerprint && meta.fingerprint != session.fingerprint {
134 tracing::warn!(
135 session_id = session.id,
136 user_id = session.user_id,
137 "session fingerprint mismatch — possible hijack, destroying session"
138 );
139 let _ = store.destroy(&session.id).await;
140 None
141 } else {
142 Some(session)
143 }
144 } else {
145 None
146 };
147
148 let should_touch = current_session.as_ref().is_some_and(|s| {
149 let elapsed = chrono::Utc::now() - s.last_active_at;
150 elapsed >= chrono::Duration::seconds(config.touch_interval_secs as i64)
151 });
152
153 if let Some(raw) = current_session.as_ref() {
154 let session_data = Session::from(raw.clone());
155 request.extensions_mut().insert(session_data);
156 }
157
158 let session_state = Arc::new(SessionState {
159 service: svc.clone(),
160 meta,
161 current: Mutex::new(current_session.clone()),
162 dirty: AtomicBool::new(false),
163 action: Mutex::new(SessionAction::None),
164 });
165
166 request.extensions_mut().insert(session_state.clone());
167
168 let mut response = inner.call(request).await?;
169
170 let action = {
171 let guard = session_state.action.lock().expect("session mutex poisoned");
172 guard.clone()
173 };
174 let is_dirty = session_state.dirty.load(Ordering::SeqCst);
175 let ttl_secs = config.session_ttl_secs;
176
177 match action {
178 SessionAction::Set(token) => {
179 set_signed_cookie(
180 &mut response,
181 cookie_name,
182 &token.as_hex(),
183 ttl_secs,
184 &cookie_config,
185 key,
186 );
187 }
188 SessionAction::Remove => {
189 remove_signed_cookie(&mut response, cookie_name, &cookie_config, key);
190 }
191 SessionAction::None => {
192 if let Some(ref session) = current_session {
193 let now = chrono::Utc::now();
194 let new_expires = now + chrono::Duration::seconds(ttl_secs as i64);
195
196 if is_dirty {
197 let data = {
198 let guard = session_state
199 .current
200 .lock()
201 .expect("session mutex poisoned");
202 guard.as_ref().map(|s| s.data.clone())
203 };
204 if let Some(data) = data
205 && let Err(e) =
206 store.flush(&session.id, &data, now, new_expires).await
207 {
208 tracing::error!(
209 session_id = session.id,
210 "failed to flush session data: {e}"
211 );
212 }
213 } else if should_touch
214 && let Err(e) = store.touch(&session.id, now, new_expires).await
215 {
216 tracing::error!(
217 session_id = session.id,
218 "failed to touch session: {e}"
219 );
220 }
221
222 if (is_dirty || should_touch)
223 && let Some(ref token) = session_token
224 {
225 set_signed_cookie(
226 &mut response,
227 cookie_name,
228 &token.as_hex(),
229 ttl_secs,
230 &cookie_config,
231 key,
232 );
233 }
234 }
235
236 if had_cookie && current_session.is_none() && !read_failed {
237 remove_signed_cookie(&mut response, cookie_name, &cookie_config, key);
238 }
239 }
240 }
241
242 Ok(response)
243 })
244 }
245}
246
247fn read_signed_cookie(
250 headers: &http::HeaderMap,
251 cookie_name: &str,
252 key: &Key,
253) -> Option<SessionToken> {
254 let cookie_header = headers.get(http::header::COOKIE)?;
255 let cookie_str = cookie_header.to_str().ok()?;
256
257 for pair in cookie_str.split(';') {
258 let pair = pair.trim();
259 if let Some((name, value)) = pair.split_once('=')
260 && name.trim() == cookie_name
261 {
262 let mut jar = CookieJar::new();
263 jar.add_original(Cookie::new(
264 cookie_name.to_string(),
265 value.trim().to_string(),
266 ));
267 let verified = jar.signed(key).get(cookie_name)?;
268 return SessionToken::from_hex(verified.value()).ok();
269 }
270 }
271 None
272}
273
274fn set_signed_cookie(
276 response: &mut Response<Body>,
277 name: &str,
278 value: &str,
279 max_age_secs: u64,
280 config: &CookieConfig,
281 key: &Key,
282) {
283 let mut jar = CookieJar::new();
285 jar.signed_mut(key)
286 .add(Cookie::new(name.to_string(), value.to_string()));
287 let signed_value = jar
288 .get(name)
289 .expect("cookie was just added")
290 .value()
291 .to_string();
292
293 let same_site = match config.same_site.as_str() {
294 "strict" => SameSite::Strict,
295 "none" => SameSite::None,
296 _ => SameSite::Lax,
297 };
298 let set_cookie_str = Cookie::build((name.to_string(), signed_value))
299 .path("/")
300 .secure(config.secure)
301 .http_only(config.http_only)
302 .same_site(same_site)
303 .max_age(cookie::time::Duration::seconds(max_age_secs as i64))
304 .build()
305 .to_string();
306
307 match HeaderValue::from_str(&set_cookie_str) {
308 Ok(v) => {
309 response.headers_mut().append(http::header::SET_COOKIE, v);
310 }
311 Err(e) => {
312 tracing::error!(
313 cookie_name = name,
314 "failed to set session cookie header: {e}"
315 );
316 }
317 }
318}
319
320fn remove_signed_cookie(
321 response: &mut Response<Body>,
322 name: &str,
323 config: &CookieConfig,
324 key: &Key,
325) {
326 set_signed_cookie(response, name, "", 0, config, key);
327}