use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use chrono::{DateTime, Utc};
use http::header::HeaderMap;
use std::str::FromStr;
use url::Url;
use crate::oauth2::{OAuth2Error, OAuth2Mode, OAuth2State, StateParams, StoredToken, TokenType};
use crate::session::{
SessionId, User as SessionUser, delete_session_from_store_by_session_id, get_user_from_session,
};
use crate::storage::{
CacheErrorConversion, CacheKey, CachePrefix, get_data, remove_data, store_cache_auto,
};
use crate::utils::gen_random_string_with_entropy_validation;
pub(super) fn encode_state(state_params: StateParams) -> Result<OAuth2State, OAuth2Error> {
let state_json =
serde_json::to_string(&state_params).map_err(|e| OAuth2Error::Serde(e.to_string()))?;
let encoded = URL_SAFE_NO_PAD.encode(state_json);
OAuth2State::new(encoded)
}
pub(crate) fn decode_state(state: &OAuth2State) -> Result<StateParams, OAuth2Error> {
let decoded_bytes = URL_SAFE_NO_PAD
.decode(state.as_str())
.expect("OAuth2State should contain valid base64url");
let decoded_state_string =
String::from_utf8(decoded_bytes).expect("OAuth2State should contain valid UTF-8");
let state_in_response: StateParams =
serde_json::from_str(&decoded_state_string).expect("OAuth2State should contain valid JSON");
Ok(state_in_response)
}
pub(super) async fn generate_store_token(
token_type: TokenType,
ttl: u64,
expires_at: DateTime<Utc>,
user_agent: Option<String>,
) -> Result<(String, String), OAuth2Error> {
let token = gen_random_string_with_entropy_validation(32)?;
let stored_token = StoredToken {
token: token.clone(),
expires_at,
user_agent,
ttl,
};
let cache_prefix =
CachePrefix::new(token_type.to_string()).map_err(OAuth2Error::convert_storage_error)?;
let token_id = store_cache_auto::<_, OAuth2Error>(cache_prefix, stored_token, ttl).await?;
Ok((token, token_id.as_str().to_string()))
}
pub(super) async fn verify_and_consume_nonce(
nonce_id: &str,
expected_nonce: Option<&str>,
) -> Result<(), OAuth2Error> {
let nonce_cache_key =
CacheKey::new(nonce_id.to_string()).map_err(OAuth2Error::convert_storage_error)?;
let nonce_session: StoredToken =
get_data::<StoredToken, OAuth2Error>(CachePrefix::nonce(), nonce_cache_key.clone())
.await?
.ok_or_else(|| {
OAuth2Error::SecurityTokenNotFound("nonce not found in cache".to_string())
})?;
tracing::debug!("Nonce data: {:?}", nonce_session);
if Utc::now() > nonce_session.expires_at {
tracing::error!(
"Nonce expired: {:?}, now: {:?}",
nonce_session.expires_at,
Utc::now()
);
return Err(OAuth2Error::NonceExpired);
}
if expected_nonce != Some(nonce_session.token.as_str()) {
tracing::error!(
"Nonce mismatch: expected={:?}, stored={:?}",
expected_nonce,
nonce_session.token
);
return Err(OAuth2Error::NonceMismatch);
}
remove_data::<OAuth2Error>(CachePrefix::nonce(), nonce_cache_key).await?;
Ok(())
}
fn origin_triple(s: &str) -> Option<(String, String, Option<u16>)> {
let u = Url::parse(s).ok()?;
Some((
u.scheme().to_ascii_lowercase(),
u.host_str()?.to_ascii_lowercase(),
u.port_or_known_default(),
))
}
pub(crate) async fn validate_origin(
headers: &HeaderMap,
auth_url: &str,
additional_allowed_origins: &[String],
) -> Result<(), OAuth2Error> {
let parsed_url = Url::parse(auth_url).expect("Invalid URL");
let scheme = parsed_url.scheme();
let host = parsed_url.host_str().unwrap_or_default();
let port = parsed_url
.port()
.map_or("".to_string(), |p| format!(":{p}"));
let expected_origin = format!("{scheme}://{host}{port}");
let allowed_triples: Vec<_> = std::iter::once(auth_url)
.chain(additional_allowed_origins.iter().map(String::as_str))
.filter_map(origin_triple)
.collect();
let origin = {
let raw = headers.get("Origin").and_then(|h| h.to_str().ok());
match raw {
Some("null") | None => headers.get("Referer").and_then(|h| h.to_str().ok()),
some => some,
}
};
let matches = |candidate: &str| {
origin_triple(candidate).is_some_and(|cand| allowed_triples.contains(&cand))
};
match origin {
Some(origin) if matches(origin) => Ok(()),
_ => {
tracing::error!("Expected Origin: {:#?}", expected_origin);
tracing::error!(
"Additional allowed origins: {:#?}",
additional_allowed_origins
);
tracing::error!("Actual Origin: {:#?}", origin);
Err(OAuth2Error::InvalidOrigin(format!(
"Expected Origin: {expected_origin:#?} (or one of {additional_allowed_origins:?}), Actual Origin: {origin:#?}"
)))
}
}
}
pub(crate) async fn get_uid_from_stored_session_by_state_param(
state_params: &StateParams,
) -> Result<Option<SessionUser>, OAuth2Error> {
let Some(misc_id) = &state_params.misc_id else {
tracing::debug!("No misc_id in state");
return Ok(None);
};
tracing::debug!("misc_id: {:#?}", misc_id);
let misc_cache_key = match CacheKey::new(misc_id.clone()) {
Ok(key) => key,
Err(e) => {
tracing::debug!("Failed to create cache key: {}", e);
return Ok(None);
}
};
let Ok(Some(token)) =
get_data::<StoredToken, OAuth2Error>(CachePrefix::misc_session(), misc_cache_key).await
else {
tracing::debug!("Failed to get session from cache");
return Ok(None);
};
tracing::debug!("Token: {:#?}", token);
let session_cookie = crate::SessionCookie::new(token.token.clone())
.map_err(|e| OAuth2Error::Storage(format!("Invalid session cookie: {e}")))?;
match get_user_from_session(&session_cookie).await {
Ok(user) => {
tracing::debug!("Found user ID: {}", user.id);
Ok(Some(user))
}
Err(e) => {
tracing::debug!("Failed to get user from session: {}", e);
Ok(None)
}
}
}
pub(crate) async fn delete_session_and_misc_token_from_store(
state_params: &StateParams,
) -> Result<(), OAuth2Error> {
if let Some(misc_id) = &state_params.misc_id {
let misc_cache_key = match CacheKey::new(misc_id.clone()) {
Ok(key) => key,
Err(e) => {
tracing::debug!("Failed to create cache key: {}", e);
return Ok(());
}
};
let Ok(Some(token)) = get_data::<StoredToken, OAuth2Error>(
CachePrefix::misc_session(),
misc_cache_key.clone(),
)
.await
else {
tracing::debug!("Failed to get session from cache");
return Ok(());
};
let session_id = SessionId::new(token.token)
.map_err(|e| OAuth2Error::Storage(format!("Invalid session ID: {e}")))?;
delete_session_from_store_by_session_id(session_id)
.await
.map_err(|e| OAuth2Error::Storage(e.to_string()))?;
remove_data::<OAuth2Error>(CachePrefix::misc_session(), misc_cache_key).await?;
}
Ok(())
}
pub(crate) async fn get_mode_from_stored_session(
mode_id: &str,
) -> Result<Option<OAuth2Mode>, OAuth2Error> {
let mode_cache_key = match CacheKey::new(mode_id.to_string()) {
Ok(key) => key,
Err(e) => {
tracing::debug!("Failed to create cache key: {}", e);
return Ok(None);
}
};
let Ok(Some(token)) =
get_data::<StoredToken, OAuth2Error>(CachePrefix::mode(), mode_cache_key).await
else {
tracing::debug!("Failed to get mode from cache");
return Ok(None);
};
match OAuth2Mode::from_str(&token.token) {
Ok(mode) => Ok(Some(mode)),
Err(_) => {
tracing::warn!("Invalid mode value in cache: {}", token.token);
Ok(None)
}
}
}
#[cfg(test)]
mod tests;