Skip to main content

axess_core/session/
extractor.rs

1//! Axum request extractor providing typed, mutable session access.
2//!
3//! [`AuthSession`] is the primary API surface for handlers. Changes are flushed
4//! to the session store by the [`SessionLayer`](super::layer::SessionLayer) middleware after the handler returns.
5
6use crate::authn::factor::FactorKind;
7use crate::authn::ids::{TenantId, UserId};
8use crate::session::{
9    data::{AdvanceOutcome, AuthState, SessionData, WorkflowState},
10    id::SessionId,
11    layer::SessionHandle,
12};
13use axum::extract::FromRequestParts;
14use axum::http::request::Parts;
15use chrono::{DateTime, Utc};
16use std::sync::Arc;
17
18/// Axum request extractor providing typed, mutable session access.
19///
20/// Zero generic parameters: wraps the `SessionHandle` inserted by [`SessionLayer`](super::layer::SessionLayer).
21/// Obtain one in a handler by listing it as a parameter:
22///
23/// ```text
24/// async fn my_handler(session: AuthSession) -> impl IntoResponse { ... }
25/// ```
26///
27/// Changes are committed to the store automatically when the response is sent.
28#[derive(Clone)]
29pub struct AuthSession(pub(crate) SessionHandle);
30
31/// Identity bundle returned by [`AuthSession::snapshot`]. All
32/// four fields are read under a single `RwLock::read()` acquisition.
33#[derive(Clone, Debug)]
34pub struct AuthSnapshot {
35    /// Identifier of the authenticated user.
36    pub user_id: UserId,
37    /// Tenant the user belongs to.
38    pub tenant_id: TenantId,
39    /// Stable session identifier for this snapshot.
40    pub session_id: SessionId,
41    /// Wall-clock time at which authentication completed.
42    pub authn_time: DateTime<Utc>,
43}
44
45impl AuthSession {
46    /// Return the authenticated user ID, if any.
47    ///
48    /// **Prefer [`snapshot`](Self::snapshot)** when the handler needs more
49    /// than one identity field; a single `snapshot().await` acquires the
50    /// read lock once and returns `user_id` + `tenant_id` + `session_id` +
51    /// `authn_time` together. Calling `user_id().await` followed by
52    /// `tenant_id().await` costs two lock acquisitions and two `clone()`s
53    /// for the same information.
54    #[doc(hidden)]
55    pub async fn user_id(&self) -> Option<UserId> {
56        self.0.0.read().await.data.auth_state.user_id().cloned()
57    }
58
59    /// Return the tenant ID, if any.
60    ///
61    /// See [`user_id`](Self::user_id) for why [`snapshot`](Self::snapshot)
62    /// is the recommended way to read identity fields on the hot path.
63    #[doc(hidden)]
64    pub async fn tenant_id(&self) -> Option<TenantId> {
65        self.0.0.read().await.data.auth_state.tenant_id().cloned()
66    }
67
68    /// Return `true` if the session is fully authenticated.
69    pub async fn is_authenticated(&self) -> bool {
70        self.0.0.read().await.data.auth_state.is_authenticated()
71    }
72
73    /// Clone the current authentication state enum (cheap; fields are `Arc<str>`).
74    pub async fn auth_state(&self) -> AuthState {
75        self.0.0.read().await.data.auth_state.clone()
76    }
77
78    /// Destructure an [`AuthState::Authenticating`] session into its
79    /// `(user_id, tenant_id, remaining_factors)` triple. Returns `None`
80    /// for any other [`AuthState`] variant (Guest / Identifying /
81    /// Authenticated / PendingWorkflow).
82    ///
83    /// Service-side callers in the multi-factor pipeline (`begin_login`,
84    /// `prepare_factor`, `verify_factor`) all need this triple to make
85    /// progress and treat any other state as the canonical
86    /// [`AuthnError::NoFlow`](crate::authn::error::AuthnError::NoFlow),
87    /// typically `session.authenticating_state().await.ok_or(AuthnError::NoFlow)?`.
88    /// This destructuring was lifted out of `AuthnService` because it
89    /// has no service dependency; it belongs alongside the other
90    /// session-state accessors here.
91    ///
92    /// Single-acquisition: the read lock is taken once and the entire
93    /// triple is built under it, matching [`snapshot`](Self::snapshot)'s
94    /// shape rather than the older "`auth_state().await` then match"
95    /// pattern that incurred two lock acquisitions and an extra
96    /// `AuthState::clone`.
97    pub async fn authenticating_state(&self) -> Option<(UserId, TenantId, Vec<FactorKind>)> {
98        let guard = self.0.0.read().await;
99        match &guard.data.auth_state {
100            AuthState::Authenticating {
101                user_id,
102                tenant_id,
103                remaining,
104                ..
105            } => Some((*user_id, *tenant_id, remaining.clone())),
106            _ => None,
107        }
108    }
109
110    /// Return the session ID.
111    pub async fn session_id(&self) -> SessionId {
112        self.0.0.read().await.id
113    }
114
115    /// Return the resolved [`crate::authn::ids::DeviceId`] for this session, if the device
116    /// subsystem populated it. `None` when the `device` feature is
117    /// disabled, when the resolver did not run for this request, or when
118    /// the request did not match a known device.
119    ///
120    /// Single-acquisition read; cheap when the device subsystem is in
121    /// use because `DeviceId` is a 16-byte `Copy` UUID newtype.
122    pub async fn device_id(&self) -> Option<crate::authn::ids::DeviceId> {
123        self.0.0.read().await.data.device_id
124    }
125
126    /// Return a clone of the full session data.
127    pub async fn data(&self) -> SessionData {
128        self.0.0.read().await.data.clone()
129    }
130
131    /// Invoke `f` against a borrow of the current [`AuthState`]
132    /// without cloning. Use this when the caller only needs to read a
133    /// few fields. Avoids the full `AuthState::clone()` that
134    /// [`auth_state`](Self::auth_state) performs (which clones every
135    /// `Arc<str>` plus the workflow state). For `AuthState::Authenticated`
136    /// in particular this halves the per-request lock-held work in
137    /// authenticated handlers that just need the user id.
138    ///
139    /// The closure runs while the session read lock is held, so do not
140    /// `await` inside it on something that needs the same lock.
141    pub async fn with_auth_state<F, T>(&self, f: F) -> T
142    where
143        F: FnOnce(&AuthState) -> T,
144    {
145        let guard = self.0.0.read().await;
146        f(&guard.data.auth_state)
147    }
148
149    /// Invoke `f` against a borrow of the full [`SessionData`]
150    /// without cloning. Useful when reading several fields together,
151    /// where `data()` would copy more than necessary. Same lock-held
152    /// caveat as [`with_auth_state`](Self::with_auth_state).
153    pub async fn with_data<F, T>(&self, f: F) -> T
154    where
155        F: FnOnce(&SessionData) -> T,
156    {
157        let guard = self.0.0.read().await;
158        f(&guard.data)
159    }
160
161    /// Read all the common identity fields under a *single*
162    /// read-lock acquisition. The previous pattern of calling
163    /// `is_authenticated()` then `user_id()` then `tenant_id()` then
164    /// `session_id()` from a typical authenticated handler took four
165    /// `RwLock` reads per request: most contention-free, but each is a
166    /// cross-await synchronization point. `snapshot()` collapses them
167    /// into one acquisition.
168    ///
169    /// Returns `None` when the session is not authenticated.
170    pub async fn snapshot(&self) -> Option<AuthSnapshot> {
171        let guard = self.0.0.read().await;
172        match &guard.data.auth_state {
173            AuthState::Authenticated {
174                user_id,
175                tenant_id,
176                authn_time,
177                ..
178            } => Some(AuthSnapshot {
179                user_id: *user_id,
180                tenant_id: *tenant_id,
181                session_id: guard.id,
182                authn_time: *authn_time,
183            }),
184            _ => None,
185        }
186    }
187
188    // ── State mutation helpers ─────────────────────────────────────────────────
189
190    /// Mark the session as fully authenticated.
191    ///
192    /// Cycles the session id eagerly so handler code that subsequently calls
193    /// `session.session_id().await` (e.g. to register against a
194    /// `SessionRegistry`) reads the post-rotation id, the same id the
195    /// cookie will carry on the next request.
196    pub async fn set_authenticated(
197        &self,
198        user_id: UserId,
199        tenant_id: TenantId,
200        authn_time: DateTime<Utc>,
201    ) {
202        let mut guard = self.0.0.write().await;
203        // Pure state mutation lives on AuthState. Direct
204        // authenticated transition (impersonation, OAuth callback, etc.)
205        // (no specific factor sequence to record).
206        guard
207            .data
208            .auth_state
209            .set_authenticated(user_id, tenant_id, authn_time);
210        // Bind the session fingerprint immediately so it's set before the
211        // response leaves the server (prevents hijack-before-binding window).
212        if guard.data.fingerprint.is_none()
213            && let Some(fp) = guard.pending_fingerprint.take()
214        {
215            guard.data.fingerprint = Some(fp);
216        }
217        guard.modified = true;
218        guard.rotate_id();
219    }
220
221    /// Begin a multi-factor authentication flow.
222    ///
223    /// Sets the state to [`AuthState::Authenticating`] with the given factors in order.
224    pub(crate) async fn begin_authenticating(
225        &self,
226        user_id: UserId,
227        tenant_id: TenantId,
228        method_name: Arc<str>,
229        factors: Vec<FactorKind>,
230    ) {
231        let mut guard = self.0.0.write().await;
232        // Pure state mutation lives on AuthState.
233        guard
234            .data
235            .auth_state
236            .begin_authenticating(user_id, tenant_id, method_name, factors);
237        // Bind fingerprint early; during Authenticating state; so that
238        // mid-MFA sessions are also protected against hijacking.
239        if guard.data.fingerprint.is_none()
240            && let Some(fp) = guard.pending_fingerprint.take()
241        {
242            guard.data.fingerprint = Some(fp);
243        }
244        guard.modified = true;
245    }
246
247    /// Advance a multi-factor flow by removing `kind` from the remaining list.
248    ///
249    /// If `remaining` becomes empty after removal, transitions to
250    /// [`AuthState::Authenticated`] automatically.
251    pub(crate) async fn advance_factor(&self, kind: &FactorKind, authn_time: DateTime<Utc>) {
252        let mut guard = self.0.0.write().await;
253        // Pure state transition lives on AuthState; session-level
254        // orchestration (fingerprint binding, id rotation, dirty flag) is
255        // dispatched on the outcome.
256        match guard.data.auth_state.advance_factor(kind, authn_time) {
257            AdvanceOutcome::Completed => {
258                // Bind fingerprint immediately on auth completion.
259                if guard.data.fingerprint.is_none()
260                    && let Some(fp) = guard.pending_fingerprint.take()
261                {
262                    guard.data.fingerprint = Some(fp);
263                }
264                // Cycle the id NOW so handler code that runs after
265                // advance_factor (e.g. `complete_factor_step` →
266                // `reg.register(user, sid)`) sees the post-rotation id.
267                guard.rotate_id();
268                guard.modified = true;
269            }
270            AdvanceOutcome::StillAuthenticating => {
271                guard.modified = true;
272            }
273            AdvanceOutcome::NotApplicable => {
274                // No-op if not in Authenticating state.
275            }
276        }
277    }
278
279    /// Record a failed attempt at the given time (for UI display / rate-limit feedback).
280    ///
281    /// Callers should supply `clock.now()` rather than `Utc::now()` so that
282    /// deterministic simulation tests control the timestamp.
283    ///
284    /// **This does not enforce lockout**; lockout is based exclusively on the
285    /// DB counter returned by `IdentityStore::record_failed_attempt`.
286    ///
287    /// The in-memory state is updated so the *current* response can
288    /// surface attempt count / last-attempt to the UI, but we deliberately
289    /// do NOT mark the session as modified. Persisting on every wrong
290    /// password forces a full SessionStore save per failed login; net
291    /// load on Valkey/Postgres scales with brute-force traffic, exactly
292    /// the worst time to add round-trips. Applications that need
293    /// cross-request "attempts remaining" UX should read the
294    /// authoritative counter via `IdentityStore::account_status` /
295    /// `record_failed_attempt`'s return value, not session state.
296    pub(crate) async fn record_attempt_at(&self, now: DateTime<Utc>) {
297        let mut guard = self.0.0.write().await;
298        // Pure state mutation lives on AuthState.
299        // Intentionally do NOT set `guard.modified = true`.
300        guard.data.auth_state.record_attempt_at(now);
301    }
302
303    /// Return `(user_id, tenant_id)` if the session is fully authenticated, `None` otherwise.
304    ///
305    /// **Prefer [`snapshot`](Self::snapshot)**; it returns the full
306    /// identity bundle (`user_id` + `tenant_id` + `session_id` +
307    /// `authn_time`) under one lock acquisition. `authenticated_ids` is
308    /// retained for the narrow case where the caller really only wants
309    /// the pair.
310    #[doc(hidden)]
311    pub async fn authenticated_ids(&self) -> Option<(UserId, TenantId)> {
312        let guard = self.0.0.read().await;
313        match &guard.data.auth_state {
314            AuthState::Authenticated {
315                user_id, tenant_id, ..
316            } => Some((*user_id, *tenant_id)),
317            _ => None,
318        }
319    }
320
321    /// Enter the identifying state (user has typed their username).
322    ///
323    /// Binds the session fingerprint early so that even pre-MFA sessions
324    /// are protected against cross-device replay.
325    pub async fn set_identifying(&self, user_id: UserId, tenant_id: TenantId) {
326        let mut guard = self.0.0.write().await;
327        // Pure state mutation lives on AuthState.
328        guard.data.auth_state.set_identifying(user_id, tenant_id);
329        // Bind fingerprint as early as possible (during Identifying state)
330        // so a stolen session cookie cannot be replayed from a different device.
331        if guard.data.fingerprint.is_none()
332            && let Some(fp) = guard.pending_fingerprint.take()
333        {
334            guard.data.fingerprint = Some(fp);
335        }
336        guard.modified = true;
337    }
338
339    /// Transition to a pending workflow state.
340    pub async fn set_pending_workflow(
341        &self,
342        user_id: UserId,
343        tenant_id: TenantId,
344        workflow: WorkflowState,
345    ) {
346        let mut guard = self.0.0.write().await;
347        // Pure state mutation lives on AuthState.
348        guard
349            .data
350            .auth_state
351            .set_pending_workflow(user_id, tenant_id, workflow);
352        guard.modified = true;
353    }
354
355    /// Clear the session (logout). Resets state to `Guest` and marks as modified.
356    ///
357    /// The caller should regenerate the session ID separately to prevent session fixation.
358    pub async fn clear(&self) {
359        let mut guard = self.0.0.write().await;
360        guard.data = SessionData::default();
361        guard.modified = true;
362    }
363
364    /// Cycle the session ID immediately.
365    ///
366    /// Mints a new id, swaps it in, and stashes the old id for the layer to
367    /// `store.cycle` on the way out. Idempotent within a single request;
368    /// calling twice still rotates only once.
369    ///
370    /// Call at any **privilege boundary**; i.e. any change to the
371    /// session's authentication context, scope, or subject identity.
372    /// Login is already auto-cycled; the rest is the app's call.
373    /// Canonical list (OWASP ASVS V3, OWASP Session Management Cheat
374    /// Sheet, NIST SP 800-63B AAL transitions):
375    ///
376    /// - MFA factor added (TOTP, WebAuthn, recovery codes, …)
377    /// - MFA factor removed or disabled (AAL drops)
378    /// - Password / primary-credential change
379    /// - Step-up to a higher assurance level
380    /// - Account-recovery flow completion
381    /// - Impersonation start / stop
382    /// - Role grant / revoke or scope change
383    /// - Tenant switch in a multi-tenant deployment
384    ///
385    /// Do *not* call on routine writes (profile edit, factor config
386    /// tuning, theme change); rotating churns the cookie and the
387    /// store without any security benefit.
388    ///
389    /// On credential changes (password change, recovery completion)
390    /// consider also revoking sibling sessions via
391    /// `SessionRegistry::revoke_user_sessions`; that is a strictly
392    /// stronger statement than rotation and cuts off other devices
393    /// still holding a stale credential-derived cookie.
394    ///
395    /// See `docs/sessions/lifecycle.md` for the full rationale and
396    /// the "library can't hook this for you" discussion.
397    pub async fn regenerate(&self) {
398        let mut guard = self.0.0.write().await;
399        guard.rotate_id();
400        guard.modified = true;
401    }
402
403    /// Read a value from the custom JSON bag.
404    pub async fn get_custom(&self, key: &str) -> Option<serde_json::Value> {
405        self.0.0.read().await.data.custom.get(key).cloned()
406    }
407
408    /// Wipe the entire custom JSON bag in one shot. Used at privilege
409    /// boundaries (e.g., admin impersonation) where the previous session's
410    /// app-controlled custom data must not leak into the assumed identity:
411    /// pre-seeded OAuth ceremony state in `custom` could otherwise be used
412    /// to hijack a subsequent OAuth flow under the new principal.
413    pub async fn clear_custom(&self) {
414        let mut guard = self.0.0.write().await;
415        guard.data.custom = serde_json::Value::Object(serde_json::Map::new());
416        guard.modified = true;
417    }
418
419    /// Remove a key from the custom JSON bag. Returns `true` if a key was
420    /// removed. Use this rather than `set_custom(k, Value::Null)` so the
421    /// JSON object stays compact; repeated nulled-out keys would otherwise
422    /// monotonically grow until the size cap clears the whole bag.
423    pub async fn remove_custom(&self, key: &str) -> bool {
424        let mut guard = self.0.0.write().await;
425        let removed = guard
426            .data
427            .custom
428            .as_object_mut()
429            .is_some_and(|obj| obj.remove(key).is_some());
430        if removed {
431            guard.modified = true;
432        }
433        removed
434    }
435
436    /// Run `f` against a mutable borrow of the custom-bag JSON
437    /// object under a single write-lock acquisition. Returns the value
438    /// produced by `f`. Marks the session modified if `f` actually
439    /// mutated the bag (judged by JSON-byte-length comparison).
440    ///
441    /// Use this in place of a series of `remove_custom`/`set_custom`
442    /// calls when several edits must appear atomic to a concurrent
443    /// reader. The original sequence in `clear_oauth_state` issued six
444    /// independent `remove_custom` calls; a parallel `set_custom` racing
445    /// in between could re-introduce a key after it had already been
446    /// removed.
447    ///
448    /// `f` runs while the write lock is held; do not `await` inside it
449    /// on something that needs the same lock.
450    pub async fn mutate_custom<F, T>(&self, f: F) -> T
451    where
452        F: FnOnce(&mut serde_json::Map<String, serde_json::Value>) -> T,
453    {
454        let mut guard = self.0.0.write().await;
455        // Ensure `custom` is an object so callers can rely on the borrow.
456        if !guard.data.custom.is_object() {
457            guard.data.custom = serde_json::Value::Object(serde_json::Map::new());
458        }
459        let before_len = serde_json::to_vec(&guard.data.custom)
460            .map(|v| v.len())
461            .unwrap_or(0);
462        let obj = guard
463            .data
464            .custom
465            .as_object_mut()
466            .expect("custom forced to object above");
467        let result = f(obj);
468        let after_len = serde_json::to_vec(&guard.data.custom)
469            .map(|v| v.len())
470            .unwrap_or(0);
471        if before_len != after_len {
472            guard.modified = true;
473        }
474        result
475    }
476
477    /// Atomically read **and remove** a key from the custom JSON bag.
478    ///
479    /// Returns the removed value, or `None` if the key was absent. The
480    /// read and the remove run under a single write lock; use this in
481    /// preference to `get_custom(k).await` followed by `remove_custom(k)`
482    /// for any one-shot ceremony state (FIDO2 `auth_state`, OAuth PKCE
483    /// verifier, password-reset token, etc.) where two parallel
484    /// requests must not both observe the same value.
485    pub async fn take_custom(&self, key: &str) -> Option<serde_json::Value> {
486        let mut guard = self.0.0.write().await;
487        let value = guard
488            .data
489            .custom
490            .as_object_mut()
491            .and_then(|obj| obj.remove(key));
492        if value.is_some() {
493            guard.modified = true;
494        }
495        value
496    }
497
498    /// Store a value in the custom JSON bag.
499    ///
500    /// Returns `false` if the write would exceed the configured
501    /// `max_custom_bytes` limit; the value is not stored in that case.
502    pub async fn set_custom(&self, key: impl Into<String>, value: serde_json::Value) -> bool {
503        let mut guard = self.0.0.write().await;
504        let key = key.into();
505        let limit = guard.max_custom_bytes;
506        if limit > 0 {
507            // Speculatively apply the change, check size, and rollback if over limit.
508            let prev = guard.data.custom.get(&key).cloned();
509            guard.data.custom[key.clone()] = value;
510            let size = serde_json::to_vec(&guard.data.custom)
511                .map(|v| v.len())
512                .unwrap_or(0);
513            if size > limit {
514                // Rollback.
515                match prev {
516                    Some(v) => guard.data.custom[key] = v,
517                    None => {
518                        if let Some(obj) = guard.data.custom.as_object_mut() {
519                            obj.remove(&key);
520                        }
521                    }
522                }
523                tracing::warn!(
524                    size,
525                    max = limit,
526                    "set_custom rejected; would exceed max_custom_bytes"
527                );
528                return false;
529            }
530        } else {
531            guard.data.custom[key] = value;
532        }
533        guard.modified = true;
534        true
535    }
536}
537
538// ── Axum extractor impl ────────────────────────────────────────────────────────
539
540/// Rejection type for when the session layer is not installed.
541#[derive(Debug)]
542pub struct SessionMissing;
543
544impl axum::response::IntoResponse for SessionMissing {
545    fn into_response(self) -> axum::response::Response {
546        (
547            axum::http::StatusCode::INTERNAL_SERVER_ERROR,
548            "Internal Server Error",
549        )
550            .into_response()
551    }
552}
553
554impl<S> FromRequestParts<S> for AuthSession
555where
556    S: Send + Sync,
557{
558    type Rejection = SessionMissing;
559
560    // `_state: &S` is the single documented carve-out from the workspace's
561    // "no `_`-prefixed names" rule. `FromRequestParts<S>` is axum's trait,
562    // so the parameter shape is fixed; `S` is unconstrained (only `Send +
563    // Sync` auto-traits) so no method can be called on it, no `tracing` use
564    // is possible, and `drop(state)` triggers `drop_ref` on the `&S` borrow.
565    // Renaming to `_state` is the language-idiomatic way to express "trait
566    // requires this; this impl genuinely cannot use it" without resorting
567    // to `#[allow(unused_variables)]` or `core::hint::black_box`.
568    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
569        parts
570            .extensions
571            .get::<SessionHandle>()
572            .cloned()
573            .map(AuthSession)
574            .ok_or(SessionMissing)
575    }
576}
577
578#[cfg(test)]
579mod tests;