oauth2-passkey 0.6.0

OAuth2 and Passkey authentication library for Rust web applications
Documentation
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use chrono::{DateTime, Utc};
use http::header::HeaderMap;
use std::str::FromStr;
use url::Url;

use crate::oauth2::{OAuth2Error, OAuth2Mode, OAuth2State, StateParams, StoredToken, TokenType};

use crate::session::{
    SessionId, User as SessionUser, delete_session_from_store_by_session_id, get_user_from_session,
};
use crate::storage::{
    CacheErrorConversion, CacheKey, CachePrefix, get_data, remove_data, store_cache_auto,
};

use crate::utils::gen_random_string_with_entropy_validation;

pub(super) fn encode_state(state_params: StateParams) -> Result<OAuth2State, OAuth2Error> {
    let state_json =
        serde_json::to_string(&state_params).map_err(|e| OAuth2Error::Serde(e.to_string()))?;
    let encoded = URL_SAFE_NO_PAD.encode(state_json);
    OAuth2State::new(encoded)
}

pub(crate) fn decode_state(state: &OAuth2State) -> Result<StateParams, OAuth2Error> {
    // Since OAuth2State is validated during construction, we know these operations will succeed
    // This is safe because validation in OAuth2State::new() already verified:
    // 1. Valid base64url encoding
    // 2. Valid UTF-8 content
    // 3. Valid JSON structure
    let decoded_bytes = URL_SAFE_NO_PAD
        .decode(state.as_str())
        .expect("OAuth2State should contain valid base64url");
    let decoded_state_string =
        String::from_utf8(decoded_bytes).expect("OAuth2State should contain valid UTF-8");
    let state_in_response: StateParams =
        serde_json::from_str(&decoded_state_string).expect("OAuth2State should contain valid JSON");
    Ok(state_in_response)
}

pub(super) async fn generate_store_token(
    token_type: TokenType,
    ttl: u64,
    expires_at: DateTime<Utc>,
    user_agent: Option<String>,
) -> Result<(String, String), OAuth2Error> {
    let token = gen_random_string_with_entropy_validation(32)?;

    let stored_token = StoredToken {
        token: token.clone(),
        expires_at,
        user_agent,
        ttl,
    };

    let cache_prefix =
        CachePrefix::new(token_type.to_string()).map_err(OAuth2Error::convert_storage_error)?;

    let token_id = store_cache_auto::<_, OAuth2Error>(cache_prefix, stored_token, ttl).await?;

    Ok((token, token_id.as_str().to_string()))
}

/// Verify a nonce against its cached value and consume it (single-use).
///
/// Retrieves the stored nonce from cache, checks expiration, compares with
/// the expected value (typically from an ID token), and removes it from cache.
pub(super) async fn verify_and_consume_nonce(
    nonce_id: &str,
    expected_nonce: Option<&str>,
) -> Result<(), OAuth2Error> {
    let nonce_cache_key =
        CacheKey::new(nonce_id.to_string()).map_err(OAuth2Error::convert_storage_error)?;
    let nonce_session: StoredToken =
        get_data::<StoredToken, OAuth2Error>(CachePrefix::nonce(), nonce_cache_key.clone())
            .await?
            .ok_or_else(|| {
                OAuth2Error::SecurityTokenNotFound("nonce not found in cache".to_string())
            })?;

    tracing::debug!("Nonce data: {:?}", nonce_session);

    if Utc::now() > nonce_session.expires_at {
        tracing::error!(
            "Nonce expired: {:?}, now: {:?}",
            nonce_session.expires_at,
            Utc::now()
        );
        return Err(OAuth2Error::NonceExpired);
    }

    if expected_nonce != Some(nonce_session.token.as_str()) {
        tracing::error!(
            "Nonce mismatch: expected={:?}, stored={:?}",
            expected_nonce,
            nonce_session.token
        );
        return Err(OAuth2Error::NonceMismatch);
    }

    remove_data::<OAuth2Error>(CachePrefix::nonce(), nonce_cache_key).await?;

    Ok(())
}

/// Parse a URL string into a `(scheme, host, port)` triple for origin
/// comparison. Scheme and host are lowercased per RFC 3986; port is the
/// explicit port or the scheme's default (443 for https, 80 for http).
/// Returns `None` if the input does not parse as a URL or has no host.
fn origin_triple(s: &str) -> Option<(String, String, Option<u16>)> {
    let u = Url::parse(s).ok()?;
    Some((
        u.scheme().to_ascii_lowercase(),
        u.host_str()?.to_ascii_lowercase(),
        u.port_or_known_default(),
    ))
}

