Skip to main content

modo/auth/session/
middleware.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll};
6
7use axum::body::Body;
8use axum::extract::connect_info::ConnectInfo;
9use cookie::{Cookie, CookieJar, SameSite};
10use http::{HeaderValue, Request, Response};
11use tower::{Layer, Service};
12
13use crate::cookie::{CookieConfig, Key};
14use crate::ip::ClientIp;
15
16use super::extractor::{SessionAction, SessionState};
17use super::meta::{SessionMeta, header_str};
18use super::store::Store;
19use super::token::SessionToken;
20
21// --- Layer ---
22
23/// Tower [`Layer`] that installs the session middleware into the request pipeline.
24///
25/// Construct with [`layer`] rather than directly. Apply before route handlers
26/// with `Router::layer(session_layer)`.
27///
28/// The middleware reads the signed session cookie, loads the session from the
29/// database, validates the browser fingerprint (when configured), and inserts
30/// an `Arc<SessionState>` into the request extensions so the [`super::extractor::Session`]
31/// extractor can access it.
32///
33/// On the response path it flushes dirty session data, touches the expiry
34/// timestamp, and sets or clears the session cookie as needed.
35#[derive(Clone)]
36pub struct SessionLayer {
37    store: Arc<Store>,
38    cookie_config: CookieConfig,
39    key: Key,
40}
41
42/// Create a [`SessionLayer`] from a [`Store`], [`CookieConfig`], and signing [`Key`].
43///
44/// # Example
45///
46/// ```rust,no_run
47/// use modo::auth::session::{self, SessionConfig, Store};
48/// use modo::cookie::{CookieConfig, key_from_config};
49/// use modo::db::Database;
50///
51/// # async fn example(db: Database) -> modo::Result<()> {
52/// let store = Store::new(db, SessionConfig::default());
53/// let cookie_config: CookieConfig = todo!("load from config");
54/// let key = key_from_config(&cookie_config)?;
55/// let session_layer = session::layer(store, &cookie_config, &key);
56/// # Ok(())
57/// # }
58/// ```
59pub fn layer(store: Store, cookie_config: &CookieConfig, key: &Key) -> SessionLayer {
60    SessionLayer {
61        store: Arc::new(store),
62        cookie_config: cookie_config.clone(),
63        key: key.clone(),
64    }
65}
66
67impl<S> Layer<S> for SessionLayer {
68    type Service = SessionMiddleware<S>;
69
70    fn layer(&self, inner: S) -> Self::Service {
71        SessionMiddleware {
72            inner,
73            store: self.store.clone(),
74            cookie_config: self.cookie_config.clone(),
75            key: self.key.clone(),
76        }
77    }
78}
79
80// --- Service ---
81
82/// Tower [`Service`] that manages the session lifecycle for each request.
83///
84/// Produced by [`SessionLayer`]; not constructed directly.
85#[derive(Clone)]
86pub struct SessionMiddleware<S> {
87    inner: S,
88    store: Arc<Store>,
89    cookie_config: CookieConfig,
90    key: Key,
91}
92
93impl<S, ReqBody> Service<Request<ReqBody>> for SessionMiddleware<S>
94where
95    S: Service<Request<ReqBody>, Response = Response<Body>> + Clone + Send + 'static,
96    S::Future: Send + 'static,
97    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
98    ReqBody: Send + 'static,
99{
100    type Response = Response<Body>;
101    type Error = S::Error;
102    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
103
104    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
105        self.inner.poll_ready(cx)
106    }
107
108    fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
109        let store = self.store.clone();
110        let cookie_config = self.cookie_config.clone();
111        let key = self.key.clone();
112        let mut inner = self.inner.clone();
113        std::mem::swap(&mut self.inner, &mut inner);
114
115        Box::pin(async move {
116            let config = store.config();
117            let cookie_name = &config.cookie_name;
118
119            // 1. Extract client IP
120            let ip = request
121                .extensions()
122                .get::<ClientIp>()
123                .map(|c| c.0.to_string())
124                .unwrap_or_else(|| {
125                    // Fallback: no ClientIpLayer applied — use ConnectInfo directly
126                    request
127                        .extensions()
128                        .get::<ConnectInfo<std::net::SocketAddr>>()
129                        .map(|ci| ci.0.ip().to_string())
130                        .unwrap_or_else(|| "unknown".to_string())
131                });
132            let headers = request.headers();
133
134            // 2. Build SessionMeta
135            let ua = header_str(headers, "user-agent");
136            let accept_lang = header_str(headers, "accept-language");
137            let accept_enc = header_str(headers, "accept-encoding");
138            let meta = SessionMeta::from_headers(ip, ua, accept_lang, accept_enc);
139
140            // 3. Read signed session cookie
141            let session_token = read_signed_cookie(request.headers(), cookie_name, &key);
142            let had_cookie = session_token.is_some();
143
144            // 4. Load session from DB
145            let (current_session, read_failed) = if let Some(ref token) = session_token {
146                match store.read_by_token(token).await {
147                    Ok(session) => (session, false),
148                    Err(e) => {
149                        tracing::error!("failed to read session: {e}");
150                        (None, true)
151                    }
152                }
153            } else {
154                (None, false)
155            };
156
157            // 5/6. Validate fingerprint
158            let current_session = if let Some(session) = current_session {
159                if config.validate_fingerprint && meta.fingerprint != session.fingerprint {
160                    tracing::warn!(
161                        session_id = session.id,
162                        user_id = session.user_id,
163                        "session fingerprint mismatch — possible hijack, destroying session"
164                    );
165                    let _ = store.destroy(&session.id).await;
166                    None
167                } else {
168                    Some(session)
169                }
170            } else {
171                None
172            };
173
174            // Check if touch interval elapsed
175            let should_touch = current_session.as_ref().is_some_and(|s| {
176                let elapsed = chrono::Utc::now() - s.last_active_at;
177                elapsed >= chrono::Duration::seconds(config.touch_interval_secs as i64)
178            });
179
180            // 7. Build SessionState
181            let session_state = Arc::new(SessionState {
182                store: (*store).clone(),
183                meta,
184                current: Mutex::new(current_session.clone()),
185                dirty: AtomicBool::new(false),
186                action: Mutex::new(SessionAction::None),
187            });
188
189            request.extensions_mut().insert(session_state.clone());
190
191            // Run inner service
192            let mut response = inner.call(request).await?;
193
194            // --- Response path ---
195
196            let action = {
197                let guard = session_state.action.lock().expect("session mutex poisoned");
198                guard.clone()
199            };
200            let is_dirty = session_state.dirty.load(Ordering::SeqCst);
201            let ttl_secs = config.session_ttl_secs;
202
203            match action {
204                SessionAction::Set(token) => {
205                    set_signed_cookie(
206                        &mut response,
207                        cookie_name,
208                        &token.as_hex(),
209                        ttl_secs,
210                        &cookie_config,
211                        &key,
212                    );
213                }
214                SessionAction::Remove => {
215                    remove_signed_cookie(&mut response, cookie_name, &cookie_config, &key);
216                }
217                SessionAction::None => {
218                    if let Some(ref session) = current_session {
219                        let now = chrono::Utc::now();
220                        let new_expires = now + chrono::Duration::seconds(ttl_secs as i64);
221
222                        if is_dirty {
223                            let data = {
224                                let guard = session_state
225                                    .current
226                                    .lock()
227                                    .expect("session mutex poisoned");
228                                guard.as_ref().map(|s| s.data.clone())
229                            };
230                            if let Some(data) = data
231                                && let Err(e) =
232                                    store.flush(&session.id, &data, now, new_expires).await
233                            {
234                                tracing::error!(
235                                    session_id = session.id,
236                                    "failed to flush session data: {e}"
237                                );
238                            }
239                        } else if should_touch
240                            && let Err(e) = store.touch(&session.id, now, new_expires).await
241                        {
242                            tracing::error!(
243                                session_id = session.id,
244                                "failed to touch session: {e}"
245                            );
246                        }
247
248                        // Refresh cookie if we did a flush or touch
249                        if (is_dirty || should_touch)
250                            && let Some(ref token) = session_token
251                        {
252                            set_signed_cookie(
253                                &mut response,
254                                cookie_name,
255                                &token.as_hex(),
256                                ttl_secs,
257                                &cookie_config,
258                                &key,
259                            );
260                        }
261                    }
262
263                    // Stale cookie cleanup
264                    if had_cookie && current_session.is_none() && !read_failed {
265                        remove_signed_cookie(&mut response, cookie_name, &cookie_config, &key);
266                    }
267                }
268            }
269
270            Ok(response)
271        })
272    }
273}
274
275/// Read a signed cookie value from request headers.
276/// Returns `Some(SessionToken)` if the cookie exists, signature is valid, and hex decodes.
277fn read_signed_cookie(
278    headers: &http::HeaderMap,
279    cookie_name: &str,
280    key: &Key,
281) -> Option<SessionToken> {
282    let cookie_header = headers.get(http::header::COOKIE)?;
283    let cookie_str = cookie_header.to_str().ok()?;
284
285    for pair in cookie_str.split(';') {
286        let pair = pair.trim();
287        if let Some((name, value)) = pair.split_once('=')
288            && name.trim() == cookie_name
289        {
290            // Verify signature using cookie crate's signed jar
291            let mut jar = CookieJar::new();
292            jar.add_original(Cookie::new(
293                cookie_name.to_string(),
294                value.trim().to_string(),
295            ));
296            let verified = jar.signed(key).get(cookie_name)?;
297            return SessionToken::from_hex(verified.value()).ok();
298        }
299    }
300    None
301}
302
303/// Sign a cookie value and append Set-Cookie header to response.
304fn set_signed_cookie(
305    response: &mut Response<Body>,
306    name: &str,
307    value: &str,
308    max_age_secs: u64,
309    config: &CookieConfig,
310    key: &Key,
311) {
312    // Sign the value
313    let mut jar = CookieJar::new();
314    jar.signed_mut(key)
315        .add(Cookie::new(name.to_string(), value.to_string()));
316    let signed_value = jar
317        .get(name)
318        .expect("cookie was just added")
319        .value()
320        .to_string();
321
322    // Build Set-Cookie header with attributes
323    let same_site = match config.same_site.as_str() {
324        "strict" => SameSite::Strict,
325        "none" => SameSite::None,
326        _ => SameSite::Lax,
327    };
328    let set_cookie_str = Cookie::build((name.to_string(), signed_value))
329        .path("/")
330        .secure(config.secure)
331        .http_only(config.http_only)
332        .same_site(same_site)
333        .max_age(cookie::time::Duration::seconds(max_age_secs as i64))
334        .build()
335        .to_string();
336
337    match HeaderValue::from_str(&set_cookie_str) {
338        Ok(v) => {
339            response.headers_mut().append(http::header::SET_COOKIE, v);
340        }
341        Err(e) => {
342            tracing::error!(
343                cookie_name = name,
344                "failed to set session cookie header: {e}"
345            );
346        }
347    }
348}
349
350fn remove_signed_cookie(
351    response: &mut Response<Body>,
352    name: &str,
353    config: &CookieConfig,
354    key: &Key,
355) {
356    set_signed_cookie(response, name, "", 0, config, key);
357}