Skip to main content

modo_session/
middleware.rs

1//! Tower middleware layer that loads, validates, and persists sessions for every
2//! request.
3//!
4//! The public entry point is [`layer`], which wraps a [`SessionStore`] in a
5//! Tower [`Layer`] / [`Service`] pair. The middleware:
6//!
7//! 1. Reads the session token from the request cookie.
8//! 2. Loads the matching, non-expired session from the database.
9//! 3. Optionally validates a server-side request fingerprint.
10//! 4. Injects shared state so the [`crate::SessionManager`] extractor can
11//!    operate within the handler.
12//! 5. After the handler returns, applies any pending session action (set cookie,
13//!    remove cookie, or touch expiry).
14
15use crate::meta::{SessionMeta, extract_client_ip, header_str};
16use crate::store::SessionStore;
17use crate::types::{SessionData, SessionToken};
18use chrono::Utc;
19use futures_util::future::BoxFuture;
20use http::{Request, Response};
21use modo::axum::extract::connect_info::ConnectInfo;
22use modo::cookies::{CookieOptions, build_cookie};
23use std::net::SocketAddr;
24use std::sync::Arc;
25use std::task::{Context, Poll};
26use tokio::sync::Mutex;
27use tower::{Layer, Service};
28
29// --- Public types shared with SessionManager ---
30
31#[derive(Clone)]
32pub(crate) enum SessionAction {
33    None,
34    Set(SessionToken),
35    Remove,
36}
37
38pub(crate) struct SessionManagerState {
39    pub store: SessionStore,
40    pub current_session: Mutex<Option<SessionData>>,
41    pub meta: SessionMeta,
42    pub action: Mutex<SessionAction>,
43}
44
45// --- Layer ---
46
47#[derive(Clone)]
48pub struct SessionLayer {
49    store: Arc<SessionStore>,
50}
51
52impl SessionLayer {
53    fn new(store: SessionStore) -> Self {
54        Self {
55            store: Arc::new(store),
56        }
57    }
58}
59
60impl<S> Layer<S> for SessionLayer {
61    type Service = SessionMiddleware<S>;
62
63    fn layer(&self, inner: S) -> Self::Service {
64        SessionMiddleware {
65            inner,
66            store: self.store.clone(),
67        }
68    }
69}
70
71/// Create a session middleware layer from a `SessionStore`.
72pub fn layer(store: SessionStore) -> SessionLayer {
73    SessionLayer::new(store)
74}
75
76// --- Service ---
77
78#[derive(Clone)]
79pub struct SessionMiddleware<S> {
80    inner: S,
81    store: Arc<SessionStore>,
82}
83
84impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for SessionMiddleware<S>
85where
86    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
87    S::Future: Send + 'static,
88    ReqBody: Send + 'static,
89    ResBody: Default + Send + 'static,
90{
91    type Response = Response<ResBody>;
92    type Error = S::Error;
93    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
94
95    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
96        self.inner.poll_ready(cx)
97    }
98
99    fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
100        let store = self.store.clone();
101        let mut inner = self.inner.clone();
102
103        Box::pin(async move {
104            let config = store.config();
105            let cookie_name = &config.cookie_name;
106
107            // Extract meta from request headers
108            let connect_ip = request
109                .extensions()
110                .get::<ConnectInfo<SocketAddr>>()
111                .map(|ci| ci.0.ip());
112            let headers = request.headers();
113            let ip = extract_client_ip(headers, &config.trusted_proxies, connect_ip);
114            let ua = header_str(headers, "user-agent");
115            let accept_lang = header_str(headers, "accept-language");
116            let accept_enc = header_str(headers, "accept-encoding");
117            let meta = SessionMeta::from_headers(ip, ua, accept_lang, accept_enc);
118
119            // Read session token from cookie
120            let session_token = read_session_cookie(headers, cookie_name);
121            let had_cookie = session_token.is_some();
122
123            // Load session from store
124            let (current_session, read_failed) = if let Some(ref token) = session_token {
125                match store.read_by_token(token).await {
126                    Ok(session) => (session, false),
127                    Err(e) => {
128                        tracing::error!("Failed to read session: {e}");
129                        (None, true)
130                    }
131                }
132            } else {
133                (None, false)
134            };
135
136            // Validate fingerprint
137            let current_session = if let Some(session) = current_session {
138                if config.validate_fingerprint && meta.fingerprint != session.fingerprint {
139                    tracing::warn!(
140                        session_id = session.id.as_str(),
141                        user_id = session.user_id,
142                        "Session fingerprint mismatch — possible hijack, destroying session"
143                    );
144                    let _ = store.destroy(&session.id).await;
145                    None
146                } else {
147                    Some(session)
148                }
149            } else {
150                None
151            };
152
153            // Check if we need to touch
154            let should_touch = current_session.as_ref().is_some_and(|s| {
155                let elapsed = Utc::now() - s.last_active_at;
156                elapsed >= chrono::Duration::seconds(config.touch_interval_secs as i64)
157            });
158
159            // Build shared state for SessionManager
160            let manager_state = Arc::new(SessionManagerState {
161                store: (*store).clone(),
162                current_session: Mutex::new(current_session.clone()),
163                meta,
164                action: Mutex::new(SessionAction::None),
165            });
166
167            request.extensions_mut().insert(manager_state.clone());
168
169            // Run inner service
170            let mut response = inner.call(request).await?;
171
172            // Response path: apply session action
173            let action = {
174                let guard = manager_state.action.lock().await;
175                guard.clone()
176            };
177
178            let ttl_secs = config.session_ttl_secs;
179
180            match action {
181                SessionAction::Set(token) => {
182                    let opts = CookieOptions::from_config(store.cookie_config()).max_age(ttl_secs);
183                    append_cookie_header(&mut response, cookie_name, &token.as_hex(), &opts);
184                }
185                SessionAction::Remove => {
186                    // Max-Age=0 instructs the browser to delete the cookie
187                    let opts = CookieOptions::from_config(store.cookie_config()).max_age(0);
188                    append_cookie_header(&mut response, cookie_name, "", &opts);
189                }
190                SessionAction::None => {
191                    if should_touch && let Some(ref session) = current_session {
192                        let new_expires = Utc::now() + chrono::Duration::seconds(ttl_secs as i64);
193                        if let Err(e) = store.touch(&session.id, new_expires).await {
194                            tracing::error!(
195                                session_id = session.id.as_str(),
196                                "Failed to touch session: {e}"
197                            );
198                        } else if let Some(ref token) = session_token {
199                            let opts =
200                                CookieOptions::from_config(store.cookie_config()).max_age(ttl_secs);
201                            append_cookie_header(
202                                &mut response,
203                                cookie_name,
204                                &token.as_hex(),
205                                &opts,
206                            );
207                        }
208                    }
209
210                    // Remove stale cookie (session not found, but cookie existed)
211                    if had_cookie && current_session.is_none() && !read_failed {
212                        // Max-Age=0 instructs the browser to delete the cookie
213                        let opts = CookieOptions::from_config(store.cookie_config()).max_age(0);
214                        append_cookie_header(&mut response, cookie_name, "", &opts);
215                    }
216                }
217            }
218
219            Ok(response)
220        })
221    }
222}
223
224/// Extract the current user ID from request extensions without going through
225/// the full `SessionManager` extractor. Useful for middleware/layers.
226///
227/// Uses `try_lock()` to avoid deadlocks when `SessionManager::set()` or
228/// `remove_key()` hold the mutex across `.await`. Returns `None` if no session
229/// exists or the lock is contended (logged at trace level).
230pub fn user_id_from_extensions(extensions: &http::Extensions) -> Option<String> {
231    extensions
232        .get::<Arc<SessionManagerState>>()
233        .and_then(|state| match state.current_session.try_lock() {
234            Ok(guard) => guard.as_ref().map(|s| s.user_id.clone()),
235            Err(_) => {
236                tracing::trace!("user_id_from_extensions: session lock contended, returning None");
237                None
238            }
239        })
240}
241
242// --- Cookie helpers ---
243
244fn read_session_cookie(headers: &http::HeaderMap, cookie_name: &str) -> Option<SessionToken> {
245    headers
246        .get_all(http::header::COOKIE)
247        .iter()
248        .find_map(|val| {
249            let val = val.to_str().ok()?;
250            for pair in val.split(';') {
251                let pair = pair.trim();
252                if let Some(value) = pair.strip_prefix(cookie_name) {
253                    let value = value.strip_prefix('=')?;
254                    return SessionToken::from_hex(value).ok();
255                }
256            }
257            None
258        })
259}
260
261fn append_cookie_header<B>(
262    response: &mut Response<B>,
263    name: &str,
264    value: &str,
265    opts: &CookieOptions,
266) {
267    let cookie = build_cookie(name, value, opts);
268    match http::HeaderValue::try_from(cookie.to_string()) {
269        Ok(val) => {
270            response.headers_mut().append(http::header::SET_COOKIE, val);
271        }
272        Err(e) => {
273            tracing::warn!(
274                cookie_name = name,
275                "Failed to serialize session cookie: {e}"
276            );
277        }
278    }
279}