modo/auth/session/cookie/
middleware.rs1use std::future::Future;
2use std::pin::Pin;
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll};
6
7use axum::body::Body;
8use axum::extract::connect_info::ConnectInfo;
9use cookie::{Cookie, CookieJar, SameSite};
10use http::{HeaderValue, Request, Response};
11use tower::{Layer, Service};
12
13use crate::client::{ClientInfo, header_str};
14use crate::ip::ClientIp;
15
16use super::CookieSessionService;
17use super::extractor::{SessionAction, SessionState};
18use crate::auth::session::data::Session;
19use crate::auth::session::token::SessionToken;
20use crate::cookie::{CookieConfig, Key};
21
22#[derive(Clone)]
37pub struct CookieSessionLayer {
38 service: CookieSessionService,
39}
40
41pub fn layer(service: CookieSessionService) -> CookieSessionLayer {
47 CookieSessionLayer { service }
48}
49
50impl<S> Layer<S> for CookieSessionLayer {
51 type Service = CookieSessionMiddleware<S>;
52
53 fn layer(&self, inner: S) -> Self::Service {
54 CookieSessionMiddleware {
55 inner,
56 service: self.service.clone(),
57 }
58 }
59}
60
61#[derive(Clone)]
67pub struct CookieSessionMiddleware<S> {
68 inner: S,
69 service: CookieSessionService,
70}
71
72impl<S, ReqBody> Service<Request<ReqBody>> for CookieSessionMiddleware<S>
73where
74 S: Service<Request<ReqBody>, Response = Response<Body>> + Clone + Send + 'static,
75 S::Future: Send + 'static,
76 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
77 ReqBody: Send + 'static,
78{
79 type Response = Response<Body>;
80 type Error = S::Error;
81 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
82
83 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
84 self.inner.poll_ready(cx)
85 }
86
87 fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
88 let svc = self.service.clone();
89 let mut inner = self.inner.clone();
90 std::mem::swap(&mut self.inner, &mut inner);
91
92 Box::pin(async move {
93 let store = svc.store();
94 let config = store.config();
95 let cookie_name = &config.cookie_name;
96 let key = svc.cookie_key();
97 let cookie_config = svc.config().cookie.clone();
98
99 let ip = request
100 .extensions()
101 .get::<ClientIp>()
102 .map(|c| c.0.to_string())
103 .unwrap_or_else(|| {
104 request
105 .extensions()
106 .get::<ConnectInfo<std::net::SocketAddr>>()
107 .map(|ci| ci.0.ip().to_string())
108 .unwrap_or_else(|| "unknown".to_string())
109 });
110 let headers = request.headers();
111
112 let ua = header_str(headers, "user-agent");
113 let accept_lang = header_str(headers, "accept-language");
114 let accept_enc = header_str(headers, "accept-encoding");
115 let info = ClientInfo::from_headers(Some(ip), ua, accept_lang, accept_enc);
116
117 let session_token = read_signed_cookie(request.headers(), cookie_name, key);
118 let had_cookie = session_token.is_some();
119
120 let (current_session, read_failed) = if let Some(ref token) = session_token {
121 match store.read_by_token(token).await {
122 Ok(session) => (session, false),
123 Err(e) => {
124 tracing::error!("failed to read session: {e}");
125 (None, true)
126 }
127 }
128 } else {
129 (None, false)
130 };
131
132 let current_session = if let Some(session) = current_session {
133 if config.validate_fingerprint
134 && info.fingerprint_value() != Some(session.fingerprint.as_str())
135 {
136 tracing::warn!(
137 session_id = session.id,
138 user_id = session.user_id,
139 "session fingerprint mismatch — possible hijack, destroying session"
140 );
141 let _ = store.destroy(&session.id).await;
142 None
143 } else {
144 Some(session)
145 }
146 } else {
147 None
148 };
149
150 let should_touch = current_session.as_ref().is_some_and(|s| {
151 let elapsed = chrono::Utc::now() - s.last_active_at;
152 elapsed >= chrono::Duration::seconds(config.touch_interval_secs as i64)
153 });
154
155 if let Some(raw) = current_session.as_ref() {
156 let session_data = Session::from(raw.clone());
157 request.extensions_mut().insert(session_data);
158 }
159
160 let session_state = Arc::new(SessionState {
161 service: svc.clone(),
162 info,
163 current: Mutex::new(current_session.clone()),
164 dirty: AtomicBool::new(false),
165 action: Mutex::new(SessionAction::None),
166 });
167
168 request.extensions_mut().insert(session_state.clone());
169
170 let mut response = inner.call(request).await?;
171
172 let action = {
173 let guard = session_state.action.lock().expect("session mutex poisoned");
174 guard.clone()
175 };
176 let is_dirty = session_state.dirty.load(Ordering::SeqCst);
177 let ttl_secs = config.session_ttl_secs;
178
179 match action {
180 SessionAction::Set(token) => {
181 set_signed_cookie(
182 &mut response,
183 cookie_name,
184 &token.as_hex(),
185 ttl_secs,
186 &cookie_config,
187 key,
188 );
189 }
190 SessionAction::Remove => {
191 remove_signed_cookie(&mut response, cookie_name, &cookie_config, key);
192 }
193 SessionAction::None => {
194 if let Some(ref session) = current_session {
195 let now = chrono::Utc::now();
196 let new_expires = now + chrono::Duration::seconds(ttl_secs as i64);
197
198 if is_dirty {
199 let data = {
200 let guard = session_state
201 .current
202 .lock()
203 .expect("session mutex poisoned");
204 guard.as_ref().map(|s| s.data.clone())
205 };
206 if let Some(data) = data
207 && let Err(e) =
208 store.flush(&session.id, &data, now, new_expires).await
209 {
210 tracing::error!(
211 session_id = session.id,
212 "failed to flush session data: {e}"
213 );
214 }
215 } else if should_touch
216 && let Err(e) = store.touch(&session.id, now, new_expires).await
217 {
218 tracing::error!(
219 session_id = session.id,
220 "failed to touch session: {e}"
221 );
222 }
223
224 if (is_dirty || should_touch)
225 && let Some(ref token) = session_token
226 {
227 set_signed_cookie(
228 &mut response,
229 cookie_name,
230 &token.as_hex(),
231 ttl_secs,
232 &cookie_config,
233 key,
234 );
235 }
236 }
237
238 if had_cookie && current_session.is_none() && !read_failed {
239 remove_signed_cookie(&mut response, cookie_name, &cookie_config, key);
240 }
241 }
242 }
243
244 Ok(response)
245 })
246 }
247}
248
249fn read_signed_cookie(
252 headers: &http::HeaderMap,
253 cookie_name: &str,
254 key: &Key,
255) -> Option<SessionToken> {
256 let cookie_header = headers.get(http::header::COOKIE)?;
257 let cookie_str = cookie_header.to_str().ok()?;
258
259 for pair in cookie_str.split(';') {
260 let pair = pair.trim();
261 if let Some((name, value)) = pair.split_once('=')
262 && name.trim() == cookie_name
263 {
264 let mut jar = CookieJar::new();
265 jar.add_original(Cookie::new(
266 cookie_name.to_string(),
267 value.trim().to_string(),
268 ));
269 let verified = jar.signed(key).get(cookie_name)?;
270 return SessionToken::from_hex(verified.value()).ok();
271 }
272 }
273 None
274}
275
276fn set_signed_cookie(
278 response: &mut Response<Body>,
279 name: &str,
280 value: &str,
281 max_age_secs: u64,
282 config: &CookieConfig,
283 key: &Key,
284) {
285 let mut jar = CookieJar::new();
287 jar.signed_mut(key)
288 .add(Cookie::new(name.to_string(), value.to_string()));
289 let signed_value = jar
290 .get(name)
291 .expect("cookie was just added")
292 .value()
293 .to_string();
294
295 let same_site = match config.same_site.as_str() {
296 "strict" => SameSite::Strict,
297 "none" => SameSite::None,
298 _ => SameSite::Lax,
299 };
300 let set_cookie_str = Cookie::build((name.to_string(), signed_value))
301 .path("/")
302 .secure(config.secure)
303 .http_only(config.http_only)
304 .same_site(same_site)
305 .max_age(cookie::time::Duration::seconds(max_age_secs as i64))
306 .build()
307 .to_string();
308
309 match HeaderValue::from_str(&set_cookie_str) {
310 Ok(v) => {
311 response.headers_mut().append(http::header::SET_COOKIE, v);
312 }
313 Err(e) => {
314 tracing::error!(
315 cookie_name = name,
316 "failed to set session cookie header: {e}"
317 );
318 }
319 }
320}
321
322fn remove_signed_cookie(
323 response: &mut Response<Body>,
324 name: &str,
325 config: &CookieConfig,
326 key: &Key,
327) {
328 set_signed_cookie(response, name, "", 0, config, key);
329}