Skip to main content

modo/auth/session/cookie/
middleware.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll};
6
7use axum::body::Body;
8use axum::extract::connect_info::ConnectInfo;
9use cookie::{Cookie, CookieJar, SameSite};
10use http::{HeaderValue, Request, Response};
11use tower::{Layer, Service};
12
13use crate::client::{ClientInfo, header_str};
14use crate::ip::ClientIp;
15
16use super::CookieSessionService;
17use super::extractor::{SessionAction, SessionState};
18use crate::auth::session::data::Session;
19use crate::auth::session::token::SessionToken;
20use crate::cookie::{CookieConfig, Key};
21
22// --- Layer ---
23
24/// Tower [`Layer`] that installs the session middleware into the request pipeline.
25///
26/// Construct with [`layer`] or [`CookieSessionService::layer`] rather than
27/// directly. Apply before route handlers with `Router::layer(session_layer)`.
28///
29/// The middleware reads the signed session cookie, loads the session from the
30/// database, validates the browser fingerprint (when configured), and inserts:
31/// - an `Arc<SessionState>` for the [`super::extractor::CookieSession`] extractor
32/// - a [`crate::auth::session::data::Session`] snapshot for the data extractor
33///
34/// On the response path it flushes dirty session data, touches the expiry
35/// timestamp, and sets or clears the session cookie as needed.
36#[derive(Clone)]
37pub struct CookieSessionLayer {
38    service: CookieSessionService,
39}
40
41/// Create a [`CookieSessionLayer`] from a [`CookieSessionService`].
42///
43/// Prefer [`CookieSessionService::layer`] in application code — this free
44/// function exists so integration tests and advanced callers can assemble the
45/// layer without borrowing the service.
46pub fn layer(service: CookieSessionService) -> CookieSessionLayer {
47    CookieSessionLayer { service }
48}
49
50impl<S> Layer<S> for CookieSessionLayer {
51    type Service = CookieSessionMiddleware<S>;
52
53    fn layer(&self, inner: S) -> Self::Service {
54        CookieSessionMiddleware {
55            inner,
56            service: self.service.clone(),
57        }
58    }
59}
60
61// --- Service ---
62
63/// Tower [`Service`] that manages the session lifecycle for each request.
64///
65/// Produced by [`CookieSessionLayer`]; not constructed directly.
66#[derive(Clone)]
67pub struct CookieSessionMiddleware<S> {
68    inner: S,
69    service: CookieSessionService,
70}
71
72impl<S, ReqBody> Service<Request<ReqBody>> for CookieSessionMiddleware<S>
73where
74    S: Service<Request<ReqBody>, Response = Response<Body>> + Clone + Send + 'static,
75    S::Future: Send + 'static,
76    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
77    ReqBody: Send + 'static,
78{
79    type Response = Response<Body>;
80    type Error = S::Error;
81    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
82
83    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
84        self.inner.poll_ready(cx)
85    }
86
87    fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
88        let svc = self.service.clone();
89        let mut inner = self.inner.clone();
90        std::mem::swap(&mut self.inner, &mut inner);
91
92        Box::pin(async move {
93            let store = svc.store();
94            let config = store.config();
95            let cookie_name = &config.cookie_name;
96            let key = svc.cookie_key();
97            let cookie_config = svc.config().cookie.clone();
98
99            let ip = request
100                .extensions()
101                .get::<ClientIp>()
102                .map(|c| c.0.to_string())
103                .unwrap_or_else(|| {
104                    request
105                        .extensions()
106                        .get::<ConnectInfo<std::net::SocketAddr>>()
107                        .map(|ci| ci.0.ip().to_string())
108                        .unwrap_or_else(|| "unknown".to_string())
109                });
110            let headers = request.headers();
111
112            let ua = header_str(headers, "user-agent");
113            let accept_lang = header_str(headers, "accept-language");
114            let accept_enc = header_str(headers, "accept-encoding");
115            let info = ClientInfo::from_headers(Some(ip), ua, accept_lang, accept_enc);
116
117            let session_token = read_signed_cookie(request.headers(), cookie_name, key);
118            let had_cookie = session_token.is_some();
119
120            let (current_session, read_failed) = if let Some(ref token) = session_token {
121                match store.read_by_token(token).await {
122                    Ok(session) => (session, false),
123                    Err(e) => {
124                        tracing::error!("failed to read session: {e}");
125                        (None, true)
126                    }
127                }
128            } else {
129                (None, false)
130            };
131
132            let current_session = if let Some(session) = current_session {
133                if config.validate_fingerprint
134                    && info.fingerprint_value() != Some(session.fingerprint.as_str())
135                {
136                    tracing::warn!(
137                        session_id = session.id,
138                        user_id = session.user_id,
139                        "session fingerprint mismatch — possible hijack, destroying session"
140                    );
141                    let _ = store.destroy(&session.id).await;
142                    None
143                } else {
144                    Some(session)
145                }
146            } else {
147                None
148            };
149
150            let should_touch = current_session.as_ref().is_some_and(|s| {
151                let elapsed = chrono::Utc::now() - s.last_active_at;
152                elapsed >= chrono::Duration::seconds(config.touch_interval_secs as i64)
153            });
154
155            if let Some(raw) = current_session.as_ref() {
156                let session_data = Session::from(raw.clone());
157                request.extensions_mut().insert(session_data);
158            }
159
160            let session_state = Arc::new(SessionState {
161                service: svc.clone(),
162                info,
163                current: Mutex::new(current_session.clone()),
164                dirty: AtomicBool::new(false),
165                action: Mutex::new(SessionAction::None),
166            });
167
168            request.extensions_mut().insert(session_state.clone());
169
170            let mut response = inner.call(request).await?;
171
172            let action = {
173                let guard = session_state.action.lock().expect("session mutex poisoned");
174                guard.clone()
175            };
176            let is_dirty = session_state.dirty.load(Ordering::SeqCst);
177            let ttl_secs = config.session_ttl_secs;
178
179            match action {
180                SessionAction::Set(token) => {
181                    set_signed_cookie(
182                        &mut response,
183                        cookie_name,
184                        &token.as_hex(),
185                        ttl_secs,
186                        &cookie_config,
187                        key,
188                    );
189                }
190                SessionAction::Remove => {
191                    remove_signed_cookie(&mut response, cookie_name, &cookie_config, key);
192                }
193                SessionAction::None => {
194                    if let Some(ref session) = current_session {
195                        let now = chrono::Utc::now();
196                        let new_expires = now + chrono::Duration::seconds(ttl_secs as i64);
197
198                        if is_dirty {
199                            let data = {
200                                let guard = session_state
201                                    .current
202                                    .lock()
203                                    .expect("session mutex poisoned");
204                                guard.as_ref().map(|s| s.data.clone())
205                            };
206                            if let Some(data) = data
207                                && let Err(e) =
208                                    store.flush(&session.id, &data, now, new_expires).await
209                            {
210                                tracing::error!(
211                                    session_id = session.id,
212                                    "failed to flush session data: {e}"
213                                );
214                            }
215                        } else if should_touch
216                            && let Err(e) = store.touch(&session.id, now, new_expires).await
217                        {
218                            tracing::error!(
219                                session_id = session.id,
220                                "failed to touch session: {e}"
221                            );
222                        }
223
224                        if (is_dirty || should_touch)
225                            && let Some(ref token) = session_token
226                        {
227                            set_signed_cookie(
228                                &mut response,
229                                cookie_name,
230                                &token.as_hex(),
231                                ttl_secs,
232                                &cookie_config,
233                                key,
234                            );
235                        }
236                    }
237
238                    if had_cookie && current_session.is_none() && !read_failed {
239                        remove_signed_cookie(&mut response, cookie_name, &cookie_config, key);
240                    }
241                }
242            }
243
244            Ok(response)
245        })
246    }
247}
248
249/// Read a signed cookie value from request headers.
250/// Returns `Some(SessionToken)` if the cookie exists, signature is valid, and hex decodes.
251fn read_signed_cookie(
252    headers: &http::HeaderMap,
253    cookie_name: &str,
254    key: &Key,
255) -> Option<SessionToken> {
256    let cookie_header = headers.get(http::header::COOKIE)?;
257    let cookie_str = cookie_header.to_str().ok()?;
258
259    for pair in cookie_str.split(';') {
260        let pair = pair.trim();
261        if let Some((name, value)) = pair.split_once('=')
262            && name.trim() == cookie_name
263        {
264            let mut jar = CookieJar::new();
265            jar.add_original(Cookie::new(
266                cookie_name.to_string(),
267                value.trim().to_string(),
268            ));
269            let verified = jar.signed(key).get(cookie_name)?;
270            return SessionToken::from_hex(verified.value()).ok();
271        }
272    }
273    None
274}
275
276/// Sign a cookie value and append Set-Cookie header to response.
277fn set_signed_cookie(
278    response: &mut Response<Body>,
279    name: &str,
280    value: &str,
281    max_age_secs: u64,
282    config: &CookieConfig,
283    key: &Key,
284) {
285    // Sign the value
286    let mut jar = CookieJar::new();
287    jar.signed_mut(key)
288        .add(Cookie::new(name.to_string(), value.to_string()));
289    let signed_value = jar
290        .get(name)
291        .expect("cookie was just added")
292        .value()
293        .to_string();
294
295    let same_site = match config.same_site.as_str() {
296        "strict" => SameSite::Strict,
297        "none" => SameSite::None,
298        _ => SameSite::Lax,
299    };
300    let set_cookie_str = Cookie::build((name.to_string(), signed_value))
301        .path("/")
302        .secure(config.secure)
303        .http_only(config.http_only)
304        .same_site(same_site)
305        .max_age(cookie::time::Duration::seconds(max_age_secs as i64))
306        .build()
307        .to_string();
308
309    match HeaderValue::from_str(&set_cookie_str) {
310        Ok(v) => {
311            response.headers_mut().append(http::header::SET_COOKIE, v);
312        }
313        Err(e) => {
314            tracing::error!(
315                cookie_name = name,
316                "failed to set session cookie header: {e}"
317            );
318        }
319    }
320}
321
322fn remove_signed_cookie(
323    response: &mut Response<Body>,
324    name: &str,
325    config: &CookieConfig,
326    key: &Key,
327) {
328    set_signed_cookie(response, name, "", 0, config, key);
329}