Skip to main content

modo_session/
middleware.rs

1//! Tower middleware layer that loads, validates, and persists sessions for every
2//! request.
3//!
4//! The public entry point is [`layer`], which wraps a [`SessionStore`] in a
5//! Tower [`Layer`] / [`Service`] pair. The middleware:
6//!
7//! 1. Reads the session token from the request cookie.
8//! 2. Loads the matching, non-expired session from the database.
9//! 3. Optionally validates a server-side request fingerprint.
10//! 4. Injects shared state so the [`crate::SessionManager`] extractor can
11//!    operate within the handler.
12//! 5. After the handler returns, applies any pending session action (set cookie,
13//!    remove cookie, or touch expiry).
14
15use crate::meta::{SessionMeta, extract_client_ip, header_str};
16use crate::store::SessionStore;
17use crate::types::{SessionData, SessionToken};
18use chrono::Utc;
19use futures_util::future::BoxFuture;
20use http::{Request, Response};
21use modo::axum::extract::connect_info::ConnectInfo;
22use modo::cookies::{CookieOptions, build_cookie};
23use std::net::SocketAddr;
24use std::sync::Arc;
25use std::task::{Context, Poll};
26use tokio::sync::Mutex;
27use tower::{Layer, Service};
28
29// --- Public types shared with SessionManager ---
30
31#[derive(Clone)]
32pub(crate) enum SessionAction {
33    None,
34    Set(SessionToken),
35    Remove,
36}
37
38pub(crate) struct SessionManagerState {
39    pub store: SessionStore,
40    pub current_session: Mutex<Option<SessionData>>,
41    pub meta: SessionMeta,
42    pub action: Mutex<SessionAction>,
43}
44
45// --- Layer ---
46
47/// Tower [`Layer`] produced by [`layer`].
48///
49/// Obtain via [`layer(store)`][layer] — do not construct directly.
50#[derive(Clone)]
51pub struct SessionContextLayer {
52    store: Arc<SessionStore>,
53}
54
55impl SessionContextLayer {
56    fn new(store: SessionStore) -> Self {
57        Self {
58            store: Arc::new(store),
59        }
60    }
61}
62
63impl<S> Layer<S> for SessionContextLayer {
64    type Service = SessionMiddleware<S>;
65
66    fn layer(&self, inner: S) -> Self::Service {
67        SessionMiddleware {
68            inner,
69            store: self.store.clone(),
70        }
71    }
72}
73
74/// Create a session middleware layer from a `SessionStore`.
75///
76/// Install the returned layer with `.layer(modo_session::layer(session_store))`
77/// on your app builder.  The layer must be present for [`crate::SessionManager`]
78/// to function as an extractor.
79pub fn layer(store: SessionStore) -> SessionContextLayer {
80    SessionContextLayer::new(store)
81}
82
83// --- Service ---
84
85/// Tower [`Service`] produced by [`SessionContextLayer`].
86///
87/// Handles per-request session loading, fingerprint validation, and cookie
88/// management.  Do not construct directly; use [`layer`] instead.
89#[derive(Clone)]
90pub struct SessionMiddleware<S> {
91    inner: S,
92    store: Arc<SessionStore>,
93}
94
95impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for SessionMiddleware<S>
96where
97    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
98    S::Future: Send + 'static,
99    ReqBody: Send + 'static,
100    ResBody: Default + Send + 'static,
101{
102    type Response = Response<ResBody>;
103    type Error = S::Error;
104    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
105
106    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
107        self.inner.poll_ready(cx)
108    }
109
110    fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
111        let store = self.store.clone();
112        let mut inner = self.inner.clone();
113
114        Box::pin(async move {
115            let config = store.config();
116            let cookie_name = &config.cookie_name;
117
118            // Extract meta from request headers
119            let connect_ip = request
120                .extensions()
121                .get::<ConnectInfo<SocketAddr>>()
122                .map(|ci| ci.0.ip());
123            let headers = request.headers();
124            let ip = extract_client_ip(headers, &config.trusted_proxies, connect_ip);
125            let ua = header_str(headers, "user-agent");
126            let accept_lang = header_str(headers, "accept-language");
127            let accept_enc = header_str(headers, "accept-encoding");
128            let meta = SessionMeta::from_headers(ip, ua, accept_lang, accept_enc);
129
130            // Read session token from cookie
131            let session_token = read_session_cookie(headers, cookie_name);
132            let had_cookie = session_token.is_some();
133
134            // Load session from store
135            let (current_session, read_failed) = if let Some(ref token) = session_token {
136                match store.read_by_token(token).await {
137                    Ok(session) => (session, false),
138                    Err(e) => {
139                        tracing::error!("Failed to read session: {e}");
140                        (None, true)
141                    }
142                }
143            } else {
144                (None, false)
145            };
146
147            // Validate fingerprint
148            let current_session = if let Some(session) = current_session {
149                if config.validate_fingerprint && meta.fingerprint != session.fingerprint {
150                    tracing::warn!(
151                        session_id = session.id.as_str(),
152                        user_id = session.user_id,
153                        "Session fingerprint mismatch — possible hijack, destroying session"
154                    );
155                    let _ = store.destroy(&session.id).await;
156                    None
157                } else {
158                    Some(session)
159                }
160            } else {
161                None
162            };
163
164            // Check if we need to touch
165            let should_touch = current_session.as_ref().is_some_and(|s| {
166                let elapsed = Utc::now() - s.last_active_at;
167                elapsed >= chrono::Duration::seconds(config.touch_interval_secs as i64)
168            });
169
170            // Build shared state for SessionManager
171            let manager_state = Arc::new(SessionManagerState {
172                store: (*store).clone(),
173                current_session: Mutex::new(current_session.clone()),
174                meta,
175                action: Mutex::new(SessionAction::None),
176            });
177
178            request.extensions_mut().insert(manager_state.clone());
179
180            // Run inner service
181            let mut response = inner.call(request).await?;
182
183            // Response path: apply session action
184            let action = {
185                let guard = manager_state.action.lock().await;
186                guard.clone()
187            };
188
189            let ttl_secs = config.session_ttl_secs;
190
191            match action {
192                SessionAction::Set(token) => {
193                    let opts = CookieOptions::from_config(store.cookie_config()).max_age(ttl_secs);
194                    append_cookie_header(&mut response, cookie_name, &token.as_hex(), &opts);
195                }
196                SessionAction::Remove => {
197                    // Max-Age=0 instructs the browser to delete the cookie
198                    let opts = CookieOptions::from_config(store.cookie_config()).max_age(0);
199                    append_cookie_header(&mut response, cookie_name, "", &opts);
200                }
201                SessionAction::None => {
202                    if should_touch && let Some(ref session) = current_session {
203                        let new_expires = Utc::now() + chrono::Duration::seconds(ttl_secs as i64);
204                        if let Err(e) = store.touch(&session.id, new_expires).await {
205                            tracing::error!(
206                                session_id = session.id.as_str(),
207                                "Failed to touch session: {e}"
208                            );
209                        } else if let Some(ref token) = session_token {
210                            let opts =
211                                CookieOptions::from_config(store.cookie_config()).max_age(ttl_secs);
212                            append_cookie_header(
213                                &mut response,
214                                cookie_name,
215                                &token.as_hex(),
216                                &opts,
217                            );
218                        }
219                    }
220
221                    // Remove stale cookie (session not found, but cookie existed)
222                    if had_cookie && current_session.is_none() && !read_failed {
223                        // Max-Age=0 instructs the browser to delete the cookie
224                        let opts = CookieOptions::from_config(store.cookie_config()).max_age(0);
225                        append_cookie_header(&mut response, cookie_name, "", &opts);
226                    }
227                }
228            }
229
230            Ok(response)
231        })
232    }
233}
234
235/// Extract the current user ID from request extensions without going through
236/// the full [`crate::SessionManager`] extractor.
237///
238/// Useful inside Tower layers that run after the session middleware but before
239/// (or instead of) a handler.  Uses `try_lock()` internally to avoid deadlocks
240/// when [`crate::SessionManager::set`] or [`crate::SessionManager::remove_key`]
241/// hold the mutex across `.await`.  Returns `None` if no session exists or if
242/// the lock is contended (logged at trace level).
243pub fn user_id_from_extensions(extensions: &http::Extensions) -> Option<String> {
244    extensions
245        .get::<Arc<SessionManagerState>>()
246        .and_then(|state| match state.current_session.try_lock() {
247            Ok(guard) => guard.as_ref().map(|s| s.user_id.clone()),
248            Err(_) => {
249                tracing::trace!("user_id_from_extensions: session lock contended, returning None");
250                None
251            }
252        })
253}
254
255// --- Cookie helpers ---
256
257fn read_session_cookie(headers: &http::HeaderMap, cookie_name: &str) -> Option<SessionToken> {
258    headers
259        .get_all(http::header::COOKIE)
260        .iter()
261        .find_map(|val| {
262            let val = val.to_str().ok()?;
263            for pair in val.split(';') {
264                let pair = pair.trim();
265                if let Some(value) = pair.strip_prefix(cookie_name) {
266                    let value = value.strip_prefix('=')?;
267                    return SessionToken::from_hex(value).ok();
268                }
269            }
270            None
271        })
272}
273
274fn append_cookie_header<B>(
275    response: &mut Response<B>,
276    name: &str,
277    value: &str,
278    opts: &CookieOptions,
279) {
280    let cookie = build_cookie(name, value, opts);
281    match http::HeaderValue::try_from(cookie.to_string()) {
282        Ok(val) => {
283            response.headers_mut().append(http::header::SET_COOKIE, val);
284        }
285        Err(e) => {
286            tracing::warn!(
287                cookie_name = name,
288                "Failed to serialize session cookie: {e}"
289            );
290        }
291    }
292}