Skip to main content

modo/auth/session/cookie/
middleware.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll};
6
7use axum::body::Body;
8use axum::extract::connect_info::ConnectInfo;
9use cookie::{Cookie, CookieJar, SameSite};
10use http::{HeaderValue, Request, Response};
11use tower::{Layer, Service};
12
13use crate::ip::ClientIp;
14
15use super::CookieSessionService;
16use super::extractor::{SessionAction, SessionState};
17use crate::auth::session::data::Session;
18use crate::auth::session::meta::{SessionMeta, header_str};
19use crate::auth::session::token::SessionToken;
20use crate::cookie::{CookieConfig, Key};
21
22// --- Layer ---
23
24/// Tower [`Layer`] that installs the session middleware into the request pipeline.
25///
26/// Construct with [`layer`] or [`CookieSessionService::layer`] rather than
27/// directly. Apply before route handlers with `Router::layer(session_layer)`.
28///
29/// The middleware reads the signed session cookie, loads the session from the
30/// database, validates the browser fingerprint (when configured), and inserts:
31/// - an `Arc<SessionState>` for the [`super::extractor::CookieSession`] extractor
32/// - a [`crate::auth::session::data::Session`] snapshot for the data extractor
33///
34/// On the response path it flushes dirty session data, touches the expiry
35/// timestamp, and sets or clears the session cookie as needed.
36#[derive(Clone)]
37pub struct CookieSessionLayer {
38    service: CookieSessionService,
39}
40
41/// Create a [`CookieSessionLayer`] from a [`CookieSessionService`].
42pub fn layer(service: CookieSessionService) -> CookieSessionLayer {
43    CookieSessionLayer { service }
44}
45
46impl<S> Layer<S> for CookieSessionLayer {
47    type Service = CookieSessionMiddleware<S>;
48
49    fn layer(&self, inner: S) -> Self::Service {
50        CookieSessionMiddleware {
51            inner,
52            service: self.service.clone(),
53        }
54    }
55}
56
57// --- Service ---
58
59/// Tower [`Service`] that manages the session lifecycle for each request.
60///
61/// Produced by [`CookieSessionLayer`]; not constructed directly.
62#[derive(Clone)]
63pub struct CookieSessionMiddleware<S> {
64    inner: S,
65    service: CookieSessionService,
66}
67
68impl<S, ReqBody> Service<Request<ReqBody>> for CookieSessionMiddleware<S>
69where
70    S: Service<Request<ReqBody>, Response = Response<Body>> + Clone + Send + 'static,
71    S::Future: Send + 'static,
72    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
73    ReqBody: Send + 'static,
74{
75    type Response = Response<Body>;
76    type Error = S::Error;
77    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
78
79    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
80        self.inner.poll_ready(cx)
81    }
82
83    fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
84        let svc = self.service.clone();
85        let mut inner = self.inner.clone();
86        std::mem::swap(&mut self.inner, &mut inner);
87
88        Box::pin(async move {
89            let store = svc.store();
90            let config = store.config();
91            let cookie_name = &config.cookie_name;
92            let key = svc.cookie_key();
93            let cookie_config = svc.config().cookie.clone();
94
95            // 1. Extract client IP
96            let ip = request
97                .extensions()
98                .get::<ClientIp>()
99                .map(|c| c.0.to_string())
100                .unwrap_or_else(|| {
101                    // Fallback: no ClientIpLayer applied — use ConnectInfo directly
102                    request
103                        .extensions()
104                        .get::<ConnectInfo<std::net::SocketAddr>>()
105                        .map(|ci| ci.0.ip().to_string())
106                        .unwrap_or_else(|| "unknown".to_string())
107                });
108            let headers = request.headers();
109
110            // 2. Build SessionMeta
111            let ua = header_str(headers, "user-agent");
112            let accept_lang = header_str(headers, "accept-language");
113            let accept_enc = header_str(headers, "accept-encoding");
114            let meta = SessionMeta::from_headers(ip, ua, accept_lang, accept_enc);
115
116            // 3. Read signed session cookie
117            let session_token = read_signed_cookie(request.headers(), cookie_name, key);
118            let had_cookie = session_token.is_some();
119
120            // 4. Load session from DB
121            let (current_session, read_failed) = if let Some(ref token) = session_token {
122                match store.read_by_token(token).await {
123                    Ok(session) => (session, false),
124                    Err(e) => {
125                        tracing::error!("failed to read session: {e}");
126                        (None, true)
127                    }
128                }
129            } else {
130                (None, false)
131            };
132
133            // 5/6. Validate fingerprint
134            let current_session = if let Some(session) = current_session {
135                if config.validate_fingerprint && meta.fingerprint != session.fingerprint {
136                    tracing::warn!(
137                        session_id = session.id,
138                        user_id = session.user_id,
139                        "session fingerprint mismatch — possible hijack, destroying session"
140                    );
141                    let _ = store.destroy(&session.id).await;
142                    None
143                } else {
144                    Some(session)
145                }
146            } else {
147                None
148            };
149
150            // Check if touch interval elapsed
151            let should_touch = current_session.as_ref().is_some_and(|s| {
152                let elapsed = chrono::Utc::now() - s.last_active_at;
153                elapsed >= chrono::Duration::seconds(config.touch_interval_secs as i64)
154            });
155
156            // 7. Insert Session data view for the data extractor
157            if let Some(raw) = current_session.as_ref() {
158                let session_data = Session::from(raw.clone());
159                request.extensions_mut().insert(session_data);
160            }
161
162            // 8. Build SessionState and insert for CookieSession extractor
163            let session_state = Arc::new(SessionState {
164                service: svc.clone(),
165                meta,
166                current: Mutex::new(current_session.clone()),
167                dirty: AtomicBool::new(false),
168                action: Mutex::new(SessionAction::None),
169            });
170
171            request.extensions_mut().insert(session_state.clone());
172
173            // Run inner service
174            let mut response = inner.call(request).await?;
175
176            // --- Response path ---
177
178            let action = {
179                let guard = session_state.action.lock().expect("session mutex poisoned");
180                guard.clone()
181            };
182            let is_dirty = session_state.dirty.load(Ordering::SeqCst);
183            let ttl_secs = config.session_ttl_secs;
184
185            match action {
186                SessionAction::Set(token) => {
187                    set_signed_cookie(
188                        &mut response,
189                        cookie_name,
190                        &token.as_hex(),
191                        ttl_secs,
192                        &cookie_config,
193                        key,
194                    );
195                }
196                SessionAction::Remove => {
197                    remove_signed_cookie(&mut response, cookie_name, &cookie_config, key);
198                }
199                SessionAction::None => {
200                    if let Some(ref session) = current_session {
201                        let now = chrono::Utc::now();
202                        let new_expires = now + chrono::Duration::seconds(ttl_secs as i64);
203
204                        if is_dirty {
205                            let data = {
206                                let guard = session_state
207                                    .current
208                                    .lock()
209                                    .expect("session mutex poisoned");
210                                guard.as_ref().map(|s| s.data.clone())
211                            };
212                            if let Some(data) = data
213                                && let Err(e) =
214                                    store.flush(&session.id, &data, now, new_expires).await
215                            {
216                                tracing::error!(
217                                    session_id = session.id,
218                                    "failed to flush session data: {e}"
219                                );
220                            }
221                        } else if should_touch
222                            && let Err(e) = store.touch(&session.id, now, new_expires).await
223                        {
224                            tracing::error!(
225                                session_id = session.id,
226                                "failed to touch session: {e}"
227                            );
228                        }
229
230                        // Refresh cookie if we did a flush or touch
231                        if (is_dirty || should_touch)
232                            && let Some(ref token) = session_token
233                        {
234                            set_signed_cookie(
235                                &mut response,
236                                cookie_name,
237                                &token.as_hex(),
238                                ttl_secs,
239                                &cookie_config,
240                                key,
241                            );
242                        }
243                    }
244
245                    // Stale cookie cleanup
246                    if had_cookie && current_session.is_none() && !read_failed {
247                        remove_signed_cookie(&mut response, cookie_name, &cookie_config, key);
248                    }
249                }
250            }
251
252            Ok(response)
253        })
254    }
255}
256
257/// Read a signed cookie value from request headers.
258/// Returns `Some(SessionToken)` if the cookie exists, signature is valid, and hex decodes.
259fn read_signed_cookie(
260    headers: &http::HeaderMap,
261    cookie_name: &str,
262    key: &Key,
263) -> Option<SessionToken> {
264    let cookie_header = headers.get(http::header::COOKIE)?;
265    let cookie_str = cookie_header.to_str().ok()?;
266
267    for pair in cookie_str.split(';') {
268        let pair = pair.trim();
269        if let Some((name, value)) = pair.split_once('=')
270            && name.trim() == cookie_name
271        {
272            // Verify signature using cookie crate's signed jar
273            let mut jar = CookieJar::new();
274            jar.add_original(Cookie::new(
275                cookie_name.to_string(),
276                value.trim().to_string(),
277            ));
278            let verified = jar.signed(key).get(cookie_name)?;
279            return SessionToken::from_hex(verified.value()).ok();
280        }
281    }
282    None
283}
284
285/// Sign a cookie value and append Set-Cookie header to response.
286fn set_signed_cookie(
287    response: &mut Response<Body>,
288    name: &str,
289    value: &str,
290    max_age_secs: u64,
291    config: &CookieConfig,
292    key: &Key,
293) {
294    // Sign the value
295    let mut jar = CookieJar::new();
296    jar.signed_mut(key)
297        .add(Cookie::new(name.to_string(), value.to_string()));
298    let signed_value = jar
299        .get(name)
300        .expect("cookie was just added")
301        .value()
302        .to_string();
303
304    // Build Set-Cookie header with attributes
305    let same_site = match config.same_site.as_str() {
306        "strict" => SameSite::Strict,
307        "none" => SameSite::None,
308        _ => SameSite::Lax,
309    };
310    let set_cookie_str = Cookie::build((name.to_string(), signed_value))
311        .path("/")
312        .secure(config.secure)
313        .http_only(config.http_only)
314        .same_site(same_site)
315        .max_age(cookie::time::Duration::seconds(max_age_secs as i64))
316        .build()
317        .to_string();
318
319    match HeaderValue::from_str(&set_cookie_str) {
320        Ok(v) => {
321            response.headers_mut().append(http::header::SET_COOKIE, v);
322        }
323        Err(e) => {
324            tracing::error!(
325                cookie_name = name,
326                "failed to set session cookie header: {e}"
327            );
328        }
329    }
330}
331
332fn remove_signed_cookie(
333    response: &mut Response<Body>,
334    name: &str,
335    config: &CookieConfig,
336    key: &Key,
337) {
338    set_signed_cookie(response, name, "", 0, config, key);
339}