use axum::{
extract::State,
http::{header, HeaderMap, StatusCode},
response::IntoResponse,
Json,
};
use chrono::{Duration, Utc};
#[cfg(feature = "postgres")]
use sqlx::PgPool;
use std::sync::Arc;
use crate::callback::{AuthCallback, AuthCallbackPayload};
use crate::errors::AppError;
use crate::models::{AuthMethod, AuthResponse, RegisterRequest};
use crate::repositories::{
default_expiry, generate_api_key, generate_verification_token, hash_verification_token,
normalize_email, ApiKeyEntity, AuditEventType, MembershipEntity, OrgEntity, OrgRole,
SessionEntity, TokenType, UserEntity,
};
use crate::services::{EmailService, TokenContext};
use crate::utils::{
attach_auth_cookies, extract_client_ip_with_fallback, hash_refresh_token, is_disposable_email,
is_valid_email, user_entity_to_auth_user, PeerIp,
};
use crate::AppState;
#[cfg(feature = "postgres")]
fn auth_methods_to_strings(methods: &[AuthMethod]) -> Vec<String> {
methods
.iter()
.map(|m| match m {
AuthMethod::Email => "email".to_string(),
AuthMethod::Google => "google".to_string(),
AuthMethod::Apple => "apple".to_string(),
AuthMethod::Solana => "solana".to_string(),
AuthMethod::WebAuthn => "webauthn".to_string(),
AuthMethod::Sso => "sso".to_string(),
})
.collect()
}
#[cfg(feature = "postgres")]
async fn register_with_transaction(
pool: &PgPool,
user: &UserEntity,
org: &OrgEntity,
membership: &MembershipEntity,
api_key: Option<&ApiKeyEntity>,
session: &SessionEntity,
) -> Result<(), AppError> {
let mut tx = pool
.begin()
.await
.map_err(|e| AppError::Internal(e.into()))?;
let auth_methods = auth_methods_to_strings(&user.auth_methods);
sqlx::query(
r#"
INSERT INTO users (id, email, email_verified, password_hash, name, picture,
wallet_address, google_id, auth_methods, is_system_admin,
created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
"#,
)
.bind(user.id)
.bind(&user.email)
.bind(user.email_verified)
.bind(&user.password_hash)
.bind(&user.name)
.bind(&user.picture)
.bind(&user.wallet_address)
.bind(&user.google_id)
.bind(&auth_methods)
.bind(user.is_system_admin)
.bind(user.created_at)
.bind(user.updated_at)
.execute(&mut *tx)
.await
.map_err(|e| AppError::Internal(e.into()))?;
sqlx::query(
r#"
INSERT INTO organizations (id, name, slug, logo_url, is_personal, owner_id)
VALUES ($1, $2, $3, $4, $5, $6)
"#,
)
.bind(org.id)
.bind(&org.name)
.bind(&org.slug)
.bind(&org.logo_url)
.bind(org.is_personal)
.bind(org.owner_id)
.execute(&mut *tx)
.await
.map_err(|e| AppError::Internal(e.into()))?;
sqlx::query(
r#"
INSERT INTO memberships (id, user_id, org_id, role)
VALUES ($1, $2, $3, $4)
"#,
)
.bind(membership.id)
.bind(membership.user_id)
.bind(membership.org_id)
.bind(membership.role.as_str())
.execute(&mut *tx)
.await
.map_err(|e| AppError::Internal(e.into()))?;
if let Some(api_key) = api_key {
sqlx::query(
r#"
INSERT INTO api_keys (id, user_id, key_hash, key_prefix, created_at, last_used_at)
VALUES ($1, $2, $3, $4, $5, $6)
"#,
)
.bind(api_key.id)
.bind(api_key.user_id)
.bind(&api_key.key_hash)
.bind(&api_key.key_prefix)
.bind(api_key.created_at)
.bind(api_key.last_used_at)
.execute(&mut *tx)
.await
.map_err(|e| AppError::Internal(e.into()))?;
}
sqlx::query(
r#"
INSERT INTO sessions (id, user_id, refresh_token_hash, ip_address, user_agent,
created_at, expires_at, revoked_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
"#,
)
.bind(session.id)
.bind(session.user_id)
.bind(&session.refresh_token_hash)
.bind(&session.ip_address)
.bind(&session.user_agent)
.bind(session.created_at)
.bind(session.expires_at)
.bind(session.revoked_at)
.execute(&mut *tx)
.await
.map_err(|e| AppError::Internal(e.into()))?;
tx.commit()
.await
.map_err(|e| AppError::Internal(e.into()))?;
Ok(())
}
pub async fn register<C: AuthCallback, E: EmailService>(
State(state): State<Arc<AppState<C, E>>>,
headers: HeaderMap,
PeerIp(peer_ip): PeerIp,
Json(req): Json<RegisterRequest>,
) -> Result<impl IntoResponse, AppError> {
if !state.config.email.enabled {
return Err(AppError::NotFound("Email auth disabled".into()));
}
if !is_valid_email(&req.email) {
return Err(AppError::Validation("Invalid email format".to_string()));
}
if state.config.email.block_disposable_emails && is_disposable_email(&req.email) {
return Err(AppError::Validation(
"Disposable email addresses are not allowed. Please use a permanent email address."
.to_string(),
));
}
let normalized_email = normalize_email(&req.email);
state.password_service.validate(&req.password)?;
let password_hash = state.password_service.hash(req.password.clone()).await?;
if state.user_repo.email_exists(&normalized_email).await? {
return Err(AppError::EmailExists);
}
let mut user =
UserEntity::new_email_user(normalized_email.clone(), password_hash, req.name.clone());
let personal_org = OrgEntity::new_personal(user.id, user.name.as_deref());
let membership = MembershipEntity::new(user.id, personal_org.id, OrgRole::Owner);
let (raw_api_key, api_key_entity) = if state.config.email.require_verification {
(None, None)
} else {
let raw = generate_api_key();
(Some(raw.clone()), Some(ApiKeyEntity::new(user.id, &raw)))
};
let session_id = uuid::Uuid::new_v4();
let token_context = TokenContext {
org_id: Some(personal_org.id),
role: Some(OrgRole::Owner.as_str().to_string()),
is_system_admin: None,
};
let token_pair =
state
.jwt_service
.generate_token_pair_with_context(user.id, session_id, &token_context)?;
let refresh_expiry =
Utc::now() + Duration::seconds(state.jwt_service.refresh_expiry_secs() as i64);
let ip_address =
extract_client_ip_with_fallback(&headers, state.config.server.trust_proxy, peer_ip);
let user_agent = headers
.get(header::USER_AGENT)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let mut session = SessionEntity::new_with_id(
session_id,
user.id,
hash_refresh_token(&token_pair.refresh_token, &state.config.jwt.secret),
refresh_expiry,
ip_address.clone(),
user_agent.clone(),
);
session.last_strong_auth_at = Some(Utc::now());
#[cfg(feature = "postgres")]
if let Some(pool) = state.postgres_pool.as_ref() {
register_with_transaction(
pool,
&user,
&personal_org,
&membership,
api_key_entity.as_ref(),
&session,
)
.await?;
} else {
user = state.user_repo.create(user).await?;
let _ = state.org_repo.create(personal_org).await?;
state.membership_repo.create(membership).await?;
if let Some(api_key_entity) = api_key_entity {
state.api_key_repo.create(api_key_entity).await?;
}
state.session_repo.create(session).await?;
}
#[cfg(not(feature = "postgres"))]
{
user = state.user_repo.create(user).await?;
let _ = state.org_repo.create(personal_org).await?;
state.membership_repo.create(membership).await?;
if let Some(api_key_entity) = api_key_entity {
state.api_key_repo.create(api_key_entity).await?;
}
state.session_repo.create(session).await?;
}
if state.config.email.require_verification {
let token = generate_verification_token();
let token_hash = hash_verification_token(&token);
state
.verification_repo
.create(
user.id,
&token_hash,
TokenType::EmailVerify,
default_expiry(TokenType::EmailVerify),
)
.await
.map_err(|e| AppError::Internal(anyhow::anyhow!("Failed to create token: {}", e)))?;
state
.comms_service
.queue_verification_email(&req.email, user.name.as_deref(), &token, Some(user.id))
.await
.map_err(|e| {
tracing::warn!(
error = %e,
user_id = %user.id,
"Failed to queue verification email"
);
e
})
.ok();
}
let auth_user = user_entity_to_auth_user(&user);
let payload = AuthCallbackPayload {
user: auth_user.clone(),
method: AuthMethod::Email,
is_new_user: true,
session_id: session_id.to_string(),
ip_address,
user_agent,
};
let callback_data =
super::call_registered_callback_with_timeout(&state.callback, &payload).await;
let _ = state
.audit_service
.log_user_event(AuditEventType::UserRegister, user.id, Some(&headers))
.await;
let response_tokens = if state.config.email.require_verification || state.config.cookie.enabled
{
None
} else {
Some(token_pair.clone())
};
let response = AuthResponse {
user: auth_user,
tokens: response_tokens,
is_new_user: true,
callback_data,
api_key: raw_api_key,
};
let resp = (StatusCode::CREATED, Json(response)).into_response();
if state.config.email.require_verification {
Ok(resp)
} else {
Ok(attach_auth_cookies(
&state.config.cookie,
&token_pair,
state.jwt_service.refresh_expiry_secs(),
resp,
))
}
}