pub(crate) async fn validate_origin(
    headers: &HeaderMap,
    auth_url: &str,
    additional_allowed_origins: &[String],
) -> Result<(), OAuth2Error> {
    // `expected_origin` is for error messages and logging only. It preserves
    // operator-facing input (no `:443` / `:80` injection) so messages look
    // like the configured value. Origin matching itself uses the structural
    // comparison in `allowed_triples` below.
    let parsed_url = Url::parse(auth_url).expect("Invalid URL");
    let scheme = parsed_url.scheme();
    let host = parsed_url.host_str().unwrap_or_default();
    let port = parsed_url
        .port()
        .map_or("".to_string(), |p| format!(":{p}"));
    let expected_origin = format!("{scheme}://{host}{port}");

    // Pre-parse the expected origin and each additional allowed origin into
    // (scheme, host, port) triples for structural comparison. This rejects
    // subdomain-confusion candidates like
    // "https://accounts.google.com.attacker.com" against
    // "https://accounts.google.com" that a raw `starts_with` would accept.
    let allowed_triples: Vec<_> = std::iter::once(auth_url)
        .chain(additional_allowed_origins.iter().map(String::as_str))
        .filter_map(origin_triple)
        .collect();

    // Browsers send `Origin: null` for cross-origin form_post redirects (e.g. Auth0).
    // Treat the literal string "null" the same as absent and fall back to Referer.
    let origin = {
        let raw = headers.get("Origin").and_then(|h| h.to_str().ok());
        match raw {
            Some("null") | None => headers.get("Referer").and_then(|h| h.to_str().ok()),
            some => some,
        }
    };

    let matches = |candidate: &str| {
        origin_triple(candidate).is_some_and(|cand| allowed_triples.contains(&cand))
    };

    match origin {
        Some(origin) if matches(origin) => Ok(()),
        _ => {
            tracing::error!("Expected Origin: {:#?}", expected_origin);
            tracing::error!(
                "Additional allowed origins: {:#?}",
                additional_allowed_origins
            );
            tracing::error!("Actual Origin: {:#?}", origin);
            Err(OAuth2Error::InvalidOrigin(format!(
                "Expected Origin: {expected_origin:#?} (or one of {additional_allowed_origins:?}), Actual Origin: {origin:#?}"
            )))
        }
    }
}

/// Extract user ID from a stored session if it exists in the state parameters.
/// Returns None if:
/// - No misc_id in state parameters
/// - Session not found in cache
/// - Error getting user from session
pub(crate) async fn get_uid_from_stored_session_by_state_param(
    state_params: &StateParams,
) -> Result<Option<SessionUser>, OAuth2Error> {
    let Some(misc_id) = &state_params.misc_id else {
        tracing::debug!("No misc_id in state");
        return Ok(None);
    };

    tracing::debug!("misc_id: {:#?}", misc_id);

    let misc_cache_key = match CacheKey::new(misc_id.clone()) {
        Ok(key) => key,
        Err(e) => {
            tracing::debug!("Failed to create cache key: {}", e);
            return Ok(None);
        }
    };
    let Ok(Some(token)) =
        get_data::<StoredToken, OAuth2Error>(CachePrefix::misc_session(), misc_cache_key).await
    else {
        tracing::debug!("Failed to get session from cache");
        return Ok(None);
    };

    tracing::debug!("Token: {:#?}", token);

    // Clean up the misc session after use
    // remove_token_from_store("misc_session", misc_id).await?;

    let session_cookie = crate::SessionCookie::new(token.token.clone())
        .map_err(|e| OAuth2Error::Storage(format!("Invalid session cookie: {e}")))?;
    match get_user_from_session(&session_cookie).await {
        Ok(user) => {
            tracing::debug!("Found user ID: {}", user.id);
            Ok(Some(user))
        }
        Err(e) => {
            tracing::debug!("Failed to get user from session: {}", e);
            Ok(None)
        }
    }
}

pub(crate) async fn delete_session_and_misc_token_from_store(
    state_params: &StateParams,
) -> Result<(), OAuth2Error> {
    if let Some(misc_id) = &state_params.misc_id {
        let misc_cache_key = match CacheKey::new(misc_id.clone()) {
            Ok(key) => key,
            Err(e) => {
                tracing::debug!("Failed to create cache key: {}", e);
                return Ok(());
            }
        };
        let Ok(Some(token)) = get_data::<StoredToken, OAuth2Error>(
            CachePrefix::misc_session(),
            misc_cache_key.clone(),
        )
        .await
        else {
            tracing::debug!("Failed to get session from cache");
            return Ok(());
        };

        let session_id = SessionId::new(token.token)
            .map_err(|e| OAuth2Error::Storage(format!("Invalid session ID: {e}")))?;
        delete_session_from_store_by_session_id(session_id)
            .await
            .map_err(|e| OAuth2Error::Storage(e.to_string()))?;

        remove_data::<OAuth2Error>(CachePrefix::misc_session(), misc_cache_key).await?;
    }

    Ok(())
}

pub(crate) async fn get_mode_from_stored_session(
    mode_id: &str,
) -> Result<Option<OAuth2Mode>, OAuth2Error> {
    let mode_cache_key = match CacheKey::new(mode_id.to_string()) {
        Ok(key) => key,
        Err(e) => {
            tracing::debug!("Failed to create cache key: {}", e);
            return Ok(None);
        }
    };
    let Ok(Some(token)) =
        get_data::<StoredToken, OAuth2Error>(CachePrefix::mode(), mode_cache_key).await
    else {
        tracing::debug!("Failed to get mode from cache");
        return Ok(None);
    };

    // Convert the string to OAuth2Mode enum
    match OAuth2Mode::from_str(&token.token) {
        Ok(mode) => Ok(Some(mode)),
        Err(_) => {
            tracing::warn!("Invalid mode value in cache: {}", token.token);
            Ok(None)
        }
    }
}

#[cfg(test)]
mod tests;