use chrono::{Duration, Utc};
use http::HeaderMap;
use std::{env, sync::LazyLock};
use crate::audit::{AuthMethod, AuthMethodDetails, LoginContext};
use crate::oauth2::provider::{ProviderConfig, ProviderKind, ProviderName, provider_for};
use crate::oauth2::{
AccountSearchField, AuthResponse, FedCMCallbackRequest, OAUTH2_CSRF_COOKIE_NAME, OAuth2Account,
OAuth2Mode, OAuth2Store, Provider, ProviderUserId, StateParams, csrf_checks, decode_state,
delete_session_and_misc_token_from_store, get_idinfo_userinfo, get_mode_from_stored_session,
get_uid_from_stored_session_by_state_param, oauth2_account_from_idinfo,
oauth2_account_from_idinfo_and_userinfo, validate_fedcm_token, validate_origin,
};
use crate::userdb::{User as DbUser, UserStore};
use crate::utils::header_set_cookie;
use super::errors::CoordinationError;
use super::login_history::{record_login_failure, record_login_success};
use super::user::gen_new_user_id;
use crate::session::{UserId, new_session_header};
static OAUTH2_USER_ACCOUNT_FIELD: LazyLock<String> =
LazyLock::new(|| env::var("OAUTH2_USER_ACCOUNT_FIELD").unwrap_or_else(|_| "email".to_string()));
static OAUTH2_USER_LABEL_FIELD: LazyLock<String> =
LazyLock::new(|| env::var("OAUTH2_USER_LABEL_FIELD").unwrap_or_else(|_| "name".to_string()));
#[derive(Debug, PartialEq, Eq)]
enum HttpMethod {
Get,
Post,
}
#[tracing::instrument(skip(ctx, auth_response, cookies, headers), fields(user_id, provider, state = %auth_response.state))]
async fn authorized_core(
ctx: &ProviderConfig,
method: HttpMethod,
auth_response: &AuthResponse,
cookies: &headers::Cookie,
headers: &HeaderMap,
) -> Result<(HeaderMap, String), CoordinationError> {
tracing::Span::current().record("provider", ctx.kind.as_str());
tracing::info!(?method, "Processing OAuth2 authorization callback");
match (method, ctx.response_mode.to_lowercase().as_str()) {
(HttpMethod::Get, "form_post") => {
return Err(CoordinationError::InvalidResponseMode(
"GET is not allowed for form_post response mode".to_string(),
));
}
(HttpMethod::Post, "query") => {
return Err(CoordinationError::InvalidResponseMode(
"POST is not allowed for query response mode".to_string(),
));
}
_ => {} }
let auth_url = ctx
.auth_url()
.await
.map_err(|e| CoordinationError::InvalidState(format!("Failed to get auth url: {e}")))?;
if ctx.redirect_uri.starts_with("http://localhost")
|| ctx.redirect_uri.starts_with("http://127.0.0.1")
{
tracing::warn!(
redirect_uri = %ctx.redirect_uri,
"Skipping origin check for HTTP localhost callback"
);
} else {
validate_origin(headers, &auth_url, &ctx.additional_allowed_origins).await?;
}
if auth_response.state.is_empty() {
return Err(CoordinationError::InvalidState(
"State is empty".to_string(),
));
}
let login_context = LoginContext::from_headers(headers);
match csrf_checks(cookies.clone(), auth_response, headers.clone()).await {
Err(e) => {
let failure_reason = format!("oauth2_csrf_failure: {}", e);
let _ = record_login_failure(
None,
AuthMethod::OAuth2,
login_context,
None,
failure_reason,
)
.await;
Err(e.into())
}
Ok(()) => process_oauth2_authorization(ctx, auth_response, login_context).await,
}
}
pub async fn get_authorized_core(
provider: ProviderName,
auth_response: &AuthResponse,
cookies: &headers::Cookie,
headers: &HeaderMap,
) -> Result<(HeaderMap, String), CoordinationError> {
let ctx = resolve_ctx(provider)?;
authorized_core(ctx, HttpMethod::Get, auth_response, cookies, headers).await
}
pub async fn post_authorized_core(
provider: ProviderName,
auth_response: &AuthResponse,
cookies: &headers::Cookie,
headers: &HeaderMap,
) -> Result<(HeaderMap, String), CoordinationError> {
let ctx = resolve_ctx(provider)?;
authorized_core(ctx, HttpMethod::Post, auth_response, cookies, headers).await
}
fn resolve_ctx(provider: ProviderName) -> Result<&'static ProviderConfig, CoordinationError> {
ProviderKind::from_provider_name(provider.as_str())
.and_then(provider_for)
.ok_or_else(|| {
CoordinationError::InvalidState(format!("OAuth2 provider not enabled: {provider}"))
})
}
#[tracing::instrument(skip(ctx, auth_response, login_context), fields(user_id, provider, state = %auth_response.state))]
async fn process_oauth2_authorization(
ctx: &ProviderConfig,
auth_response: &AuthResponse,
login_context: LoginContext,
) -> Result<(HeaderMap, String), CoordinationError> {
let provider_name = ctx.provider_name;
tracing::Span::current().record("provider", provider_name.as_str());
tracing::info!("Processing OAuth2 authorization core logic");
let oauth2_state = crate::OAuth2State::new(auth_response.state.clone())?;
let state_in_response = decode_state(&oauth2_state)?;
if state_in_response.provider.as_str() != provider_name.as_str() {
tracing::error!(
security_event = "provider_mismatch",
state_provider = %state_in_response.provider,
url_path_provider = %provider_name,
"Provider mismatch: state.provider does not match URL path provider"
);
return Err(CoordinationError::InvalidState(format!(
"Provider mismatch: state contains '{}' but URL path is '{}'",
state_in_response.provider, provider_name
)));
}
let (idinfo, userinfo) = get_idinfo_userinfo(ctx, auth_response).await?;
let oauth2_account = oauth2_account_from_idinfo_and_userinfo(&idinfo, &userinfo, ctx)?;
let state_user = get_uid_from_stored_session_by_state_param(&state_in_response).await?;
let (uid_in_state, account_in_state) = match &state_user {
Some(user) => (Some(user.id.as_str()), Some(user.account.as_str())),
None => (None, None),
};
let mode = match &state_in_response.mode_id {
Some(mode_id) => get_mode_from_stored_session(mode_id).await?,
None => {
tracing::debug!("No mode ID found");
None
}
};
let (mut headers, message) = process_authenticated_oauth2_user(
oauth2_account,
mode,
AuthMethod::OAuth2,
login_context,
uid_in_state,
account_in_state,
Some(&state_in_response),
)
.await?;
let _ = header_set_cookie(
&mut headers,
OAUTH2_CSRF_COOKIE_NAME.to_string(),
"value".to_string(),
Utc::now() - Duration::seconds(86400),
-86400,
None, )?;
Ok((headers, message))
}
async fn process_authenticated_oauth2_user(
mut oauth2_account: OAuth2Account,
mode: Option<OAuth2Mode>,
auth_method: AuthMethod,
login_context: LoginContext,
uid_in_state: Option<&str>,
account_in_state: Option<&str>,
state_params: Option<&StateParams>,
) -> Result<(HeaderMap, String), CoordinationError> {
let provider = Provider::new(oauth2_account.provider.clone())
.map_err(|e| CoordinationError::Validation(format!("Invalid provider: {e}")))?;
let provider_user_id = ProviderUserId::new(oauth2_account.provider_user_id.clone())
.map_err(|e| CoordinationError::Validation(format!("Invalid provider user ID: {e}")))?;
let existing_account =
OAuth2Store::get_oauth2_account_by_provider(provider, provider_user_id).await?;
let provider_for_history = oauth2_account.provider.clone();
let provider_user_id_for_history = oauth2_account.provider_user_id.clone();
let email_for_history = oauth2_account.email.clone();
tracing::debug!("Mode: {:?}", mode);
tracing::debug!("User ID in state: {:?}", uid_in_state);
tracing::debug!("Existing account: {:?}", existing_account);
tracing::debug!("Account in state: {:?}", account_in_state);
let (user_id, message) = match (mode.clone(), uid_in_state, &existing_account) {
(Some(OAuth2Mode::AddToUser), Some(uid), None) => {
let account_info = account_in_state.ok_or_else(|| {
CoordinationError::InvalidState(
"Missing account information in OAuth2 state".to_string(),
)
})?;
let state_params = state_params.ok_or_else(|| {
CoordinationError::InvalidState("AddToUser requires state parameters".to_string())
})?;
let message = format!("Successfully linked to {account_info}");
tracing::debug!("{}", message);
oauth2_account.user_id = uid.to_string();
OAuth2Store::upsert_oauth2_account(oauth2_account).await?;
delete_session_and_misc_token_from_store(state_params).await?;
(uid.to_string(), message)
}
(Some(OAuth2Mode::AddToUser), Some(uid), Some(existing)) => {
if uid == existing.user_id {
let account_info = account_in_state.ok_or_else(|| {
CoordinationError::InvalidState(
"Missing account information in OAuth2 state".to_string(),
)
})?;
let state_params = state_params.ok_or_else(|| {
CoordinationError::InvalidState(
"AddToUser requires state parameters".to_string(),
)
})?;
let msg = format!("Already linked to current user {account_info}");
tracing::debug!("{}", msg);
delete_session_and_misc_token_from_store(state_params).await?;
(uid.to_string(), msg)
} else {
return Err(CoordinationError::Conflict(
"This OAuth2 account is already linked to a different user".to_string(),
));
}
}
(Some(OAuth2Mode::Login), None, Some(existing)) => {
let message = format!("Signing in as {}", existing.name);
tracing::debug!("{}", message);
(existing.user_id.clone(), message)
}
(Some(OAuth2Mode::CreateUser), None, None) => {
let name = oauth2_account.name.clone();
let user_id = create_user_and_oauth2account(oauth2_account).await?;
let message = format!("Created new user {name}");
tracing::debug!("{}", message);
(user_id.clone(), message)
}
(Some(OAuth2Mode::CreateUserOrLogin), None, Some(existing)) => {
let message = format!("Signing in as {}", existing.name);
tracing::debug!("{}", message);
(existing.user_id.clone(), message)
}
(Some(OAuth2Mode::CreateUserOrLogin), None, None) => {
let name = oauth2_account.name.clone();
let user_id = create_user_and_oauth2account(oauth2_account).await?;
let message = format!("Created new user {name}");
tracing::debug!("{}", message);
(user_id.clone(), message)
}
(Some(OAuth2Mode::CreateUser), None, Some(_)) => {
tracing::debug!("This OAuth2 account is already registered");
return Err(CoordinationError::Conflict(
"This OAuth2 account is already registered".to_string(),
));
}
(Some(OAuth2Mode::Login), None, None) => {
tracing::debug!("This OAuth2 account is not registered");
return Err(CoordinationError::Conflict(
"This OAuth2 account is not registered".to_string(),
));
}
_ => {
tracing::error!("Invalid combination of mode {:?} and user state", mode);
return Err(CoordinationError::InvalidState(format!(
"Invalid combination of mode {mode:?} and user state"
)));
}
};
tracing::Span::current().record("user_id", &user_id);
tracing::info!(user_id = %user_id, "OAuth2 authorization completed successfully");
let user_id_validated = UserId::new(user_id)
.map_err(|e| CoordinationError::Validation(format!("Invalid user ID: {e}")))?;
let headers = new_session_header(user_id_validated.clone()).await?;
let _ = record_login_success(
user_id_validated,
auth_method,
login_context,
AuthMethodDetails {
provider: Some(provider_for_history),
provider_user_id: Some(provider_user_id_for_history),
email: Some(email_for_history),
..Default::default()
},
)
.await;
Ok((headers, message))
}
#[tracing::instrument(skip(request, headers), fields(user_id, provider = "google"))]
pub async fn fedcm_authorized_core(
request: &FedCMCallbackRequest,
headers: &HeaderMap,
) -> Result<(HeaderMap, String), CoordinationError> {
tracing::info!("Processing FedCM authorization callback");
let ctx = provider_for(ProviderKind::Google).ok_or_else(|| {
CoordinationError::InvalidState("Google provider not available".to_string())
})?;
let idinfo = validate_fedcm_token(ctx, &request.credential, &request.nonce_id).await?;
let oauth2_account = oauth2_account_from_idinfo(&idinfo, ctx)?;
let mode = match &request.mode {
Some(mode_str) => {
let parsed: OAuth2Mode = mode_str.parse().map_err(|_| {
CoordinationError::InvalidState(format!("Invalid FedCM mode: {mode_str}"))
})?;
Some(parsed)
}
None => None,
};
if matches!(mode, Some(OAuth2Mode::AddToUser)) {
return Err(CoordinationError::InvalidState(
"FedCM does not support add_to_user mode".to_string(),
));
}
let login_context = LoginContext::from_headers(headers);
let result = process_authenticated_oauth2_user(
oauth2_account,
mode,
AuthMethod::FedCM,
login_context,
None,
None,
None,
)
.await?;
Ok(result)
}
async fn create_user_and_oauth2account(
mut oauth2_account: OAuth2Account,
) -> Result<String, CoordinationError> {
let (account, label) = get_account_and_label_from_oauth2_account(&oauth2_account);
let new_user = DbUser {
id: gen_new_user_id().await?,
account,
label,
is_admin: *crate::config::O2P_DEMO_MODE,
sequence_number: None,
created_at: Utc::now(),
updated_at: Utc::now(),
};
let stored_user = UserStore::upsert_user(new_user).await?;
oauth2_account.user_id = stored_user.id.clone();
OAuth2Store::upsert_oauth2_account(oauth2_account).await?;
Ok(stored_user.id)
}
fn get_account_and_label_from_oauth2_account(oauth2_account: &OAuth2Account) -> (String, String) {
let (account_field, label_field) = get_oauth2_field_mappings();
let account = match account_field.as_str() {
"email" => oauth2_account.email.clone(),
"name" => oauth2_account.name.clone(),
_ => oauth2_account.email.clone(), };
let label = match label_field.as_str() {
"email" => oauth2_account.email.clone(),
"name" => oauth2_account.name.clone(),
_ => oauth2_account.name.clone(), };
(account, label)
}
fn get_oauth2_field_mappings() -> (String, String) {
(
OAUTH2_USER_ACCOUNT_FIELD.clone(),
OAUTH2_USER_LABEL_FIELD.clone(),
)
}
#[tracing::instrument(fields(user_id, provider, provider_user_id))]
pub async fn delete_oauth2_account_core(
user_id: UserId,
provider: Provider,
provider_user_id: ProviderUserId,
) -> Result<(), CoordinationError> {
tracing::info!("Attempting to delete OAuth2 account");
let accounts = OAuth2Store::get_oauth2_accounts_by(AccountSearchField::ProviderUserId(
provider_user_id.clone(),
))
.await?;
let account = accounts
.into_iter()
.find(|account| {
account.provider == provider.as_str()
&& account.provider_user_id == provider_user_id.as_str()
})
.ok_or(
CoordinationError::ResourceNotFound {
resource_type: "OAuth2Account".to_string(),
resource_id: format!("{}/{}", provider.as_str(), provider_user_id.as_str()),
}
.log(),
)?;
if account.user_id != user_id.as_str() {
return Err(CoordinationError::Unauthorized.log());
}
tracing::info!(
"Successfully deleted OAuth2 account {}/{} for user {}",
provider.as_str(),
provider_user_id.as_str(),
user_id.as_str()
);
OAuth2Store::delete_oauth2_accounts_by(AccountSearchField::ProviderUserId(provider_user_id))
.await?;
Ok(())
}
#[tracing::instrument(fields(user_id))]
pub async fn list_accounts_core(user_id: UserId) -> Result<Vec<OAuth2Account>, CoordinationError> {
tracing::debug!("Listing OAuth2 accounts for user");
let accounts = OAuth2Store::get_oauth2_accounts(user_id).await?;
tracing::info!(account_count = accounts.len(), "Retrieved OAuth2 accounts");
Ok(accounts)
}
#[cfg(test)]
mod tests;