#[cfg(feature = "device")]
use crate::device::resolver::ErasedDeviceResolver;
use crate::session::binding::{self, SessionBinding};
use crate::session::config::SessionConfig;
use crate::session::layer::SessionLayer;
use crate::session::layer::handle::{SessionHandle, SessionInner};
use crate::session::layer::lifecycle::{build_set_cookie, finalize_session, load_session};
use crate::session::layer::signing::SigningKeys;
use crate::session::store::SessionStore;
use axum::{body::Body, http::Request, response::Response};
use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tokio::sync::RwLock;
use tower::{Layer, Service};
#[doc(hidden)]
#[derive(Clone)]
pub struct SessionService<S, Inner> {
inner: Inner,
store: S,
signing_keys: Arc<SigningKeys>,
config: Arc<SessionConfig>,
binding: Option<Arc<dyn SessionBinding>>,
metrics: Option<Arc<dyn crate::metrics::AuthnMetrics>>,
#[cfg(feature = "device")]
device_resolver: Option<Arc<dyn ErasedDeviceResolver>>,
}
impl<S, Inner> Layer<Inner> for SessionLayer<S>
where
S: SessionStore + Clone,
{
type Service = SessionService<S, Inner>;
fn layer(&self, inner: Inner) -> Self::Service {
SessionService {
inner,
store: self.store.clone(),
signing_keys: self.signing_keys.clone(),
config: self.config.clone(),
binding: self.binding.clone(),
metrics: self.metrics.clone(),
#[cfg(feature = "device")]
device_resolver: self.device_resolver.clone(),
}
}
}
impl<S, Inner, ResBody> Service<Request<Body>> for SessionService<S, Inner>
where
S: SessionStore + Clone + Send + Sync + 'static,
S::Error: Send + Sync + 'static,
Inner: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
Inner::Future: Send + 'static,
Inner::Error: Send + 'static,
ResBody: Send + 'static,
{
type Response = Response<ResBody>;
type Error = Inner::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
let store = self.store.clone();
let config = self.config.clone();
let signing_keys = self.signing_keys.clone();
let session_binding = self.binding.clone();
let metrics = self.metrics.clone();
#[cfg(feature = "device")]
let device_resolver = self.device_resolver.clone();
let mut inner = self.inner.clone();
std::mem::swap(&mut inner, &mut self.inner);
let current_fingerprint = session_binding
.as_deref()
.and_then(|b| binding::compute_fingerprint(b, &req, &signing_keys.fingerprint));
Box::pin(async move {
let load = load_session(
&store,
&signing_keys,
&config,
metrics.as_deref(),
req.headers(),
current_fingerprint.as_deref(),
)
.await;
#[cfg(feature = "device")]
let (load, device_changed) = {
let mut load = load;
let dc = if let Some(ref resolver) = device_resolver {
let (parts, body) = req.into_parts();
let resolved = resolver.resolve_erased(&parts).await;
req = Request::from_parts(parts, body);
let differs = resolved != load.data.device_id;
if differs {
load.data.device_id = resolved;
}
differs
} else {
false
};
(load, dc)
};
#[cfg(not(feature = "device"))]
let device_changed = false;
let pending_fp = if session_binding.is_some() && load.data.fingerprint.is_none() {
current_fingerprint.clone()
} else {
None
};
let inner_state = SessionInner {
id: load.id,
data: load.data,
modified: load.binding_invalidated || device_changed,
regenerate: load.binding_invalidated,
pre_cycle_id: None,
pending_fingerprint: pending_fp,
max_custom_bytes: config.max_custom_bytes,
};
let handle = SessionHandle(Arc::new(RwLock::new(inner_state)));
req.extensions_mut().insert(handle.clone());
let response = inner.call(req).await?;
let outcome = finalize_session(
&store,
&config,
metrics.as_deref(),
&handle,
load.existing_id,
)
.await;
let mut response = response;
if outcome.session_changed
&& let Some(hv) = build_set_cookie(&signing_keys, &config, outcome.final_id)
{
response
.headers_mut()
.append(axum::http::header::SET_COOKIE, hv);
}
Ok(response)
})
}
}