use std::{
sync::LazyLock,
time::{Duration, SystemTime},
};
use bytes::Bytes;
use http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode, Uri, header};
use huskarl::{
core::{
BoxedError, Error as _,
crypto::cipher::{AeadV1Sealer, AeadV1Unsealer, BoxedAeadCipher},
http::HttpClient,
},
grant::{authorization_code::PendingState, core::TokenResponse},
token::RefreshToken,
};
use rand::RngExt as _;
use serde::{Deserialize, Serialize};
use crate::{
DefaultErrorPage, ErrorPage, LoginConfig, LoginGrant, Session, SessionDriver, SessionError,
};
mod callback;
mod logout;
mod redirect;
#[cfg(test)]
mod tests;
type EngineError = Box<dyn std::error::Error + Send + Sync>;
pub struct LoginResponse {
pub status: StatusCode,
pub headers: Vec<(HeaderName, HeaderValue)>,
pub body: Bytes,
}
pub struct LoadedSession<S> {
pub session: Option<(S, SessionPersistence)>,
pub clear_cookies: Vec<HeaderValue>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SessionPersistence {
Save,
Touch,
Skip,
}
pub trait PersistFailurePolicy: Send + Sync + 'static {
fn handle(
&self,
persistence: SessionPersistence,
error: &dyn std::error::Error,
) -> Option<LoginResponse>;
}
pub struct DefaultPersistFailurePolicy;
impl PersistFailurePolicy for DefaultPersistFailurePolicy {
fn handle(
&self,
persistence: SessionPersistence,
_error: &dyn std::error::Error,
) -> Option<LoginResponse> {
match persistence {
SessionPersistence::Save => Some(LoginResponse {
status: StatusCode::SERVICE_UNAVAILABLE,
headers: Vec::new(),
body: Bytes::new(),
}),
SessionPersistence::Touch | SessionPersistence::Skip => None,
}
}
}
#[derive(Serialize, Deserialize)]
struct LoginStateCookie {
original_url: String,
pending_state: PendingState,
}
pub struct LoginEngine<G, SD, H> {
config: LoginConfig,
grant: G,
session_store: SD,
sealer: AeadV1Sealer<BoxedAeadCipher>,
unsealer: AeadV1Unsealer<BoxedAeadCipher>,
http_client: H,
error_page: Box<dyn ErrorPage>,
}
#[bon::bon]
impl<G, SD, H> LoginEngine<G, SD, H>
where
G: LoginGrant,
SD: SessionDriver,
H: HttpClient + Send + Sync,
{
#[builder]
pub fn new(
config: LoginConfig,
grant: G,
session_store: SD,
cipher: BoxedAeadCipher,
http_client: H,
#[builder(default = Box::new(DefaultErrorPage) as Box<dyn ErrorPage>)]
error_page: Box<dyn ErrorPage>,
) -> Self {
Self {
config,
grant,
session_store,
sealer: AeadV1Sealer::new(cipher.clone()),
unsealer: AeadV1Unsealer::new(cipher),
http_client,
error_page,
}
}
}
impl<G, SD, H> LoginEngine<G, SD, H> {
pub fn config(&self) -> &LoginConfig {
&self.config
}
pub fn session_store(&self) -> &SD {
&self.session_store
}
}
impl<G, SD, H> LoginEngine<G, SD, H>
where
G: LoginGrant,
SD: SessionDriver,
H: HttpClient + Send + Sync,
{
pub async fn try_handle_login_route(
&self,
path: &str,
_method: &Method,
headers: &HeaderMap,
uri: &Uri,
) -> Option<LoginResponse> {
if path == self.config.callback_path {
return Some(self.handle_callback(uri, headers).await);
}
if self
.config
.logout_path
.as_deref()
.is_some_and(|p| path == p)
{
return Some(self.handle_logout(headers).await);
}
None
}
pub async fn load_session(
&self,
headers: &HeaderMap,
) -> Result<LoadedSession<SD::SessionType>, SessionError> {
let Some(mut session) = self.session_store.load(headers).await? else {
return Ok(LoadedSession {
session: None,
clear_cookies: vec![],
});
};
let now = SystemTime::now();
if self.session_is_expired(&session, now) {
let clear_cookies = self.delete_best_effort(&session, headers).await;
return Ok(LoadedSession {
session: None,
clear_cookies,
});
}
if now + self.config.token_refresh_margin >= session.token_expiry() {
return Ok(self.refresh_or_clear(session, headers).await);
}
let persistence = self.touch_or_skip(&mut session, now);
Ok(LoadedSession {
session: Some((session, persistence)),
clear_cookies: vec![],
})
}
pub async fn redirect_to_login(&self, headers: &HeaderMap, uri: &Uri) -> LoginResponse {
if !is_navigation_request(headers) {
return self.build_error_response(StatusCode::UNAUTHORIZED, "authentication required");
}
match self.redirect_to_as(headers, uri, None).await {
Ok(resp) => resp,
Err(e) => {
log::error!(
"failed to redirect to authorization server: {}",
error_chain(&*e)
);
self.build_error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"failed to start login",
)
}
}
}
fn session_is_expired(&self, session: &SD::SessionType, now: SystemTime) -> bool {
if is_too_far_future(session.created_at(), now)
|| is_too_far_future(session.last_active(), now)
{
log::warn!("session timestamps are too far in the future — treating as expired");
return true;
}
if let Some(max_lifetime) = self.config.max_lifetime
&& elapsed_since(session.created_at(), now) > max_lifetime
{
return true;
}
if let Some(idle_timeout) = self.config.idle_timeout
&& elapsed_since(session.last_active(), now) > idle_timeout
{
return true;
}
false
}
async fn refresh_or_clear(
&self,
mut session: SD::SessionType,
headers: &HeaderMap,
) -> LoadedSession<SD::SessionType> {
let Some(rt) = session.refresh_token().cloned() else {
let clear_cookies = self.delete_best_effort(&session, headers).await;
return LoadedSession {
session: None,
clear_cookies,
};
};
match self.refresh_with_retry(&rt).await {
Ok(token_response) => {
session.apply_refresh(&token_response, self.config.default_token_lifetime);
LoadedSession {
session: Some((session, SessionPersistence::Save)),
clear_cookies: vec![],
}
}
Err(e) => {
log::error!("token refresh failed: {}", error_chain(&e));
let clear_cookies = self.delete_best_effort(&session, headers).await;
LoadedSession {
session: None,
clear_cookies,
}
}
}
}
async fn delete_best_effort(
&self,
session: &SD::SessionType,
headers: &HeaderMap,
) -> Vec<HeaderValue> {
match self.session_store.delete(session, headers).await {
Ok(c) => c,
Err(e) => {
log::error!("failed to delete session: {}", error_chain(&*e));
vec![]
}
}
}
async fn refresh_with_retry(&self, rt: &RefreshToken) -> Result<TokenResponse, BoxedError> {
let mut attempt = 0;
loop {
attempt += 1;
match self.grant.refresh(&self.http_client, rt).await {
Ok(tr) => return Ok(tr),
Err(e) if attempt < REFRESH_MAX_ATTEMPTS && e.is_retryable() => {
let delay = refresh_retry_delay(attempt);
log::warn!(
"token refresh failed (attempt {attempt}/{REFRESH_MAX_ATTEMPTS}, retrying in {delay:?}): {e}"
);
tokio::time::sleep(delay).await;
}
Err(e) => return Err(e),
}
}
}
fn touch_or_skip(&self, session: &mut SD::SessionType, now: SystemTime) -> SessionPersistence {
if elapsed_since(session.last_active(), now) >= self.config.touch_min_interval {
session.record_activity();
SessionPersistence::Touch
} else {
SessionPersistence::Skip
}
}
pub async fn persist_session(
&self,
session: &SD::SessionType,
persistence: SessionPersistence,
request_headers: &HeaderMap,
) -> Result<Vec<HeaderValue>, SessionError> {
match persistence {
SessionPersistence::Save => self.session_store.save(session, request_headers).await,
SessionPersistence::Touch => self.session_store.touch(session, request_headers).await,
SessionPersistence::Skip => Ok(vec![]),
}
}
pub async fn delete_session(
&self,
session: &SD::SessionType,
request_headers: &HeaderMap,
) -> Result<Vec<HeaderValue>, SessionError> {
self.session_store.delete(session, request_headers).await
}
pub async fn save_session(
&self,
session: &SD::SessionType,
request_headers: &HeaderMap,
) -> Result<Vec<HeaderValue>, SessionError> {
self.session_store.save(session, request_headers).await
}
pub async fn touch_session(
&self,
session: &SD::SessionType,
request_headers: &HeaderMap,
) -> Result<Vec<HeaderValue>, SessionError> {
self.session_store.touch(session, request_headers).await
}
pub fn render_error(&self, status: StatusCode, message: &str) -> LoginResponse {
self.build_error_response(status, message)
}
fn build_error_response(&self, status: StatusCode, message: &str) -> LoginResponse {
let rendered = self.error_page.render(status, message);
LoginResponse {
status,
headers: vec![
(
header::CONTENT_TYPE,
HeaderValue::from_static(rendered.content_type),
),
(header::CACHE_CONTROL, HeaderValue::from_static("no-store")),
],
body: rendered.body,
}
}
}
const REFRESH_MAX_ATTEMPTS: u32 = 3;
const REFRESH_RETRY_BASE_DELAY: Duration = Duration::from_millis(100);
const REFRESH_RETRY_JITTER_MAX: Duration = Duration::from_millis(50);
fn refresh_retry_delay(attempt: u32) -> Duration {
let base = REFRESH_RETRY_BASE_DELAY * (1u32 << (attempt - 1).min(16));
let jitter_max_ms = u64::try_from(REFRESH_RETRY_JITTER_MAX.as_millis()).unwrap_or(u64::MAX);
let jitter_ms = rand::rng().random_range(0..jitter_max_ms);
base + Duration::from_millis(jitter_ms)
}
const MAX_CLOCK_SKEW: Duration = Duration::from_mins(1);
fn is_too_far_future(timestamp: SystemTime, now: SystemTime) -> bool {
timestamp
.duration_since(now)
.is_ok_and(|ahead| ahead > MAX_CLOCK_SKEW)
}
fn elapsed_since(earlier: SystemTime, now: SystemTime) -> Duration {
now.duration_since(earlier).unwrap_or(Duration::ZERO)
}
static SEC_FETCH_MODE: LazyLock<HeaderName> =
LazyLock::new(|| HeaderName::from_static("sec-fetch-mode"));
static SEC_FETCH_DEST: LazyLock<HeaderName> =
LazyLock::new(|| HeaderName::from_static("sec-fetch-dest"));
static X_REQUESTED_WITH: LazyLock<HeaderName> =
LazyLock::new(|| HeaderName::from_static("x-requested-with"));
pub fn is_cors_preflight(method: &Method, headers: &HeaderMap) -> bool {
*method == Method::OPTIONS && headers.contains_key(header::ACCESS_CONTROL_REQUEST_METHOD)
}
pub fn is_navigation_request(headers: &HeaderMap) -> bool {
if headers
.get(&*X_REQUESTED_WITH)
.is_some_and(|v| v.as_bytes().eq_ignore_ascii_case(b"XMLHttpRequest"))
{
return false;
}
if let Some(mode) = headers.get(&*SEC_FETCH_MODE) {
return mode.as_bytes() == b"navigate";
}
if let Some(dest) = headers.get(&*SEC_FETCH_DEST) {
return dest.as_bytes() == b"document";
}
headers
.get(header::ACCEPT)
.and_then(|v| v.to_str().ok())
.is_some_and(|v| v.contains("text/html") || v.contains("application/xhtml+xml"))
}
pub fn error_chain(e: &dyn std::error::Error) -> String {
use std::fmt::Write as _;
let mut s = e.to_string();
let mut source = e.source();
while let Some(cause) = source {
let _ = write!(s, ": {cause}");
source = cause.source();
}
s
}
#[cfg(test)]
mod retry_delay_tests {
use std::collections::HashSet;
use super::{REFRESH_RETRY_BASE_DELAY, REFRESH_RETRY_JITTER_MAX, refresh_retry_delay};
fn base(attempt: u32) -> std::time::Duration {
REFRESH_RETRY_BASE_DELAY * (1u32 << (attempt - 1))
}
#[test]
fn delay_stays_within_jitter_window() {
for attempt in 1..=3 {
let lo = base(attempt);
let hi = lo + REFRESH_RETRY_JITTER_MAX;
for _ in 0..50 {
let d = refresh_retry_delay(attempt);
assert!(
d >= lo && d < hi,
"attempt {attempt}: delay {d:?} out of [{lo:?}, {hi:?})",
);
}
}
}
#[test]
fn later_attempts_strictly_outpace_earlier_ones() {
assert!(
REFRESH_RETRY_JITTER_MAX < REFRESH_RETRY_BASE_DELAY,
"jitter must stay below base or successive windows overlap",
);
}
#[test]
fn jitter_actually_varies_across_calls() {
let samples: HashSet<_> = (0..50).map(|_| refresh_retry_delay(1)).collect();
assert!(
samples.len() > 5,
"expected jitter to vary, got {samples:?}"
);
}
}