Skip to main content

axess_core/session/layer/
service.rs

1//! Tower middleware service that runs the per-request session lifecycle.
2//!
3//! [`SessionService`] is the concrete `tower::Service` constructed by
4//! [`SessionLayer::layer`](super::SessionLayer). Its `call` method
5//! coordinates the four lifecycle steps: load → device-resolve →
6//! handler → finalize, then emits the response cookie.
7
8#[cfg(feature = "device")]
9use crate::device::resolver::ErasedDeviceResolver;
10use crate::session::binding::{self, SessionBinding};
11use crate::session::config::SessionConfig;
12use crate::session::layer::SessionLayer;
13use crate::session::layer::handle::{SessionHandle, SessionInner};
14use crate::session::layer::lifecycle::{build_set_cookie, finalize_session, load_session};
15use crate::session::layer::signing::SigningKeys;
16use crate::session::store::SessionStore;
17use axum::{body::Body, http::Request, response::Response};
18use std::{
19    future::Future,
20    pin::Pin,
21    sync::Arc,
22    task::{Context, Poll},
23};
24use tokio::sync::RwLock;
25use tower::{Layer, Service};
26
27/// Tower service wrapping an inner service with session management.
28///
29/// Returned by [`<SessionLayer as tower::Layer>::layer`](super::SessionLayer).
30/// `pub` because the [`tower::Layer`] impl is on the public
31/// `SessionLayer` and its `type Service = SessionService<…>`
32/// associated type cannot be more private than the trait impl.
33/// `#[doc(hidden)]` keeps the type out of rendered API docs;
34/// adopters reach it only through the trait surface.
35#[doc(hidden)]
36#[derive(Clone)]
37pub struct SessionService<S, Inner> {
38    inner: Inner,
39    store: S,
40    signing_keys: Arc<SigningKeys>,
41    /// Shared with [`SessionLayer`] via `Arc` to keep
42    /// `Layer::layer` and per-request `Service::call` cloning cheap.
43    config: Arc<SessionConfig>,
44    binding: Option<Arc<dyn SessionBinding>>,
45    metrics: Option<Arc<dyn crate::metrics::AuthnMetrics>>,
46    #[cfg(feature = "device")]
47    device_resolver: Option<Arc<dyn ErasedDeviceResolver>>,
48}
49
50impl<S, Inner> Layer<Inner> for SessionLayer<S>
51where
52    S: SessionStore + Clone,
53{
54    type Service = SessionService<S, Inner>;
55
56    fn layer(&self, inner: Inner) -> Self::Service {
57        SessionService {
58            inner,
59            store: self.store.clone(),
60            signing_keys: self.signing_keys.clone(),
61            config: self.config.clone(),
62            binding: self.binding.clone(),
63            metrics: self.metrics.clone(),
64            #[cfg(feature = "device")]
65            device_resolver: self.device_resolver.clone(),
66        }
67    }
68}
69
70impl<S, Inner, ResBody> Service<Request<Body>> for SessionService<S, Inner>
71where
72    S: SessionStore + Clone + Send + Sync + 'static,
73    S::Error: Send + Sync + 'static,
74    Inner: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
75    Inner::Future: Send + 'static,
76    Inner::Error: Send + 'static,
77    ResBody: Send + 'static,
78{
79    type Response = Response<ResBody>;
80    type Error = Inner::Error;
81    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
82
83    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
84        self.inner.poll_ready(cx)
85    }
86
87    fn call(&mut self, mut req: Request<Body>) -> Self::Future {
88        let store = self.store.clone();
89        let config = self.config.clone();
90        let signing_keys = self.signing_keys.clone();
91        let session_binding = self.binding.clone();
92        let metrics = self.metrics.clone();
93        #[cfg(feature = "device")]
94        let device_resolver = self.device_resolver.clone();
95
96        // Clone inner *before* the async block; required by tower's contract.
97        let mut inner = self.inner.clone();
98        std::mem::swap(&mut inner, &mut self.inner);
99
100        // Pre-compute the binding HMAC from the request before moving it.
101        // Use the dedicated fingerprint sub-key; distinct from the
102        // cookie-signing key.
103        let current_fingerprint = session_binding
104            .as_deref()
105            .and_then(|b| binding::compute_fingerprint(b, &req, &signing_keys.fingerprint));
106
107        Box::pin(async move {
108            // 1. Load session; cookie verify, store load, fingerprint check, fresh-mint.
109            let load = load_session(
110                &store,
111                &signing_keys,
112                &config,
113                metrics.as_deref(),
114                req.headers(),
115                current_fingerprint.as_deref(),
116            )
117            .await;
118
119            // 1c. Resolve the device for this request. Best-effort;
120            //     `Err(_)` is logged and treated as `None` by the
121            //     erasure wrapper. We mark the session modified iff
122            //     the resolved id differs from the loaded one, so an
123            //     unchanged device on a stable session does not trigger
124            //     a gratuitous save on every request.
125            //
126            //     `axum::body::Body` is `!Sync`, so `&Request<Body>`
127            //     cannot be borrowed across an `await`. Split the
128            //     request into `(Parts, Body)`, run the resolver
129            //     against the parts, then reassemble before the inner
130            //     service call.
131            //
132            //     The `mut load` shadow binding scopes the mutability
133            //     to the device-enabled branch; under
134            //     `cfg(not(feature = "device"))` the outer `load`
135            //     stays immutable so no lint suppression is needed.
136            #[cfg(feature = "device")]
137            let (load, device_changed) = {
138                let mut load = load;
139                let dc = if let Some(ref resolver) = device_resolver {
140                    let (parts, body) = req.into_parts();
141                    let resolved = resolver.resolve_erased(&parts).await;
142                    req = Request::from_parts(parts, body);
143                    let differs = resolved != load.data.device_id;
144                    if differs {
145                        load.data.device_id = resolved;
146                    }
147                    differs
148                } else {
149                    false
150                };
151                (load, dc)
152            };
153            #[cfg(not(feature = "device"))]
154            let device_changed = false;
155
156            // 2. Insert SessionHandle into request extensions.
157            //    If binding is configured and the session has no fingerprint yet,
158            //    pass the pre-computed fingerprint so the extractor can apply it
159            //    immediately when the session transitions to Authenticated.
160            let pending_fp = if session_binding.is_some() && load.data.fingerprint.is_none() {
161                current_fingerprint.clone()
162            } else {
163                None
164            };
165
166            let inner_state = SessionInner {
167                id: load.id,
168                data: load.data,
169                modified: load.binding_invalidated || device_changed,
170                regenerate: load.binding_invalidated,
171                pre_cycle_id: None,
172                pending_fingerprint: pending_fp,
173                max_custom_bytes: config.max_custom_bytes,
174            };
175            let handle = SessionHandle(Arc::new(RwLock::new(inner_state)));
176            req.extensions_mut().insert(handle.clone());
177
178            // 3. Call the inner service.
179            let response = inner.call(req).await?;
180
181            // 4. Finalize; custom-data size enforcement + save-or-cycle decision.
182            let outcome = finalize_session(
183                &store,
184                &config,
185                metrics.as_deref(),
186                &handle,
187                load.existing_id,
188            )
189            .await;
190
191            // 5. Set the cookie only when the session was created or changed.
192            //    Omitting Set-Cookie on unmodified responses reduces header bloat
193            //    and prevents spurious cache invalidation on CDN / reverse proxies.
194            let mut response = response;
195            if outcome.session_changed
196                && let Some(hv) = build_set_cookie(&signing_keys, &config, outcome.final_id)
197            {
198                response
199                    .headers_mut()
200                    .append(axum::http::header::SET_COOKIE, hv);
201            }
202
203            Ok(response)
204        })
205    }
206}