use axum::{extract::State, http::HeaderMap, response::IntoResponse, Json};
use chrono::{Duration, Utc};
use std::sync::Arc;
use sha2::{Digest, Sha256};
use crate::callback::{AuthCallback, AuthCallbackPayload};
use crate::errors::AppError;
use crate::handlers::auth::{
call_authenticated_callback_with_timeout, call_registered_callback_with_timeout,
};
use crate::models::{AppleAuthRequest, AuthMethod, AuthResponse};
use crate::repositories::{
generate_api_key, normalize_email, ApiKeyEntity, AuditEventType, MembershipEntity,
SessionEntity, TransactionalOps, UserEntity,
};
use crate::services::EmailService;
use crate::utils::{
build_json_response_with_cookies, compute_post_login, extract_client_ip_with_fallback,
get_default_org_context, hash_refresh_token, resolve_org_assignment,
user_entity_to_auth_user, PeerIp,
};
use crate::AppState;
pub async fn apple_auth<C: AuthCallback, E: EmailService>(
State(state): State<Arc<AppState<C, E>>>,
headers: HeaderMap,
PeerIp(peer_ip): PeerIp,
Json(req): Json<AppleAuthRequest>,
) -> Result<impl IntoResponse, AppError> {
let enabled = state
.settings_service
.get_bool("auth_apple_enabled")
.await
.ok()
.flatten()
.unwrap_or(state.config.apple.enabled);
if !enabled {
return Err(AppError::NotFound("Apple auth disabled".into()));
}
let client_id = state
.settings_service
.get("auth_apple_client_id")
.await
.ok()
.flatten()
.filter(|s| !s.is_empty())
.or_else(|| state.config.apple.client_id.clone())
.ok_or_else(|| AppError::Config("Apple client ID not configured".into()))?;
let _team_id = state
.settings_service
.get("auth_apple_team_id")
.await
.ok()
.flatten()
.filter(|s| !s.is_empty())
.or_else(|| state.config.apple.team_id.clone())
.ok_or_else(|| AppError::Config("Apple team ID not configured".into()))?;
let claims = state
.apple_service
.verify_id_token(&req.id_token, &client_id)
.await?;
if let Some(ref client_nonce) = req.nonce {
let expected_hash = hex::encode(Sha256::digest(client_nonce.as_bytes()));
match &claims.nonce {
Some(token_nonce) if token_nonce == &expected_hash => { }
Some(_) => {
tracing::warn!("Apple nonce mismatch: token nonce does not match client nonce hash");
return Err(AppError::InvalidToken);
}
None => {
tracing::warn!("Apple nonce missing from token despite client sending nonce");
return Err(AppError::InvalidToken);
}
}
}
let existing_user = state.user_repo.find_by_apple_id(&claims.sub).await?;
let (user, is_new_user, api_key) = if let Some(user) = existing_user {
(user, false, None)
} else {
if !claims.is_likely_real() {
tracing::warn!(
apple_id = %claims.sub,
real_user_status = ?claims.real_user_status,
"Rejected Apple Sign-In registration: potential bot detected"
);
return Err(AppError::Validation(
"Unable to verify account authenticity. Please try again later.".to_string(),
));
}
let normalized_email = claims.email.as_deref().map(normalize_email);
let autolink_match = if let Some(ref email) = normalized_email {
state.user_repo.find_by_email(email).await?
} else {
None
};
if let Some(mut existing) = autolink_match {
let now = Utc::now();
existing.apple_id = Some(claims.sub);
existing.updated_at = now;
existing.last_login_at = Some(now);
if !existing.auth_methods.contains(&AuthMethod::Apple) {
existing.auth_methods.push(AuthMethod::Apple);
}
if existing.name.is_none() {
existing.name = req.name;
}
let user = state.user_repo.update(existing).await?;
(user, false, None)
} else {
let now = Utc::now();
let user = UserEntity {
id: uuid::Uuid::new_v4(),
email: normalized_email.clone(),
email_verified: claims.is_email_verified(),
password_hash: None,
name: req.name,
username: None,
picture: None, wallet_address: None,
google_id: None,
apple_id: Some(claims.sub),
stripe_customer_id: None,
auth_methods: vec![AuthMethod::Apple],
is_system_admin: false,
created_at: now,
updated_at: now,
last_login_at: Some(now),
welcome_completed_at: None,
};
let org_assignment = resolve_org_assignment(&state, user.id).await?;
let membership = MembershipEntity::new(user.id, org_assignment.org_id, org_assignment.role);
let raw_api_key = generate_api_key();
let api_key_entity = ApiKeyEntity::new(user.id, &raw_api_key, "default");
#[cfg(feature = "postgres")]
let user = if let Some(pool) = state.postgres_pool.as_ref() {
TransactionalOps::create_user_with_membership_and_api_key(
pool, &user, &membership, &api_key_entity,
)
.await?;
user
} else {
let created = state.user_repo.create(user).await?;
state.membership_repo.create(membership).await?;
state.api_key_repo.create(api_key_entity).await?;
created
};
#[cfg(not(feature = "postgres"))]
let user = {
let created = state.user_repo.create(user).await?;
state.membership_repo.create(membership).await?;
state.api_key_repo.create(api_key_entity).await?;
created
};
(user, true, Some(raw_api_key))
}
};
let memberships = state.membership_repo.find_by_user(user.id).await?;
let token_context = get_default_org_context(&memberships, user.is_system_admin, user.email_verified);
let session_id = uuid::Uuid::new_v4();
let token_pair =
state
.jwt_service
.generate_token_pair_with_context(user.id, session_id, &token_context)?;
let refresh_expiry =
Utc::now() + Duration::seconds(state.jwt_service.refresh_expiry_secs() as i64);
let ip_address =
extract_client_ip_with_fallback(&headers, state.config.server.trust_proxy, peer_ip);
let user_agent = headers
.get(axum::http::header::USER_AGENT)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let mut session = SessionEntity::new_with_id(
session_id,
user.id,
hash_refresh_token(&token_pair.refresh_token, &state.config.jwt.secret),
refresh_expiry,
ip_address.clone(),
user_agent.clone(),
);
session.last_strong_auth_at = Some(Utc::now());
state.session_repo.create(session).await?;
let auth_user = user_entity_to_auth_user(&user);
let payload = AuthCallbackPayload {
user: auth_user.clone(),
method: AuthMethod::Apple,
is_new_user,
session_id: session_id.to_string(),
ip_address,
user_agent,
};
let callback_data = if is_new_user {
call_registered_callback_with_timeout(&state.callback, &payload).await
} else {
call_authenticated_callback_with_timeout(&state.callback, &payload).await
};
let audit_event = if is_new_user {
AuditEventType::UserRegister
} else {
AuditEventType::UserLogin
};
let _ = state
.audit_service
.log_user_event(audit_event, user.id, Some(&headers))
.await;
let response_tokens = if state.config.cookie.enabled {
None
} else {
Some(token_pair.clone())
};
let response = AuthResponse {
user: auth_user,
tokens: response_tokens,
is_new_user,
callback_data,
api_key,
email_queued: None,
post_login: compute_post_login(&user, &state.settings_service, &*state.totp_repo, &*state.credential_repo, &*state.wallet_material_repo, &*state.storage.pending_wallet_recovery_repo).await,
};
Ok(build_json_response_with_cookies(
&state.config.cookie,
&token_pair,
state.jwt_service.refresh_expiry_secs(),
response,
))
}