1use std::sync::Arc;
2
3use time::OffsetDateTime;
4
5use crate::config::AuthConfig;
6use crate::crypto::{hash, token};
7use crate::email::EmailSender;
8use crate::error::{AuthError, OAuthError};
9use crate::oauth::{OAuthTokens, OAuthUserInfo};
10use crate::store::{AccountStore, OAuthStateStore, SessionStore, UserStore, VerificationStore};
11use crate::types::{NewAccount, NewSession, NewUser, NewVerification, Session, User, Verification};
12
13#[derive(Debug)]
15pub struct SignupResult {
16 pub user: User,
18 pub session: Option<Session>,
20 pub session_token: Option<String>,
22 pub verification_token: Option<String>,
24}
25
26#[derive(Debug)]
28pub struct LoginResult {
29 pub user: User,
31 pub session: Session,
33 pub session_token: String,
35}
36
37#[derive(Debug)]
39pub struct VerifyEmailResult {
40 pub user: User,
42 pub session: Option<Session>,
44 pub session_token: Option<String>,
46}
47
48#[derive(Debug, Default)]
50pub struct RequestResetResult {
51 pub _private: (),
53}
54
55#[derive(Debug)]
57pub struct ResetPasswordResult {
58 pub user: User,
60}
61
62#[derive(Debug)]
64pub struct SessionResult {
65 pub user: User,
67 pub session: Session,
69}
70
71pub struct AuthService<U, S, V, A, O, E>
73where
74 U: UserStore,
75 S: SessionStore,
76 V: VerificationStore,
77 A: AccountStore,
78 O: OAuthStateStore,
79 E: EmailSender,
80{
81 pub config: AuthConfig,
83 pub users: Arc<U>,
85 pub sessions: Arc<S>,
87 pub verifications: Arc<V>,
89 pub accounts: Arc<A>,
91 pub oauth_states: Arc<O>,
93 pub email: Arc<E>,
95}
96
97impl<U, S, V, A, O, E> AuthService<U, S, V, A, O, E>
98where
99 U: UserStore,
100 S: SessionStore,
101 V: VerificationStore,
102 A: AccountStore,
103 O: OAuthStateStore,
104 E: EmailSender,
105{
106 pub fn new(
108 config: AuthConfig,
109 users: U,
110 sessions: S,
111 verifications: V,
112 accounts: A,
113 oauth_states: O,
114 email: E,
115 ) -> Self {
116 Self {
117 config,
118 users: Arc::new(users),
119 sessions: Arc::new(sessions),
120 verifications: Arc::new(verifications),
121 accounts: Arc::new(accounts),
122 oauth_states: Arc::new(oauth_states),
123 email: Arc::new(email),
124 }
125 }
126
127 pub async fn signup(
129 &self,
130 input: NewUser,
131 ip: Option<String>,
132 user_agent: Option<String>,
133 ) -> Result<SignupResult, AuthError> {
134 if input.password.len() < 8 {
135 return Err(AuthError::WeakPassword(8));
136 }
137
138 let email = input.email.trim().to_lowercase();
139 if self.users.find_by_email(&email).await?.is_some() {
140 return Err(AuthError::EmailTaken);
141 }
142
143 let password_hash = hash::hash_password(&input.password)?;
144 let user = self
145 .users
146 .create_user(&email, input.name.as_deref(), Some(&password_hash))
147 .await?;
148
149 let verification_token = if self.config.email.send_verification_on_signup {
150 let identifier = format!("email-verify:{}", user.email.to_lowercase());
151 let raw_token = token::generate_token(self.config.token_length);
152
153 let _ = self.verifications.delete_by_identifier(&identifier).await;
154 self.verifications
155 .create_verification(NewVerification {
156 identifier,
157 token_hash: token::hash_token(&raw_token),
158 expires_at: OffsetDateTime::now_utc() + self.config.verification_ttl,
159 })
160 .await?;
161 self.email
162 .send_verification_email(&user, &raw_token)
163 .await?;
164 Some(raw_token)
165 } else {
166 None
167 };
168
169 let (session, session_token) = if self.config.email.auto_sign_in_after_signup {
170 let (session, raw_token) = self
171 .create_session_internal(user.id, ip, user_agent)
172 .await?;
173 (Some(session), Some(raw_token))
174 } else {
175 (None, None)
176 };
177
178 Ok(SignupResult {
179 user,
180 session,
181 session_token,
182 verification_token,
183 })
184 }
185
186 pub async fn login(
188 &self,
189 email: &str,
190 password: &str,
191 ip: Option<String>,
192 user_agent: Option<String>,
193 ) -> Result<LoginResult, AuthError> {
194 let user = self
195 .users
196 .find_by_email(&email.trim().to_lowercase())
197 .await?
198 .ok_or(AuthError::InvalidCredentials)?;
199
200 let password_hash = user
201 .password_hash
202 .as_deref()
203 .ok_or(AuthError::InvalidCredentials)?;
204 if !hash::verify_password(password, password_hash)? {
205 return Err(AuthError::InvalidCredentials);
206 }
207
208 if self.config.email.require_verification_to_login && !user.is_verified() {
209 return Err(AuthError::EmailNotVerified);
210 }
211
212 let (session, session_token) = self
213 .create_session_internal(user.id, ip, user_agent)
214 .await?;
215
216 Ok(LoginResult {
217 user,
218 session,
219 session_token,
220 })
221 }
222
223 pub async fn logout(&self, session_id: i64) -> Result<(), AuthError> {
225 self.sessions.delete_session(session_id).await
226 }
227
228 pub async fn logout_all(&self, user_id: i64) -> Result<(), AuthError> {
230 self.sessions.delete_by_user_id(user_id).await
231 }
232
233 pub async fn get_session(&self, raw_token: &str) -> Result<SessionResult, AuthError> {
235 let session = self
236 .sessions
237 .find_by_token_hash(&token::hash_token(raw_token))
238 .await?
239 .ok_or(AuthError::SessionNotFound)?;
240
241 if session.expires_at < OffsetDateTime::now_utc() {
242 self.sessions.delete_session(session.id).await?;
243 return Err(AuthError::SessionNotFound);
244 }
245
246 let user = self
247 .users
248 .find_by_id(session.user_id)
249 .await?
250 .ok_or(AuthError::UserNotFound)?;
251
252 Ok(SessionResult { user, session })
253 }
254
255 pub async fn list_sessions(&self, user_id: i64) -> Result<Vec<Session>, AuthError> {
257 self.sessions.find_by_user_id(user_id).await
258 }
259
260 pub async fn verify_email(
262 &self,
263 raw_token: &str,
264 ip: Option<String>,
265 user_agent: Option<String>,
266 ) -> Result<VerifyEmailResult, AuthError> {
267 let verification = self.lookup_verification(raw_token, "email-verify:").await?;
268 let email = verification
269 .identifier
270 .strip_prefix("email-verify:")
271 .ok_or(AuthError::InvalidToken)?;
272
273 let user = self
274 .users
275 .find_by_email(email)
276 .await?
277 .ok_or(AuthError::UserNotFound)?;
278 self.users.set_email_verified(user.id).await?;
279 self.verifications
280 .delete_verification(verification.id)
281 .await?;
282
283 let user = self
284 .users
285 .find_by_id(user.id)
286 .await?
287 .ok_or(AuthError::UserNotFound)?;
288
289 let (session, session_token) = if self.config.email.auto_sign_in_after_verification {
290 let (session, raw_token) = self
291 .create_session_internal(user.id, ip, user_agent)
292 .await?;
293 (Some(session), Some(raw_token))
294 } else {
295 (None, None)
296 };
297
298 Ok(VerifyEmailResult {
299 user,
300 session,
301 session_token,
302 })
303 }
304
305 pub async fn request_password_reset(
307 &self,
308 email: &str,
309 ) -> Result<RequestResetResult, AuthError> {
310 let email = email.trim().to_lowercase();
311 if let Some(user) = self.users.find_by_email(&email).await? {
312 let identifier = format!("password-reset:{}", user.email.to_lowercase());
313 let _ = self.verifications.delete_by_identifier(&identifier).await;
314
315 let raw_token = token::generate_token(self.config.token_length);
316 self.verifications
317 .create_verification(NewVerification {
318 identifier,
319 token_hash: token::hash_token(&raw_token),
320 expires_at: OffsetDateTime::now_utc() + self.config.reset_ttl,
321 })
322 .await?;
323 self.email
324 .send_password_reset_email(&user, &raw_token)
325 .await?;
326 }
327
328 Ok(RequestResetResult::default())
329 }
330
331 pub async fn reset_password(
333 &self,
334 raw_token: &str,
335 new_password: &str,
336 ) -> Result<ResetPasswordResult, AuthError> {
337 if new_password.len() < 8 {
338 return Err(AuthError::WeakPassword(8));
339 }
340
341 let verification = self
342 .lookup_verification(raw_token, "password-reset:")
343 .await?;
344 let email = verification
345 .identifier
346 .strip_prefix("password-reset:")
347 .ok_or(AuthError::InvalidToken)?;
348 let user = self
349 .users
350 .find_by_email(email)
351 .await?
352 .ok_or(AuthError::UserNotFound)?;
353
354 self.users
355 .update_password(user.id, &hash::hash_password(new_password)?)
356 .await?;
357 self.sessions.delete_by_user_id(user.id).await?;
358 self.verifications
359 .delete_verification(verification.id)
360 .await?;
361
362 let user = self
363 .users
364 .find_by_id(user.id)
365 .await?
366 .ok_or(AuthError::UserNotFound)?;
367
368 Ok(ResetPasswordResult { user })
369 }
370
371 pub async fn cleanup_expired(&self) -> Result<(u64, u64, u64), AuthError> {
374 let sessions_deleted = self.sessions.delete_expired().await?;
375 let verifications_deleted = self.verifications.delete_expired().await?;
376 let oauth_states_deleted = self.oauth_states.delete_expired_oauth_states().await?;
377 Ok((
378 sessions_deleted,
379 verifications_deleted,
380 oauth_states_deleted,
381 ))
382 }
383
384 pub async fn oauth_callback(
389 &self,
390 info: OAuthUserInfo,
391 tokens: OAuthTokens,
392 ip: Option<String>,
393 user_agent: Option<String>,
394 ) -> Result<LoginResult, AuthError> {
395 if let Some(account) = self
397 .accounts
398 .find_by_provider(&info.provider_id, &info.account_id)
399 .await?
400 {
401 let user = self
403 .users
404 .find_by_id(account.user_id)
405 .await?
406 .ok_or(AuthError::UserNotFound)?;
407 let (session, session_token) = self
408 .create_session_internal(user.id, ip, user_agent)
409 .await?;
410 return Ok(LoginResult {
411 user,
412 session,
413 session_token,
414 });
415 }
416
417 let user = if let Some(existing_user) = self.users.find_by_email(&info.email).await? {
419 if !self.config.oauth.allow_implicit_account_linking {
421 return Err(AuthError::OAuth(OAuthError::LinkingDisabled));
422 }
423 existing_user
424 } else {
425 self.users
427 .create_user(&info.email, info.name.as_deref(), None)
428 .await?
429 };
430
431 let access_token_expires_at = tokens.expires_in.map(|d| OffsetDateTime::now_utc() + d);
433 self.accounts
434 .create_account(NewAccount {
435 user_id: user.id,
436 provider_id: info.provider_id,
437 account_id: info.account_id,
438 access_token: tokens.access_token,
439 refresh_token: tokens.refresh_token,
440 access_token_expires_at,
441 scope: tokens.scope,
442 })
443 .await?;
444
445 if !user.is_verified() {
447 self.users.set_email_verified(user.id).await?;
448 }
449
450 let user = self
452 .users
453 .find_by_id(user.id)
454 .await?
455 .ok_or(AuthError::UserNotFound)?;
456 let (session, session_token) = self
457 .create_session_internal(user.id, ip, user_agent)
458 .await?;
459 Ok(LoginResult {
460 user,
461 session,
462 session_token,
463 })
464 }
465
466 async fn create_session_internal(
467 &self,
468 user_id: i64,
469 ip: Option<String>,
470 user_agent: Option<String>,
471 ) -> Result<(Session, String), AuthError> {
472 let raw_token = token::generate_token(self.config.token_length);
473 let session = self
474 .sessions
475 .create_session(NewSession {
476 token_hash: token::hash_token(&raw_token),
477 user_id,
478 expires_at: OffsetDateTime::now_utc() + self.config.session_ttl,
479 ip_address: ip,
480 user_agent,
481 })
482 .await?;
483
484 Ok((session, raw_token))
485 }
486
487 async fn lookup_verification(
488 &self,
489 raw_token: &str,
490 prefix: &str,
491 ) -> Result<Verification, AuthError> {
492 let verification = self
493 .verifications
494 .find_by_token_hash(&token::hash_token(raw_token))
495 .await?
496 .ok_or(AuthError::InvalidToken)?;
497
498 if !verification.identifier.starts_with(prefix) {
499 return Err(AuthError::InvalidToken);
500 }
501
502 if verification.expires_at < OffsetDateTime::now_utc() {
503 self.verifications
504 .delete_verification(verification.id)
505 .await?;
506 return Err(AuthError::InvalidToken);
507 }
508
509 Ok(verification)
510 }
511}
512
513#[cfg(test)]
514mod tests {
515 use std::collections::HashMap;
516 use std::sync::{Arc, Mutex};
517
518 use async_trait::async_trait;
519 use time::OffsetDateTime;
520
521 use super::AuthService;
522 use crate::config::AuthConfig;
523 use crate::email::EmailSender;
524 use crate::error::{AuthError, OAuthError};
525 use crate::oauth::OAuthTokens;
526 use crate::store::{AccountStore, OAuthStateStore, SessionStore, UserStore, VerificationStore};
527 use crate::types::{
528 Account, NewAccount, NewOAuthState, NewSession, NewUser, NewVerification, OAuthState,
529 Session, User, Verification,
530 };
531
532 #[derive(Default)]
533 struct MemoryState {
534 next_user_id: i64,
535 next_session_id: i64,
536 next_verification_id: i64,
537 next_account_id: i64,
538 next_oauth_state_id: i64,
539 users: HashMap<i64, User>,
540 sessions: HashMap<i64, Session>,
541 verifications: HashMap<i64, Verification>,
542 accounts: HashMap<i64, Account>,
543 oauth_states: HashMap<i64, OAuthState>,
544 }
545
546 #[derive(Clone, Default)]
547 struct MemoryStore {
548 inner: Arc<Mutex<MemoryState>>,
549 }
550
551 #[async_trait]
552 impl UserStore for MemoryStore {
553 async fn create_user(
554 &self,
555 email: &str,
556 name: Option<&str>,
557 password_hash: Option<&str>,
558 ) -> Result<User, AuthError> {
559 let mut state = self.inner.lock().unwrap();
560 if state.users.values().any(|user| user.email == email) {
561 return Err(AuthError::EmailTaken);
562 }
563
564 state.next_user_id += 1;
565 let now = OffsetDateTime::now_utc();
566 let user = User {
567 id: state.next_user_id,
568 email: email.to_string(),
569 name: name.map(str::to_owned),
570 password_hash: password_hash.map(str::to_owned),
571 email_verified_at: None,
572 image: None,
573 created_at: now,
574 updated_at: now,
575 };
576 state.users.insert(user.id, user.clone());
577 Ok(user)
578 }
579
580 async fn find_by_email(&self, email: &str) -> Result<Option<User>, AuthError> {
581 let state = self.inner.lock().unwrap();
582 Ok(state
583 .users
584 .values()
585 .find(|user| user.email == email)
586 .cloned())
587 }
588
589 async fn find_by_id(&self, id: i64) -> Result<Option<User>, AuthError> {
590 Ok(self.inner.lock().unwrap().users.get(&id).cloned())
591 }
592
593 async fn set_email_verified(&self, user_id: i64) -> Result<(), AuthError> {
594 let mut state = self.inner.lock().unwrap();
595 let user = state
596 .users
597 .get_mut(&user_id)
598 .ok_or(AuthError::UserNotFound)?;
599 user.email_verified_at = Some(OffsetDateTime::now_utc());
600 user.updated_at = OffsetDateTime::now_utc();
601 Ok(())
602 }
603
604 async fn update_password(
605 &self,
606 user_id: i64,
607 password_hash: &str,
608 ) -> Result<(), AuthError> {
609 let mut state = self.inner.lock().unwrap();
610 let user = state
611 .users
612 .get_mut(&user_id)
613 .ok_or(AuthError::UserNotFound)?;
614 user.password_hash = Some(password_hash.to_string());
615 user.updated_at = OffsetDateTime::now_utc();
616 Ok(())
617 }
618
619 async fn delete_user(&self, user_id: i64) -> Result<(), AuthError> {
620 self.inner.lock().unwrap().users.remove(&user_id);
621 Ok(())
622 }
623 }
624
625 #[async_trait]
626 impl SessionStore for MemoryStore {
627 async fn create_session(&self, session: NewSession) -> Result<Session, AuthError> {
628 let mut state = self.inner.lock().unwrap();
629 state.next_session_id += 1;
630 let now = OffsetDateTime::now_utc();
631 let session = Session {
632 id: state.next_session_id,
633 token_hash: session.token_hash,
634 user_id: session.user_id,
635 expires_at: session.expires_at,
636 ip_address: session.ip_address,
637 user_agent: session.user_agent,
638 created_at: now,
639 updated_at: now,
640 };
641 state.sessions.insert(session.id, session.clone());
642 Ok(session)
643 }
644
645 async fn find_by_token_hash(&self, token_hash: &str) -> Result<Option<Session>, AuthError> {
646 let state = self.inner.lock().unwrap();
647 Ok(state
648 .sessions
649 .values()
650 .find(|session| session.token_hash == token_hash)
651 .cloned())
652 }
653
654 async fn find_by_user_id(&self, user_id: i64) -> Result<Vec<Session>, AuthError> {
655 let state = self.inner.lock().unwrap();
656 Ok(state
657 .sessions
658 .values()
659 .filter(|session| session.user_id == user_id)
660 .cloned()
661 .collect())
662 }
663
664 async fn delete_session(&self, id: i64) -> Result<(), AuthError> {
665 self.inner.lock().unwrap().sessions.remove(&id);
666 Ok(())
667 }
668
669 async fn delete_by_user_id(&self, user_id: i64) -> Result<(), AuthError> {
670 self.inner
671 .lock()
672 .unwrap()
673 .sessions
674 .retain(|_, session| session.user_id != user_id);
675 Ok(())
676 }
677
678 async fn delete_expired(&self) -> Result<u64, AuthError> {
679 let now = OffsetDateTime::now_utc();
680 let mut state = self.inner.lock().unwrap();
681 let before = state.sessions.len();
682 state
683 .sessions
684 .retain(|_, session| session.expires_at >= now);
685 Ok((before - state.sessions.len()) as u64)
686 }
687 }
688
689 #[async_trait]
690 impl VerificationStore for MemoryStore {
691 async fn create_verification(
692 &self,
693 verification: NewVerification,
694 ) -> Result<Verification, AuthError> {
695 let mut state = self.inner.lock().unwrap();
696 state.next_verification_id += 1;
697 let now = OffsetDateTime::now_utc();
698 let verification = Verification {
699 id: state.next_verification_id,
700 identifier: verification.identifier,
701 token_hash: verification.token_hash,
702 expires_at: verification.expires_at,
703 created_at: now,
704 updated_at: now,
705 };
706 state
707 .verifications
708 .insert(verification.id, verification.clone());
709 Ok(verification)
710 }
711
712 async fn find_by_identifier(
713 &self,
714 identifier: &str,
715 ) -> Result<Option<Verification>, AuthError> {
716 let state = self.inner.lock().unwrap();
717 Ok(state
718 .verifications
719 .values()
720 .find(|verification| verification.identifier == identifier)
721 .cloned())
722 }
723
724 async fn find_by_token_hash(
725 &self,
726 token_hash: &str,
727 ) -> Result<Option<Verification>, AuthError> {
728 let state = self.inner.lock().unwrap();
729 Ok(state
730 .verifications
731 .values()
732 .find(|verification| verification.token_hash == token_hash)
733 .cloned())
734 }
735
736 async fn delete_verification(&self, id: i64) -> Result<(), AuthError> {
737 self.inner.lock().unwrap().verifications.remove(&id);
738 Ok(())
739 }
740
741 async fn delete_by_identifier(&self, identifier: &str) -> Result<(), AuthError> {
742 self.inner
743 .lock()
744 .unwrap()
745 .verifications
746 .retain(|_, verification| verification.identifier != identifier);
747 Ok(())
748 }
749
750 async fn delete_expired(&self) -> Result<u64, AuthError> {
751 let now = OffsetDateTime::now_utc();
752 let mut state = self.inner.lock().unwrap();
753 let before = state.verifications.len();
754 state
755 .verifications
756 .retain(|_, verification| verification.expires_at >= now);
757 Ok((before - state.verifications.len()) as u64)
758 }
759 }
760
761 #[async_trait]
762 impl AccountStore for MemoryStore {
763 async fn create_account(&self, account: NewAccount) -> Result<Account, AuthError> {
764 let mut state = self.inner.lock().unwrap();
765 state.next_account_id += 1;
766 let now = OffsetDateTime::now_utc();
767 let account = Account {
768 id: state.next_account_id,
769 user_id: account.user_id,
770 provider_id: account.provider_id,
771 account_id: account.account_id,
772 access_token: account.access_token,
773 refresh_token: account.refresh_token,
774 access_token_expires_at: account.access_token_expires_at,
775 scope: account.scope,
776 created_at: now,
777 updated_at: now,
778 };
779 state.accounts.insert(account.id, account.clone());
780 Ok(account)
781 }
782
783 async fn find_by_provider(
784 &self,
785 provider_id: &str,
786 account_id: &str,
787 ) -> Result<Option<Account>, AuthError> {
788 let state = self.inner.lock().unwrap();
789 Ok(state
790 .accounts
791 .values()
792 .find(|account| {
793 account.provider_id == provider_id && account.account_id == account_id
794 })
795 .cloned())
796 }
797
798 async fn find_by_user_id(&self, user_id: i64) -> Result<Vec<Account>, AuthError> {
799 let state = self.inner.lock().unwrap();
800 Ok(state
801 .accounts
802 .values()
803 .filter(|account| account.user_id == user_id)
804 .cloned()
805 .collect())
806 }
807
808 async fn delete_account(&self, id: i64) -> Result<(), AuthError> {
809 self.inner.lock().unwrap().accounts.remove(&id);
810 Ok(())
811 }
812 }
813
814 #[async_trait]
815 impl OAuthStateStore for MemoryStore {
816 async fn create_oauth_state(
817 &self,
818 new_state: NewOAuthState,
819 ) -> Result<OAuthState, AuthError> {
820 let mut state = self.inner.lock().unwrap();
821 state.next_oauth_state_id += 1;
822 let now = OffsetDateTime::now_utc();
823 let oauth_state = OAuthState {
824 id: state.next_oauth_state_id,
825 provider_id: new_state.provider_id,
826 csrf_state: new_state.csrf_state,
827 pkce_verifier: new_state.pkce_verifier,
828 expires_at: new_state.expires_at,
829 created_at: now,
830 };
831 state
832 .oauth_states
833 .insert(oauth_state.id, oauth_state.clone());
834 Ok(oauth_state)
835 }
836
837 async fn find_by_csrf_state(
838 &self,
839 csrf_state: &str,
840 ) -> Result<Option<OAuthState>, AuthError> {
841 let state = self.inner.lock().unwrap();
842 Ok(state
843 .oauth_states
844 .values()
845 .find(|s| s.csrf_state == csrf_state)
846 .cloned())
847 }
848
849 async fn delete_oauth_state(&self, id: i64) -> Result<(), AuthError> {
850 self.inner.lock().unwrap().oauth_states.remove(&id);
851 Ok(())
852 }
853
854 async fn delete_expired_oauth_states(&self) -> Result<u64, AuthError> {
855 let now = OffsetDateTime::now_utc();
856 let mut state = self.inner.lock().unwrap();
857 let before = state.oauth_states.len();
858 state.oauth_states.retain(|_, s| s.expires_at >= now);
859 Ok((before - state.oauth_states.len()) as u64)
860 }
861 }
862
863 #[derive(Clone, Default)]
864 struct TestEmailSender {
865 verification_tokens: Arc<Mutex<Vec<String>>>,
866 reset_tokens: Arc<Mutex<Vec<String>>>,
867 }
868
869 #[async_trait]
870 impl EmailSender for TestEmailSender {
871 async fn send_verification_email(
872 &self,
873 _user: &User,
874 token: &str,
875 ) -> Result<(), AuthError> {
876 self.verification_tokens
877 .lock()
878 .unwrap()
879 .push(token.to_string());
880 Ok(())
881 }
882
883 async fn send_password_reset_email(
884 &self,
885 _user: &User,
886 token: &str,
887 ) -> Result<(), AuthError> {
888 self.reset_tokens.lock().unwrap().push(token.to_string());
889 Ok(())
890 }
891 }
892
893 #[tokio::test]
894 async fn signup_verify_login_and_reset_flow_works() {
895 let store = MemoryStore::default();
896 let email = TestEmailSender::default();
897 let service = AuthService::new(
898 AuthConfig::default(),
899 store.clone(),
900 store.clone(),
901 store.clone(),
902 store.clone(),
903 store.clone(),
904 email.clone(),
905 );
906
907 let signup = service
908 .signup(
909 NewUser {
910 email: "test@example.com".to_string(),
911 name: Some("Test".to_string()),
912 password: "supersecret".to_string(),
913 },
914 Some("127.0.0.1".to_string()),
915 Some("test-agent".to_string()),
916 )
917 .await
918 .unwrap();
919
920 assert_eq!(signup.user.email, "test@example.com");
921 assert!(signup.session.is_some());
922 assert_eq!(email.verification_tokens.lock().unwrap().len(), 1);
923
924 let verification_token = email.verification_tokens.lock().unwrap()[0].clone();
925 let verify = service
926 .verify_email(&verification_token, None, None)
927 .await
928 .unwrap();
929 assert!(verify.user.is_verified());
930
931 let login = service
932 .login("test@example.com", "supersecret", None, None)
933 .await
934 .unwrap();
935 assert_eq!(login.user.email, "test@example.com");
936
937 service
938 .request_password_reset("test@example.com")
939 .await
940 .unwrap();
941 let reset_token = email.reset_tokens.lock().unwrap()[0].clone();
942 service
943 .reset_password(&reset_token, "newpassword")
944 .await
945 .unwrap();
946
947 let login = service
948 .login("test@example.com", "newpassword", None, None)
949 .await
950 .unwrap();
951 assert_eq!(login.user.email, "test@example.com");
952 }
953
954 #[tokio::test]
955 async fn oauth_callback_creates_new_user_and_account() {
956 let store = MemoryStore::default();
957 let email = TestEmailSender::default();
958 let service = AuthService::new(
959 AuthConfig::default(),
960 store.clone(),
961 store.clone(),
962 store.clone(),
963 store.clone(),
964 store.clone(),
965 email.clone(),
966 );
967
968 let oauth_info = crate::oauth::OAuthUserInfo {
969 provider_id: "google".to_string(),
970 account_id: "google-123".to_string(),
971 email: "oauth@example.com".to_string(),
972 name: Some("OAuth User".to_string()),
973 image: Some("https://example.com/avatar.jpg".to_string()),
974 };
975
976 let result = service
977 .oauth_callback(
978 oauth_info,
979 OAuthTokens::default(),
980 Some("127.0.0.1".to_string()),
981 Some("test-agent".to_string()),
982 )
983 .await
984 .unwrap();
985
986 assert_eq!(result.user.email, "oauth@example.com");
987 assert_eq!(result.user.name, Some("OAuth User".to_string()));
988 assert!(result.user.is_verified());
989 assert!(result.user.password_hash.is_none());
990
991 let accounts = AccountStore::find_by_user_id(&store, result.user.id)
993 .await
994 .unwrap();
995 assert_eq!(accounts.len(), 1);
996 assert_eq!(accounts[0].provider_id, "google");
997 assert_eq!(accounts[0].account_id, "google-123");
998 }
999
1000 #[tokio::test]
1001 async fn oauth_callback_links_existing_user_by_email() {
1002 let store = MemoryStore::default();
1003 let email = TestEmailSender::default();
1004 let service = AuthService::new(
1005 AuthConfig::default(),
1006 store.clone(),
1007 store.clone(),
1008 store.clone(),
1009 store.clone(),
1010 store.clone(),
1011 email.clone(),
1012 );
1013
1014 let existing_user = store
1016 .create_user(
1017 "existing@example.com",
1018 Some("Existing User"),
1019 Some("hash123"),
1020 )
1021 .await
1022 .unwrap();
1023
1024 let oauth_info = crate::oauth::OAuthUserInfo {
1025 provider_id: "github".to_string(),
1026 account_id: "github-456".to_string(),
1027 email: "existing@example.com".to_string(),
1028 name: Some("GitHub User".to_string()),
1029 image: None,
1030 };
1031
1032 let result = service
1033 .oauth_callback(
1034 oauth_info,
1035 OAuthTokens::default(),
1036 Some("127.0.0.1".to_string()),
1037 Some("test-agent".to_string()),
1038 )
1039 .await
1040 .unwrap();
1041
1042 assert_eq!(result.user.id, existing_user.id);
1044 assert_eq!(result.user.email, "existing@example.com");
1045 assert!(result.user.is_verified());
1046
1047 let accounts = AccountStore::find_by_user_id(&store, result.user.id)
1049 .await
1050 .unwrap();
1051 assert_eq!(accounts.len(), 1);
1052 assert_eq!(accounts[0].provider_id, "github");
1053 assert_eq!(accounts[0].account_id, "github-456");
1054 }
1055
1056 #[tokio::test]
1057 async fn oauth_callback_logs_in_existing_account() {
1058 let store = MemoryStore::default();
1059 let email = TestEmailSender::default();
1060 let service = AuthService::new(
1061 AuthConfig::default(),
1062 store.clone(),
1063 store.clone(),
1064 store.clone(),
1065 store.clone(),
1066 store.clone(),
1067 email.clone(),
1068 );
1069
1070 let user = store
1072 .create_user("oauth@example.com", Some("OAuth User"), None)
1073 .await
1074 .unwrap();
1075 store
1076 .create_account(crate::types::NewAccount {
1077 user_id: user.id,
1078 provider_id: "google".to_string(),
1079 account_id: "google-789".to_string(),
1080 access_token: None,
1081 refresh_token: None,
1082 access_token_expires_at: None,
1083 scope: None,
1084 })
1085 .await
1086 .unwrap();
1087
1088 let oauth_info = crate::oauth::OAuthUserInfo {
1089 provider_id: "google".to_string(),
1090 account_id: "google-789".to_string(),
1091 email: "oauth@example.com".to_string(),
1092 name: Some("OAuth User".to_string()),
1093 image: None,
1094 };
1095
1096 let result = service
1097 .oauth_callback(
1098 oauth_info,
1099 OAuthTokens::default(),
1100 Some("127.0.0.1".to_string()),
1101 Some("test-agent".to_string()),
1102 )
1103 .await
1104 .unwrap();
1105
1106 assert_eq!(result.user.id, user.id);
1108 assert_eq!(result.user.email, "oauth@example.com");
1109
1110 let accounts = AccountStore::find_by_user_id(&store, result.user.id)
1112 .await
1113 .unwrap();
1114 assert_eq!(accounts.len(), 1);
1115 }
1116
1117 #[tokio::test]
1118 async fn oauth_callback_respects_linking_policy() {
1119 let store = MemoryStore::default();
1120 let email = TestEmailSender::default();
1121 let mut config = AuthConfig::default();
1122 config.oauth.allow_implicit_account_linking = false;
1123
1124 let service = AuthService::new(
1125 config,
1126 store.clone(),
1127 store.clone(),
1128 store.clone(),
1129 store.clone(),
1130 store.clone(),
1131 email.clone(),
1132 );
1133
1134 store
1136 .create_user(
1137 "existing@example.com",
1138 Some("Existing User"),
1139 Some("hash123"),
1140 )
1141 .await
1142 .unwrap();
1143
1144 let oauth_info = crate::oauth::OAuthUserInfo {
1145 provider_id: "google".to_string(),
1146 account_id: "google-999".to_string(),
1147 email: "existing@example.com".to_string(),
1148 name: Some("OAuth User".to_string()),
1149 image: None,
1150 };
1151
1152 let result = service
1154 .oauth_callback(
1155 oauth_info,
1156 OAuthTokens::default(),
1157 Some("127.0.0.1".to_string()),
1158 Some("test-agent".to_string()),
1159 )
1160 .await;
1161
1162 assert!(result.is_err());
1163 match result {
1164 Err(AuthError::OAuth(OAuthError::LinkingDisabled)) => {}
1165 _ => panic!("Expected OAuth linking-disabled error"),
1166 }
1167 }
1168
1169 #[tokio::test]
1170 async fn cleanup_expired_removes_sessions_verifications_and_oauth_states() {
1171 let store = MemoryStore::default();
1172 let email = TestEmailSender::default();
1173 let service = AuthService::new(
1174 AuthConfig::default(),
1175 store.clone(),
1176 store.clone(),
1177 store.clone(),
1178 store.clone(),
1179 store.clone(),
1180 email,
1181 );
1182
1183 store
1184 .create_session(NewSession {
1185 token_hash: "expired-session".to_string(),
1186 user_id: 1,
1187 expires_at: OffsetDateTime::now_utc() - time::Duration::hours(1),
1188 ip_address: None,
1189 user_agent: None,
1190 })
1191 .await
1192 .unwrap();
1193 store
1194 .create_verification(NewVerification {
1195 identifier: "email-verify:test@example.com".to_string(),
1196 token_hash: "expired-verification".to_string(),
1197 expires_at: OffsetDateTime::now_utc() - time::Duration::hours(1),
1198 })
1199 .await
1200 .unwrap();
1201 store
1202 .create_oauth_state(NewOAuthState {
1203 provider_id: "google".to_string(),
1204 csrf_state: "expired-oauth-state".to_string(),
1205 pkce_verifier: "pkce-verifier".to_string(),
1206 expires_at: OffsetDateTime::now_utc() - time::Duration::hours(1),
1207 })
1208 .await
1209 .unwrap();
1210
1211 let deleted = service.cleanup_expired().await.unwrap();
1212
1213 assert_eq!(deleted, (1, 1, 1));
1214 }
1215}