use std::sync::Arc;
use time::OffsetDateTime;
use crate::config::AuthConfig;
use crate::crypto::{hash, token};
use crate::email::EmailSender;
use crate::error::{AuthError, OAuthError};
use crate::events::{AuthEvent, LoginFailReason, LoginMethod};
use crate::hooks::EventEmitter;
use crate::oauth::{OAuthProviderConfig, OAuthTokens, OAuthUserInfo, client};
use crate::store::{AccountStore, OAuthStateStore, SessionStore, UserStore, VerificationStore};
use crate::types::{
NewAccount, NewSession, NewUser, NewVerification, PublicAccount, Session, User, Verification,
};
#[derive(Debug)]
pub struct SignupResult {
pub user: User,
pub session: Option<Session>,
pub session_token: Option<String>,
pub verification_token: Option<String>,
}
#[derive(Debug)]
pub struct LoginResult {
pub user: User,
pub session: Session,
pub session_token: String,
}
#[derive(Debug)]
pub struct VerifyEmailResult {
pub user: User,
pub session: Option<Session>,
pub session_token: Option<String>,
}
#[derive(Debug, Default)]
pub struct RequestResetResult {
pub _private: (),
}
#[derive(Debug)]
pub struct ResetPasswordResult {
pub user: User,
}
#[derive(Debug)]
pub struct SessionResult {
pub user: User,
pub session: Session,
}
#[derive(Debug)]
pub struct LinkAccountResult {
pub user: User,
}
#[derive(Debug, Default)]
pub struct UnlinkAccountResult {
pub _private: (),
}
#[derive(Debug)]
pub struct RefreshTokenResult {
pub tokens: OAuthTokens,
}
pub struct AuthService<U, S, V, A, O, E>
where
U: UserStore,
S: SessionStore,
V: VerificationStore,
A: AccountStore,
O: OAuthStateStore,
E: EmailSender,
{
pub config: AuthConfig,
pub users: Arc<U>,
pub sessions: Arc<S>,
pub verifications: Arc<V>,
pub accounts: Arc<A>,
pub oauth_states: Arc<O>,
pub email: Arc<E>,
pub events: Arc<EventEmitter>,
}
impl<U, S, V, A, O, E> AuthService<U, S, V, A, O, E>
where
U: UserStore,
S: SessionStore,
V: VerificationStore,
A: AccountStore,
O: OAuthStateStore,
E: EmailSender,
{
pub fn new(
config: AuthConfig,
users: U,
sessions: S,
verifications: V,
accounts: A,
oauth_states: O,
email: E,
) -> Self {
Self {
config,
users: Arc::new(users),
sessions: Arc::new(sessions),
verifications: Arc::new(verifications),
accounts: Arc::new(accounts),
oauth_states: Arc::new(oauth_states),
email: Arc::new(email),
events: Arc::new(EventEmitter::new()),
}
}
pub fn with_events(mut self, events: EventEmitter) -> Self {
self.events = Arc::new(events);
self
}
pub async fn signup(
&self,
input: NewUser,
ip: Option<String>,
user_agent: Option<String>,
) -> Result<SignupResult, AuthError> {
if input.password.len() < 8 {
return Err(AuthError::WeakPassword(8));
}
let email = input.email.trim().to_lowercase();
if self.users.find_by_email(&email).await?.is_some() {
return Err(AuthError::EmailTaken);
}
let password_hash = hash::hash_password(&input.password)?;
let user = self
.users
.create_user(&email, input.name.as_deref(), Some(&password_hash))
.await?;
let verification_token = if self.config.email.send_verification_on_signup {
let identifier = format!("email-verify:{}", user.email.to_lowercase());
let raw_token = token::generate_token(self.config.token_length);
let _ = self.verifications.delete_by_identifier(&identifier).await;
self.verifications
.create_verification(NewVerification {
identifier,
token_hash: token::hash_token(&raw_token),
expires_at: OffsetDateTime::now_utc() + self.config.verification_ttl,
})
.await?;
self.email
.send_verification_email(&user, &raw_token)
.await?;
Some(raw_token)
} else {
None
};
self.events
.emit(AuthEvent::UserSignedUp {
user_id: user.id,
email: user.email.clone(),
})
.await;
let (session, session_token) = if self.config.email.auto_sign_in_after_signup {
let (session, raw_token) = self
.create_session_internal(user.id, ip, user_agent)
.await?;
(Some(session), Some(raw_token))
} else {
(None, None)
};
Ok(SignupResult {
user,
session,
session_token,
verification_token,
})
}
pub async fn login(
&self,
email: &str,
password: &str,
ip: Option<String>,
user_agent: Option<String>,
) -> Result<LoginResult, AuthError> {
let user = self
.users
.find_by_email(&email.trim().to_lowercase())
.await?
.ok_or(AuthError::InvalidCredentials)?;
let password_hash = user
.password_hash
.as_deref()
.ok_or(AuthError::InvalidCredentials)?;
if !hash::verify_password(password, password_hash)? {
self.events
.emit(AuthEvent::UserLoginFailed {
email: email.to_string(),
reason: LoginFailReason::InvalidCredentials,
})
.await;
return Err(AuthError::InvalidCredentials);
}
if self.config.email.require_verification_to_login && !user.is_verified() {
self.events
.emit(AuthEvent::UserLoginFailed {
email: email.to_string(),
reason: LoginFailReason::EmailNotVerified,
})
.await;
return Err(AuthError::EmailNotVerified);
}
let (session, session_token) = self
.create_session_internal(user.id, ip, user_agent)
.await?;
self.events
.emit(AuthEvent::UserLoggedIn {
user_id: user.id,
method: LoginMethod::Password,
})
.await;
Ok(LoginResult {
user,
session,
session_token,
})
}
pub async fn logout(&self, session_id: i64) -> Result<(), AuthError> {
self.sessions.delete_session(session_id).await
}
pub async fn logout_all(&self, user_id: i64) -> Result<(), AuthError> {
self.sessions.delete_by_user_id(user_id).await
}
pub async fn get_session(&self, raw_token: &str) -> Result<SessionResult, AuthError> {
let session = self
.sessions
.find_by_token_hash(&token::hash_token(raw_token))
.await?
.ok_or(AuthError::SessionNotFound)?;
if session.expires_at < OffsetDateTime::now_utc() {
self.sessions.delete_session(session.id).await?;
return Err(AuthError::SessionNotFound);
}
let user = self
.users
.find_by_id(session.user_id)
.await?
.ok_or(AuthError::UserNotFound)?;
Ok(SessionResult { user, session })
}
pub async fn list_sessions(&self, user_id: i64) -> Result<Vec<Session>, AuthError> {
self.sessions.find_by_user_id(user_id).await
}
pub async fn verify_email(
&self,
raw_token: &str,
ip: Option<String>,
user_agent: Option<String>,
) -> Result<VerifyEmailResult, AuthError> {
let verification = self.lookup_verification(raw_token, "email-verify:").await?;
let email = verification
.identifier
.strip_prefix("email-verify:")
.ok_or(AuthError::InvalidToken)?;
let user = self
.users
.find_by_email(email)
.await?
.ok_or(AuthError::UserNotFound)?;
self.users.set_email_verified(user.id).await?;
self.verifications
.delete_verification(verification.id)
.await?;
self.events
.emit(AuthEvent::EmailVerified { user_id: user.id })
.await;
let user = self
.users
.find_by_id(user.id)
.await?
.ok_or(AuthError::UserNotFound)?;
let (session, session_token) = if self.config.email.auto_sign_in_after_verification {
let (session, raw_token) = self
.create_session_internal(user.id, ip, user_agent)
.await?;
(Some(session), Some(raw_token))
} else {
(None, None)
};
Ok(VerifyEmailResult {
user,
session,
session_token,
})
}
pub async fn request_password_reset(
&self,
email: &str,
) -> Result<RequestResetResult, AuthError> {
let email = email.trim().to_lowercase();
if let Some(user) = self.users.find_by_email(&email).await? {
let identifier = format!("password-reset:{}", user.email.to_lowercase());
let _ = self.verifications.delete_by_identifier(&identifier).await;
let raw_token = token::generate_token(self.config.token_length);
self.verifications
.create_verification(NewVerification {
identifier,
token_hash: token::hash_token(&raw_token),
expires_at: OffsetDateTime::now_utc() + self.config.reset_ttl,
})
.await?;
self.email
.send_password_reset_email(&user, &raw_token)
.await?;
self.events
.emit(AuthEvent::PasswordResetRequested { user_id: user.id })
.await;
}
Ok(RequestResetResult::default())
}
pub async fn reset_password(
&self,
raw_token: &str,
new_password: &str,
) -> Result<ResetPasswordResult, AuthError> {
if new_password.len() < 8 {
return Err(AuthError::WeakPassword(8));
}
let verification = self
.lookup_verification(raw_token, "password-reset:")
.await?;
let email = verification
.identifier
.strip_prefix("password-reset:")
.ok_or(AuthError::InvalidToken)?;
let user = self
.users
.find_by_email(email)
.await?
.ok_or(AuthError::UserNotFound)?;
self.users
.update_password(user.id, &hash::hash_password(new_password)?)
.await?;
self.sessions.delete_by_user_id(user.id).await?;
self.verifications
.delete_verification(verification.id)
.await?;
self.events
.emit(AuthEvent::PasswordResetCompleted { user_id: user.id })
.await;
let user = self
.users
.find_by_id(user.id)
.await?
.ok_or(AuthError::UserNotFound)?;
Ok(ResetPasswordResult { user })
}
pub async fn cleanup_expired(&self) -> Result<(u64, u64, u64), AuthError> {
let sessions_deleted = self.sessions.delete_expired().await?;
let verifications_deleted = self.verifications.delete_expired().await?;
let oauth_states_deleted = self.oauth_states.delete_expired_oauth_states().await?;
Ok((
sessions_deleted,
verifications_deleted,
oauth_states_deleted,
))
}
pub async fn oauth_callback(
&self,
info: OAuthUserInfo,
tokens: OAuthTokens,
ip: Option<String>,
user_agent: Option<String>,
) -> Result<LoginResult, AuthError> {
let oauth_provider_id = info.provider_id.clone();
if let Some(account) = self
.accounts
.find_by_provider(&info.provider_id, &info.account_id)
.await?
{
let user = self
.users
.find_by_id(account.user_id)
.await?
.ok_or(AuthError::UserNotFound)?;
let (session, session_token) = self
.create_session_internal(user.id, ip, user_agent)
.await?;
self.events
.emit(AuthEvent::UserLoggedIn {
user_id: user.id,
method: LoginMethod::OAuth {
provider_id: oauth_provider_id,
},
})
.await;
return Ok(LoginResult {
user,
session,
session_token,
});
}
let user = if let Some(existing_user) = self.users.find_by_email(&info.email).await? {
if !self.config.oauth.allow_implicit_account_linking {
return Err(AuthError::OAuth(OAuthError::LinkingDisabled));
}
existing_user
} else {
self.users
.create_user(&info.email, info.name.as_deref(), None)
.await?
};
let access_token_expires_at = tokens.expires_in.map(|d| OffsetDateTime::now_utc() + d);
self.accounts
.create_account(NewAccount {
user_id: user.id,
provider_id: info.provider_id,
account_id: info.account_id,
access_token: tokens.access_token,
refresh_token: tokens.refresh_token,
access_token_expires_at,
scope: tokens.scope,
})
.await?;
if !user.is_verified() {
self.users.set_email_verified(user.id).await?;
}
let user = self
.users
.find_by_id(user.id)
.await?
.ok_or(AuthError::UserNotFound)?;
let (session, session_token) = self
.create_session_internal(user.id, ip, user_agent)
.await?;
self.events
.emit(AuthEvent::UserLoggedIn {
user_id: user.id,
method: LoginMethod::OAuth {
provider_id: oauth_provider_id,
},
})
.await;
Ok(LoginResult {
user,
session,
session_token,
})
}
pub async fn list_accounts(&self, user_id: i64) -> Result<Vec<PublicAccount>, AuthError> {
let accounts = self.accounts.find_by_user_id(user_id).await?;
Ok(accounts.into_iter().map(PublicAccount::from).collect())
}
pub async fn link_account(
&self,
user_id: i64,
info: OAuthUserInfo,
tokens: OAuthTokens,
) -> Result<LinkAccountResult, AuthError> {
let linked_provider_id = info.provider_id.clone();
if let Some(existing) = self
.accounts
.find_by_provider(&info.provider_id, &info.account_id)
.await?
{
if existing.user_id != user_id {
return Err(AuthError::OAuth(OAuthError::AccountAlreadyLinked));
}
let expires_at = tokens.expires_in.map(|d| OffsetDateTime::now_utc() + d);
self.accounts
.update_account(
existing.id,
tokens.access_token,
tokens.refresh_token,
expires_at,
tokens.scope,
)
.await?;
} else {
let expires_at = tokens.expires_in.map(|d| OffsetDateTime::now_utc() + d);
self.accounts
.create_account(NewAccount {
user_id,
provider_id: info.provider_id,
account_id: info.account_id,
access_token: tokens.access_token,
refresh_token: tokens.refresh_token,
access_token_expires_at: expires_at,
scope: tokens.scope,
})
.await?;
}
let user = self
.users
.find_by_id(user_id)
.await?
.ok_or(AuthError::UserNotFound)?;
self.events
.emit(AuthEvent::OAuthAccountLinked {
user_id,
provider_id: linked_provider_id,
})
.await;
Ok(LinkAccountResult { user })
}
pub async fn unlink_account(
&self,
user_id: i64,
account_id: i64,
) -> Result<UnlinkAccountResult, AuthError> {
let accounts = self.accounts.find_by_user_id(user_id).await?;
let target = accounts
.iter()
.find(|a| a.id == account_id)
.ok_or(AuthError::OAuth(OAuthError::AccountNotFound))?;
let user = self
.users
.find_by_id(user_id)
.await?
.ok_or(AuthError::UserNotFound)?;
if user.password_hash.is_none() && accounts.len() == 1 {
return Err(AuthError::OAuth(OAuthError::LastAuthMethod));
}
let unlinked_provider_id = target.provider_id.clone();
self.accounts.delete_account(target.id).await?;
self.events
.emit(AuthEvent::OAuthAccountUnlinked {
user_id,
provider_id: unlinked_provider_id,
})
.await;
Ok(UnlinkAccountResult::default())
}
pub async fn refresh_oauth_token(
&self,
user_id: i64,
account_id: i64,
provider_config: &OAuthProviderConfig,
) -> Result<RefreshTokenResult, AuthError> {
let accounts = self.accounts.find_by_user_id(user_id).await?;
let account = accounts
.iter()
.find(|a| a.id == account_id)
.ok_or(AuthError::OAuth(OAuthError::AccountNotFound))?;
let refresh_token_str = account
.refresh_token
.as_deref()
.ok_or(AuthError::OAuth(OAuthError::NoRefreshToken))?;
let tokens = client::refresh_access_token(provider_config, refresh_token_str).await?;
let expires_at = tokens.expires_in.map(|d| OffsetDateTime::now_utc() + d);
self.accounts
.update_account(
account.id,
tokens.access_token.clone(),
tokens.refresh_token.clone(),
expires_at,
tokens.scope.clone(),
)
.await?;
Ok(RefreshTokenResult { tokens })
}
async fn create_session_internal(
&self,
user_id: i64,
ip: Option<String>,
user_agent: Option<String>,
) -> Result<(Session, String), AuthError> {
let raw_token = token::generate_token(self.config.token_length);
let session = self
.sessions
.create_session(NewSession {
token_hash: token::hash_token(&raw_token),
user_id,
expires_at: OffsetDateTime::now_utc() + self.config.session_ttl,
ip_address: ip.clone(),
user_agent,
})
.await?;
self.events
.emit(AuthEvent::SessionCreated {
user_id,
session_id: session.id,
ip: ip.clone(),
})
.await;
Ok((session, raw_token))
}
async fn lookup_verification(
&self,
raw_token: &str,
prefix: &str,
) -> Result<Verification, AuthError> {
let verification = self
.verifications
.find_by_token_hash(&token::hash_token(raw_token))
.await?
.ok_or(AuthError::InvalidToken)?;
if !verification.identifier.starts_with(prefix) {
return Err(AuthError::InvalidToken);
}
if verification.expires_at < OffsetDateTime::now_utc() {
self.verifications
.delete_verification(verification.id)
.await?;
return Err(AuthError::InvalidToken);
}
Ok(verification)
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use time::OffsetDateTime;
use super::AuthService;
use crate::config::AuthConfig;
use crate::email::EmailSender;
use crate::error::{AuthError, OAuthError};
use crate::oauth::{OAuthProviderConfig, OAuthTokens};
use crate::store::{AccountStore, OAuthStateStore, SessionStore, UserStore, VerificationStore};
use crate::types::{
Account, NewAccount, NewOAuthState, NewSession, NewUser, NewVerification, OAuthIntent,
OAuthState, Session, User, Verification,
};
#[derive(Default)]
struct MemoryState {
next_user_id: i64,
next_session_id: i64,
next_verification_id: i64,
next_account_id: i64,
next_oauth_state_id: i64,
users: HashMap<i64, User>,
sessions: HashMap<i64, Session>,
verifications: HashMap<i64, Verification>,
accounts: HashMap<i64, Account>,
oauth_states: HashMap<i64, OAuthState>,
}
#[derive(Clone, Default)]
struct MemoryStore {
inner: Arc<Mutex<MemoryState>>,
}
#[async_trait]
impl UserStore for MemoryStore {
async fn create_user(
&self,
email: &str,
name: Option<&str>,
password_hash: Option<&str>,
) -> Result<User, AuthError> {
let mut state = self.inner.lock().unwrap();
if state.users.values().any(|user| user.email == email) {
return Err(AuthError::EmailTaken);
}
state.next_user_id += 1;
let now = OffsetDateTime::now_utc();
let user = User {
id: state.next_user_id,
email: email.to_string(),
name: name.map(str::to_owned),
password_hash: password_hash.map(str::to_owned),
email_verified_at: None,
image: None,
created_at: now,
updated_at: now,
};
state.users.insert(user.id, user.clone());
Ok(user)
}
async fn find_by_email(&self, email: &str) -> Result<Option<User>, AuthError> {
let state = self.inner.lock().unwrap();
Ok(state
.users
.values()
.find(|user| user.email == email)
.cloned())
}
async fn find_by_id(&self, id: i64) -> Result<Option<User>, AuthError> {
Ok(self.inner.lock().unwrap().users.get(&id).cloned())
}
async fn set_email_verified(&self, user_id: i64) -> Result<(), AuthError> {
let mut state = self.inner.lock().unwrap();
let user = state
.users
.get_mut(&user_id)
.ok_or(AuthError::UserNotFound)?;
user.email_verified_at = Some(OffsetDateTime::now_utc());
user.updated_at = OffsetDateTime::now_utc();
Ok(())
}
async fn update_password(
&self,
user_id: i64,
password_hash: &str,
) -> Result<(), AuthError> {
let mut state = self.inner.lock().unwrap();
let user = state
.users
.get_mut(&user_id)
.ok_or(AuthError::UserNotFound)?;
user.password_hash = Some(password_hash.to_string());
user.updated_at = OffsetDateTime::now_utc();
Ok(())
}
async fn delete_user(&self, user_id: i64) -> Result<(), AuthError> {
self.inner.lock().unwrap().users.remove(&user_id);
Ok(())
}
}
#[async_trait]
impl SessionStore for MemoryStore {
async fn create_session(&self, session: NewSession) -> Result<Session, AuthError> {
let mut state = self.inner.lock().unwrap();
state.next_session_id += 1;
let now = OffsetDateTime::now_utc();
let session = Session {
id: state.next_session_id,
token_hash: session.token_hash,
user_id: session.user_id,
expires_at: session.expires_at,
ip_address: session.ip_address,
user_agent: session.user_agent,
created_at: now,
updated_at: now,
};
state.sessions.insert(session.id, session.clone());
Ok(session)
}
async fn find_by_token_hash(&self, token_hash: &str) -> Result<Option<Session>, AuthError> {
let state = self.inner.lock().unwrap();
Ok(state
.sessions
.values()
.find(|session| session.token_hash == token_hash)
.cloned())
}
async fn find_by_user_id(&self, user_id: i64) -> Result<Vec<Session>, AuthError> {
let state = self.inner.lock().unwrap();
Ok(state
.sessions
.values()
.filter(|session| session.user_id == user_id)
.cloned()
.collect())
}
async fn delete_session(&self, id: i64) -> Result<(), AuthError> {
self.inner.lock().unwrap().sessions.remove(&id);
Ok(())
}
async fn delete_by_user_id(&self, user_id: i64) -> Result<(), AuthError> {
self.inner
.lock()
.unwrap()
.sessions
.retain(|_, session| session.user_id != user_id);
Ok(())
}
async fn delete_expired(&self) -> Result<u64, AuthError> {
let now = OffsetDateTime::now_utc();
let mut state = self.inner.lock().unwrap();
let before = state.sessions.len();
state
.sessions
.retain(|_, session| session.expires_at >= now);
Ok((before - state.sessions.len()) as u64)
}
}
#[async_trait]
impl VerificationStore for MemoryStore {
async fn create_verification(
&self,
verification: NewVerification,
) -> Result<Verification, AuthError> {
let mut state = self.inner.lock().unwrap();
state.next_verification_id += 1;
let now = OffsetDateTime::now_utc();
let verification = Verification {
id: state.next_verification_id,
identifier: verification.identifier,
token_hash: verification.token_hash,
expires_at: verification.expires_at,
created_at: now,
updated_at: now,
};
state
.verifications
.insert(verification.id, verification.clone());
Ok(verification)
}
async fn find_by_identifier(
&self,
identifier: &str,
) -> Result<Option<Verification>, AuthError> {
let state = self.inner.lock().unwrap();
Ok(state
.verifications
.values()
.find(|verification| verification.identifier == identifier)
.cloned())
}
async fn find_by_token_hash(
&self,
token_hash: &str,
) -> Result<Option<Verification>, AuthError> {
let state = self.inner.lock().unwrap();
Ok(state
.verifications
.values()
.find(|verification| verification.token_hash == token_hash)
.cloned())
}
async fn delete_verification(&self, id: i64) -> Result<(), AuthError> {
self.inner.lock().unwrap().verifications.remove(&id);
Ok(())
}
async fn delete_by_identifier(&self, identifier: &str) -> Result<(), AuthError> {
self.inner
.lock()
.unwrap()
.verifications
.retain(|_, verification| verification.identifier != identifier);
Ok(())
}
async fn delete_expired(&self) -> Result<u64, AuthError> {
let now = OffsetDateTime::now_utc();
let mut state = self.inner.lock().unwrap();
let before = state.verifications.len();
state
.verifications
.retain(|_, verification| verification.expires_at >= now);
Ok((before - state.verifications.len()) as u64)
}
}
#[async_trait]
impl AccountStore for MemoryStore {
async fn create_account(&self, account: NewAccount) -> Result<Account, AuthError> {
let mut state = self.inner.lock().unwrap();
state.next_account_id += 1;
let now = OffsetDateTime::now_utc();
let account = Account {
id: state.next_account_id,
user_id: account.user_id,
provider_id: account.provider_id,
account_id: account.account_id,
access_token: account.access_token,
refresh_token: account.refresh_token,
access_token_expires_at: account.access_token_expires_at,
scope: account.scope,
created_at: now,
updated_at: now,
};
state.accounts.insert(account.id, account.clone());
Ok(account)
}
async fn find_by_provider(
&self,
provider_id: &str,
account_id: &str,
) -> Result<Option<Account>, AuthError> {
let state = self.inner.lock().unwrap();
Ok(state
.accounts
.values()
.find(|account| {
account.provider_id == provider_id && account.account_id == account_id
})
.cloned())
}
async fn find_by_user_id(&self, user_id: i64) -> Result<Vec<Account>, AuthError> {
let state = self.inner.lock().unwrap();
Ok(state
.accounts
.values()
.filter(|account| account.user_id == user_id)
.cloned()
.collect())
}
async fn delete_account(&self, id: i64) -> Result<(), AuthError> {
self.inner.lock().unwrap().accounts.remove(&id);
Ok(())
}
async fn update_account(
&self,
id: i64,
access_token: Option<String>,
refresh_token: Option<String>,
access_token_expires_at: Option<OffsetDateTime>,
scope: Option<String>,
) -> Result<(), AuthError> {
let mut state = self.inner.lock().unwrap();
let account = state
.accounts
.get_mut(&id)
.ok_or(AuthError::OAuth(OAuthError::AccountNotFound))?;
account.access_token = access_token;
account.refresh_token = refresh_token;
account.access_token_expires_at = access_token_expires_at;
account.scope = scope;
account.updated_at = OffsetDateTime::now_utc();
Ok(())
}
}
#[async_trait]
impl OAuthStateStore for MemoryStore {
async fn create_oauth_state(
&self,
new_state: NewOAuthState,
) -> Result<OAuthState, AuthError> {
let mut state = self.inner.lock().unwrap();
state.next_oauth_state_id += 1;
let now = OffsetDateTime::now_utc();
let oauth_state = OAuthState {
id: state.next_oauth_state_id,
provider_id: new_state.provider_id,
csrf_state: new_state.csrf_state,
pkce_verifier: new_state.pkce_verifier,
intent: new_state.intent,
link_user_id: new_state.link_user_id,
expires_at: new_state.expires_at,
created_at: now,
};
state
.oauth_states
.insert(oauth_state.id, oauth_state.clone());
Ok(oauth_state)
}
async fn find_by_csrf_state(
&self,
csrf_state: &str,
) -> Result<Option<OAuthState>, AuthError> {
let state = self.inner.lock().unwrap();
Ok(state
.oauth_states
.values()
.find(|s| s.csrf_state == csrf_state)
.cloned())
}
async fn delete_oauth_state(&self, id: i64) -> Result<(), AuthError> {
self.inner.lock().unwrap().oauth_states.remove(&id);
Ok(())
}
async fn delete_expired_oauth_states(&self) -> Result<u64, AuthError> {
let now = OffsetDateTime::now_utc();
let mut state = self.inner.lock().unwrap();
let before = state.oauth_states.len();
state.oauth_states.retain(|_, s| s.expires_at >= now);
Ok((before - state.oauth_states.len()) as u64)
}
}
#[derive(Clone, Default)]
struct TestEmailSender {
verification_tokens: Arc<Mutex<Vec<String>>>,
reset_tokens: Arc<Mutex<Vec<String>>>,
}
#[async_trait]
impl EmailSender for TestEmailSender {
async fn send_verification_email(
&self,
_user: &User,
token: &str,
) -> Result<(), AuthError> {
self.verification_tokens
.lock()
.unwrap()
.push(token.to_string());
Ok(())
}
async fn send_password_reset_email(
&self,
_user: &User,
token: &str,
) -> Result<(), AuthError> {
self.reset_tokens.lock().unwrap().push(token.to_string());
Ok(())
}
}
#[tokio::test]
async fn signup_verify_login_and_reset_flow_works() {
let store = MemoryStore::default();
let email = TestEmailSender::default();
let service = AuthService::new(
AuthConfig::default(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
email.clone(),
);
let signup = service
.signup(
NewUser {
email: "test@example.com".to_string(),
name: Some("Test".to_string()),
password: "supersecret".to_string(),
},
Some("127.0.0.1".to_string()),
Some("test-agent".to_string()),
)
.await
.unwrap();
assert_eq!(signup.user.email, "test@example.com");
assert!(signup.session.is_some());
assert_eq!(email.verification_tokens.lock().unwrap().len(), 1);
let verification_token = email.verification_tokens.lock().unwrap()[0].clone();
let verify = service
.verify_email(&verification_token, None, None)
.await
.unwrap();
assert!(verify.user.is_verified());
let login = service
.login("test@example.com", "supersecret", None, None)
.await
.unwrap();
assert_eq!(login.user.email, "test@example.com");
service
.request_password_reset("test@example.com")
.await
.unwrap();
let reset_token = email.reset_tokens.lock().unwrap()[0].clone();
service
.reset_password(&reset_token, "newpassword")
.await
.unwrap();
let login = service
.login("test@example.com", "newpassword", None, None)
.await
.unwrap();
assert_eq!(login.user.email, "test@example.com");
}
#[tokio::test]
async fn oauth_callback_creates_new_user_and_account() {
let store = MemoryStore::default();
let email = TestEmailSender::default();
let service = AuthService::new(
AuthConfig::default(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
email.clone(),
);
let oauth_info = crate::oauth::OAuthUserInfo {
provider_id: "google".to_string(),
account_id: "google-123".to_string(),
email: "oauth@example.com".to_string(),
name: Some("OAuth User".to_string()),
image: Some("https://example.com/avatar.jpg".to_string()),
};
let result = service
.oauth_callback(
oauth_info,
OAuthTokens::default(),
Some("127.0.0.1".to_string()),
Some("test-agent".to_string()),
)
.await
.unwrap();
assert_eq!(result.user.email, "oauth@example.com");
assert_eq!(result.user.name, Some("OAuth User".to_string()));
assert!(result.user.is_verified());
assert!(result.user.password_hash.is_none());
let accounts = AccountStore::find_by_user_id(&store, result.user.id)
.await
.unwrap();
assert_eq!(accounts.len(), 1);
assert_eq!(accounts[0].provider_id, "google");
assert_eq!(accounts[0].account_id, "google-123");
}
#[tokio::test]
async fn oauth_callback_links_existing_user_by_email() {
let store = MemoryStore::default();
let email = TestEmailSender::default();
let service = AuthService::new(
AuthConfig::default(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
email.clone(),
);
let existing_user = store
.create_user(
"existing@example.com",
Some("Existing User"),
Some("hash123"),
)
.await
.unwrap();
let oauth_info = crate::oauth::OAuthUserInfo {
provider_id: "github".to_string(),
account_id: "github-456".to_string(),
email: "existing@example.com".to_string(),
name: Some("GitHub User".to_string()),
image: None,
};
let result = service
.oauth_callback(
oauth_info,
OAuthTokens::default(),
Some("127.0.0.1".to_string()),
Some("test-agent".to_string()),
)
.await
.unwrap();
assert_eq!(result.user.id, existing_user.id);
assert_eq!(result.user.email, "existing@example.com");
assert!(result.user.is_verified());
let accounts = AccountStore::find_by_user_id(&store, result.user.id)
.await
.unwrap();
assert_eq!(accounts.len(), 1);
assert_eq!(accounts[0].provider_id, "github");
assert_eq!(accounts[0].account_id, "github-456");
}
#[tokio::test]
async fn oauth_callback_logs_in_existing_account() {
let store = MemoryStore::default();
let email = TestEmailSender::default();
let service = AuthService::new(
AuthConfig::default(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
email.clone(),
);
let user = store
.create_user("oauth@example.com", Some("OAuth User"), None)
.await
.unwrap();
store
.create_account(crate::types::NewAccount {
user_id: user.id,
provider_id: "google".to_string(),
account_id: "google-789".to_string(),
access_token: None,
refresh_token: None,
access_token_expires_at: None,
scope: None,
})
.await
.unwrap();
let oauth_info = crate::oauth::OAuthUserInfo {
provider_id: "google".to_string(),
account_id: "google-789".to_string(),
email: "oauth@example.com".to_string(),
name: Some("OAuth User".to_string()),
image: None,
};
let result = service
.oauth_callback(
oauth_info,
OAuthTokens::default(),
Some("127.0.0.1".to_string()),
Some("test-agent".to_string()),
)
.await
.unwrap();
assert_eq!(result.user.id, user.id);
assert_eq!(result.user.email, "oauth@example.com");
let accounts = AccountStore::find_by_user_id(&store, result.user.id)
.await
.unwrap();
assert_eq!(accounts.len(), 1);
}
#[tokio::test]
async fn oauth_callback_respects_linking_policy() {
let store = MemoryStore::default();
let email = TestEmailSender::default();
let mut config = AuthConfig::default();
config.oauth.allow_implicit_account_linking = false;
let service = AuthService::new(
config,
store.clone(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
email.clone(),
);
store
.create_user(
"existing@example.com",
Some("Existing User"),
Some("hash123"),
)
.await
.unwrap();
let oauth_info = crate::oauth::OAuthUserInfo {
provider_id: "google".to_string(),
account_id: "google-999".to_string(),
email: "existing@example.com".to_string(),
name: Some("OAuth User".to_string()),
image: None,
};
let result = service
.oauth_callback(
oauth_info,
OAuthTokens::default(),
Some("127.0.0.1".to_string()),
Some("test-agent".to_string()),
)
.await;
assert!(result.is_err());
match result {
Err(AuthError::OAuth(OAuthError::LinkingDisabled)) => {}
_ => panic!("Expected OAuth linking-disabled error"),
}
}
#[tokio::test]
async fn cleanup_expired_removes_sessions_verifications_and_oauth_states() {
let store = MemoryStore::default();
let email = TestEmailSender::default();
let service = AuthService::new(
AuthConfig::default(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
email,
);
store
.create_session(NewSession {
token_hash: "expired-session".to_string(),
user_id: 1,
expires_at: OffsetDateTime::now_utc() - time::Duration::hours(1),
ip_address: None,
user_agent: None,
})
.await
.unwrap();
store
.create_verification(NewVerification {
identifier: "email-verify:test@example.com".to_string(),
token_hash: "expired-verification".to_string(),
expires_at: OffsetDateTime::now_utc() - time::Duration::hours(1),
})
.await
.unwrap();
store
.create_oauth_state(NewOAuthState {
provider_id: "google".to_string(),
csrf_state: "expired-oauth-state".to_string(),
pkce_verifier: "pkce-verifier".to_string(),
intent: OAuthIntent::Login,
link_user_id: None,
expires_at: OffsetDateTime::now_utc() - time::Duration::hours(1),
})
.await
.unwrap();
let deleted = service.cleanup_expired().await.unwrap();
assert_eq!(deleted, (1, 1, 1));
}
#[tokio::test]
async fn list_accounts_returns_empty_for_user_with_no_accounts() {
let store = MemoryStore::default();
let email = TestEmailSender::default();
let service = AuthService::new(
AuthConfig::default(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
email,
);
let user = store
.create_user("test@example.com", Some("Test User"), Some("hash123"))
.await
.unwrap();
let accounts = service.list_accounts(user.id).await.unwrap();
assert!(accounts.is_empty());
}
#[tokio::test]
async fn list_accounts_returns_public_accounts_without_tokens() {
let store = MemoryStore::default();
let email = TestEmailSender::default();
let service = AuthService::new(
AuthConfig::default(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
email,
);
let user = store
.create_user("test@example.com", Some("Test User"), Some("hash123"))
.await
.unwrap();
store
.create_account(NewAccount {
user_id: user.id,
provider_id: "google".to_string(),
account_id: "google-123".to_string(),
access_token: Some("secret-token".to_string()),
refresh_token: Some("refresh-secret".to_string()),
access_token_expires_at: None,
scope: Some("openid,email".to_string()),
})
.await
.unwrap();
store
.create_account(NewAccount {
user_id: user.id,
provider_id: "github".to_string(),
account_id: "github-456".to_string(),
access_token: Some("another-token".to_string()),
refresh_token: None,
access_token_expires_at: None,
scope: None,
})
.await
.unwrap();
let accounts = service.list_accounts(user.id).await.unwrap();
assert_eq!(accounts.len(), 2);
let provider_ids: Vec<&str> = accounts.iter().map(|a| a.provider_id.as_str()).collect();
assert!(provider_ids.contains(&"google"));
assert!(provider_ids.contains(&"github"));
assert!(!format!("{:?}", accounts).contains("secret"));
}
#[tokio::test]
async fn link_account_creates_account_for_authenticated_user() {
let store = MemoryStore::default();
let email = TestEmailSender::default();
let service = AuthService::new(
AuthConfig::default(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
email,
);
let user = store
.create_user("test@example.com", Some("Test User"), Some("hash123"))
.await
.unwrap();
let oauth_info = crate::oauth::OAuthUserInfo {
provider_id: "google".to_string(),
account_id: "google-123".to_string(),
email: "test@example.com".to_string(),
name: None,
image: None,
};
let result = service
.link_account(user.id, oauth_info, OAuthTokens::default())
.await;
assert!(result.is_ok());
let accounts = AccountStore::find_by_user_id(&store, user.id)
.await
.unwrap();
assert_eq!(accounts.len(), 1);
assert_eq!(accounts[0].provider_id, "google");
assert_eq!(accounts[0].account_id, "google-123");
}
#[tokio::test]
async fn link_account_is_idempotent_when_already_linked_to_same_user() {
let store = MemoryStore::default();
let email = TestEmailSender::default();
let service = AuthService::new(
AuthConfig::default(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
email,
);
let user = store
.create_user("test@example.com", Some("Test User"), Some("hash123"))
.await
.unwrap();
store
.create_account(NewAccount {
user_id: user.id,
provider_id: "google".to_string(),
account_id: "google-123".to_string(),
access_token: Some("old-token".to_string()),
refresh_token: Some("old-refresh".to_string()),
access_token_expires_at: None,
scope: None,
})
.await
.unwrap();
let oauth_info = crate::oauth::OAuthUserInfo {
provider_id: "google".to_string(),
account_id: "google-123".to_string(),
email: "test@example.com".to_string(),
name: None,
image: None,
};
let result = service
.link_account(user.id, oauth_info, OAuthTokens::default())
.await;
assert!(result.is_ok());
let accounts = AccountStore::find_by_user_id(&store, user.id)
.await
.unwrap();
assert_eq!(accounts.len(), 1);
}
#[tokio::test]
async fn link_account_rejects_when_already_linked_to_different_user() {
let store = MemoryStore::default();
let email = TestEmailSender::default();
let service = AuthService::new(
AuthConfig::default(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
email,
);
let user_a = store
.create_user("usera@example.com", Some("User A"), Some("hash123"))
.await
.unwrap();
let user_b = store
.create_user("userb@example.com", Some("User B"), Some("hash456"))
.await
.unwrap();
store
.create_account(NewAccount {
user_id: user_a.id,
provider_id: "google".to_string(),
account_id: "google-123".to_string(),
access_token: None,
refresh_token: None,
access_token_expires_at: None,
scope: None,
})
.await
.unwrap();
let oauth_info = crate::oauth::OAuthUserInfo {
provider_id: "google".to_string(),
account_id: "google-123".to_string(),
email: "userb@example.com".to_string(),
name: None,
image: None,
};
let result = service
.link_account(user_b.id, oauth_info, OAuthTokens::default())
.await;
assert!(result.is_err());
match result {
Err(AuthError::OAuth(OAuthError::AccountAlreadyLinked)) => {}
_ => panic!("Expected AccountAlreadyLinked error"),
}
}
#[tokio::test]
async fn link_account_updates_existing_account_tokens() {
let store = MemoryStore::default();
let email = TestEmailSender::default();
let service = AuthService::new(
AuthConfig::default(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
email,
);
let user = store
.create_user("test@example.com", Some("Test User"), Some("hash123"))
.await
.unwrap();
store
.create_account(NewAccount {
user_id: user.id,
provider_id: "google".to_string(),
account_id: "google-123".to_string(),
access_token: Some("old-token".to_string()),
refresh_token: Some("old-refresh".to_string()),
access_token_expires_at: None,
scope: None,
})
.await
.unwrap();
let oauth_info = crate::oauth::OAuthUserInfo {
provider_id: "google".to_string(),
account_id: "google-123".to_string(),
email: "test@example.com".to_string(),
name: None,
image: None,
};
let tokens = OAuthTokens {
access_token: Some("new-token".to_string()),
refresh_token: Some("new-refresh".to_string()),
..Default::default()
};
service
.link_account(user.id, oauth_info, tokens)
.await
.unwrap();
let accounts = AccountStore::find_by_user_id(&store, user.id)
.await
.unwrap();
assert_eq!(accounts.len(), 1);
assert_eq!(accounts[0].access_token, Some("new-token".to_string()));
assert_eq!(accounts[0].refresh_token, Some("new-refresh".to_string()));
}
#[tokio::test]
async fn unlink_account_removes_account() {
let store = MemoryStore::default();
let email = TestEmailSender::default();
let service = AuthService::new(
AuthConfig::default(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
email,
);
let user = store
.create_user("test@example.com", Some("Test User"), Some("password-hash"))
.await
.unwrap();
let account1 = store
.create_account(NewAccount {
user_id: user.id,
provider_id: "google".to_string(),
account_id: "google-123".to_string(),
access_token: None,
refresh_token: None,
access_token_expires_at: None,
scope: None,
})
.await
.unwrap();
store
.create_account(NewAccount {
user_id: user.id,
provider_id: "github".to_string(),
account_id: "github-456".to_string(),
access_token: None,
refresh_token: None,
access_token_expires_at: None,
scope: None,
})
.await
.unwrap();
service.unlink_account(user.id, account1.id).await.unwrap();
let accounts = AccountStore::find_by_user_id(&store, user.id)
.await
.unwrap();
assert_eq!(accounts.len(), 1);
assert_eq!(accounts[0].provider_id, "github");
}
#[tokio::test]
async fn unlink_account_rejects_last_auth_method() {
let store = MemoryStore::default();
let email = TestEmailSender::default();
let service = AuthService::new(
AuthConfig::default(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
email,
);
let user = store
.create_user("test@example.com", Some("Test User"), None)
.await
.unwrap();
let account = store
.create_account(NewAccount {
user_id: user.id,
provider_id: "google".to_string(),
account_id: "google-123".to_string(),
access_token: None,
refresh_token: None,
access_token_expires_at: None,
scope: None,
})
.await
.unwrap();
let result = service.unlink_account(user.id, account.id).await;
assert!(result.is_err());
match result {
Err(AuthError::OAuth(OAuthError::LastAuthMethod)) => {}
_ => panic!("Expected LastAuthMethod error"),
}
}
#[tokio::test]
async fn refresh_oauth_token_rejects_when_no_refresh_token() {
let store = MemoryStore::default();
let email = TestEmailSender::default();
let service = AuthService::new(
AuthConfig::default(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
email,
);
let user = store
.create_user("test@example.com", Some("Test User"), None)
.await
.unwrap();
let account = store
.create_account(NewAccount {
user_id: user.id,
provider_id: "google".to_string(),
account_id: "google-123".to_string(),
access_token: Some("old-token".to_string()),
refresh_token: None, access_token_expires_at: None,
scope: None,
})
.await
.unwrap();
let provider_config = OAuthProviderConfig {
provider_id: "google".to_string(),
client_id: "test".to_string(),
client_secret: "test".to_string(),
auth_url: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
token_url: "https://oauth2.googleapis.com/token".to_string(),
userinfo_url: "https://www.googleapis.com/oauth2/v2/userinfo".to_string(),
redirect_url: "https://localhost/callback".to_string(),
scopes: vec!["openid".to_string()],
};
let result = service
.refresh_oauth_token(user.id, account.id, &provider_config)
.await;
assert!(result.is_err());
match result {
Err(AuthError::OAuth(OAuthError::NoRefreshToken)) => {}
_ => panic!("Expected NoRefreshToken error"),
}
}
#[tokio::test]
async fn refresh_oauth_token_rejects_for_wrong_account_id() {
let store = MemoryStore::default();
let email = TestEmailSender::default();
let service = AuthService::new(
AuthConfig::default(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
store.clone(),
email,
);
let user = store
.create_user("test@example.com", Some("Test User"), None)
.await
.unwrap();
store
.create_account(NewAccount {
user_id: user.id,
provider_id: "google".to_string(),
account_id: "google-123".to_string(),
access_token: Some("old-token".to_string()),
refresh_token: Some("refresh-token".to_string()),
access_token_expires_at: None,
scope: None,
})
.await
.unwrap();
let provider_config = OAuthProviderConfig {
provider_id: "google".to_string(),
client_id: "test".to_string(),
client_secret: "test".to_string(),
auth_url: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
token_url: "https://oauth2.googleapis.com/token".to_string(),
userinfo_url: "https://www.googleapis.com/oauth2/v2/userinfo".to_string(),
redirect_url: "https://localhost/callback".to_string(),
scopes: vec!["openid".to_string()],
};
let result = service
.refresh_oauth_token(user.id, 9999, &provider_config)
.await;
assert!(result.is_err());
match result {
Err(AuthError::OAuth(OAuthError::AccountNotFound)) => {}
_ => panic!("Expected AccountNotFound error"),
}
}
}