1use std::sync::Arc;
2
3use time::OffsetDateTime;
4
5use crate::config::AuthConfig;
6use crate::crypto::{hash, token};
7use crate::email::EmailSender;
8use crate::error::{AuthError, OAuthError};
9use crate::events::{AuthEvent, LoginFailReason, LoginMethod};
10use crate::hooks::EventEmitter;
11use crate::oauth::{OAuthProviderConfig, OAuthTokens, OAuthUserInfo, client};
12use crate::store::{AccountStore, OAuthStateStore, SessionStore, UserStore, VerificationStore};
13use crate::types::{
14 NewAccount, NewSession, NewUser, NewVerification, PublicAccount, Session, User, Verification,
15};
16
17#[derive(Debug)]
19pub struct SignupResult {
20 pub user: User,
22 pub session: Option<Session>,
24 pub session_token: Option<String>,
26 pub verification_token: Option<String>,
28}
29
30#[derive(Debug)]
32pub struct LoginResult {
33 pub user: User,
35 pub session: Session,
37 pub session_token: String,
39}
40
41#[derive(Debug)]
43pub struct VerifyEmailResult {
44 pub user: User,
46 pub session: Option<Session>,
48 pub session_token: Option<String>,
50}
51
52#[derive(Debug, Default)]
54pub struct RequestResetResult {
55 pub _private: (),
57}
58
59#[derive(Debug)]
61pub struct ResetPasswordResult {
62 pub user: User,
64}
65
66#[derive(Debug)]
68pub struct SessionResult {
69 pub user: User,
71 pub session: Session,
73}
74
75#[derive(Debug)]
77pub struct LinkAccountResult {
78 pub user: User,
80}
81
82#[derive(Debug, Default)]
84pub struct UnlinkAccountResult {
85 pub _private: (),
86}
87
88#[derive(Debug)]
90pub struct RefreshTokenResult {
91 pub tokens: OAuthTokens,
93}
94
95pub struct AuthService<U, S, V, A, O, E>
97where
98 U: UserStore,
99 S: SessionStore,
100 V: VerificationStore,
101 A: AccountStore,
102 O: OAuthStateStore,
103 E: EmailSender,
104{
105 pub config: AuthConfig,
107 pub users: Arc<U>,
109 pub sessions: Arc<S>,
111 pub verifications: Arc<V>,
113 pub accounts: Arc<A>,
115 pub oauth_states: Arc<O>,
117 pub email: Arc<E>,
119 pub events: Arc<EventEmitter>,
121}
122
123impl<U, S, V, A, O, E> AuthService<U, S, V, A, O, E>
124where
125 U: UserStore,
126 S: SessionStore,
127 V: VerificationStore,
128 A: AccountStore,
129 O: OAuthStateStore,
130 E: EmailSender,
131{
132 pub fn new(
134 config: AuthConfig,
135 users: U,
136 sessions: S,
137 verifications: V,
138 accounts: A,
139 oauth_states: O,
140 email: E,
141 ) -> Self {
142 Self {
143 config,
144 users: Arc::new(users),
145 sessions: Arc::new(sessions),
146 verifications: Arc::new(verifications),
147 accounts: Arc::new(accounts),
148 oauth_states: Arc::new(oauth_states),
149 email: Arc::new(email),
150 events: Arc::new(EventEmitter::new()),
151 }
152 }
153
154 pub fn with_events(mut self, events: EventEmitter) -> Self {
155 self.events = Arc::new(events);
156 self
157 }
158
159 pub async fn signup(
161 &self,
162 input: NewUser,
163 ip: Option<String>,
164 user_agent: Option<String>,
165 ) -> Result<SignupResult, AuthError> {
166 if input.password.len() < 8 {
167 return Err(AuthError::WeakPassword(8));
168 }
169
170 let email = input.email.trim().to_lowercase();
171 if self.users.find_by_email(&email).await?.is_some() {
172 return Err(AuthError::EmailTaken);
173 }
174
175 let password_hash = hash::hash_password(&input.password)?;
176 let user = self
177 .users
178 .create_user(&email, input.name.as_deref(), Some(&password_hash))
179 .await?;
180
181 let verification_token = if self.config.email.send_verification_on_signup {
182 let identifier = format!("email-verify:{}", user.email.to_lowercase());
183 let raw_token = token::generate_token(self.config.token_length);
184
185 let _ = self.verifications.delete_by_identifier(&identifier).await;
186 self.verifications
187 .create_verification(NewVerification {
188 identifier,
189 token_hash: token::hash_token(&raw_token),
190 expires_at: OffsetDateTime::now_utc() + self.config.verification_ttl,
191 })
192 .await?;
193 self.email
194 .send_verification_email(&user, &raw_token)
195 .await?;
196 Some(raw_token)
197 } else {
198 None
199 };
200
201 self.events
202 .emit(AuthEvent::UserSignedUp {
203 user_id: user.id,
204 email: user.email.clone(),
205 })
206 .await;
207
208 let (session, session_token) = if self.config.email.auto_sign_in_after_signup {
209 let (session, raw_token) = self
210 .create_session_internal(user.id, ip, user_agent)
211 .await?;
212 (Some(session), Some(raw_token))
213 } else {
214 (None, None)
215 };
216
217 Ok(SignupResult {
218 user,
219 session,
220 session_token,
221 verification_token,
222 })
223 }
224
225 pub async fn login(
227 &self,
228 email: &str,
229 password: &str,
230 ip: Option<String>,
231 user_agent: Option<String>,
232 ) -> Result<LoginResult, AuthError> {
233 let user = self
234 .users
235 .find_by_email(&email.trim().to_lowercase())
236 .await?
237 .ok_or(AuthError::InvalidCredentials)?;
238
239 let password_hash = user
240 .password_hash
241 .as_deref()
242 .ok_or(AuthError::InvalidCredentials)?;
243 if !hash::verify_password(password, password_hash)? {
244 self.events
245 .emit(AuthEvent::UserLoginFailed {
246 email: email.to_string(),
247 reason: LoginFailReason::InvalidCredentials,
248 })
249 .await;
250 return Err(AuthError::InvalidCredentials);
251 }
252
253 if self.config.email.require_verification_to_login && !user.is_verified() {
254 self.events
255 .emit(AuthEvent::UserLoginFailed {
256 email: email.to_string(),
257 reason: LoginFailReason::EmailNotVerified,
258 })
259 .await;
260 return Err(AuthError::EmailNotVerified);
261 }
262
263 let (session, session_token) = self
264 .create_session_internal(user.id, ip, user_agent)
265 .await?;
266
267 self.events
268 .emit(AuthEvent::UserLoggedIn {
269 user_id: user.id,
270 method: LoginMethod::Password,
271 })
272 .await;
273
274 Ok(LoginResult {
275 user,
276 session,
277 session_token,
278 })
279 }
280
281 pub async fn logout(&self, session_id: i64) -> Result<(), AuthError> {
283 self.sessions.delete_session(session_id).await
284 }
285
286 pub async fn logout_all(&self, user_id: i64) -> Result<(), AuthError> {
288 self.sessions.delete_by_user_id(user_id).await
289 }
290
291 pub async fn get_session(&self, raw_token: &str) -> Result<SessionResult, AuthError> {
293 let session = self
294 .sessions
295 .find_by_token_hash(&token::hash_token(raw_token))
296 .await?
297 .ok_or(AuthError::SessionNotFound)?;
298
299 if session.expires_at < OffsetDateTime::now_utc() {
300 self.sessions.delete_session(session.id).await?;
301 return Err(AuthError::SessionNotFound);
302 }
303
304 let user = self
305 .users
306 .find_by_id(session.user_id)
307 .await?
308 .ok_or(AuthError::UserNotFound)?;
309
310 Ok(SessionResult { user, session })
311 }
312
313 pub async fn list_sessions(&self, user_id: i64) -> Result<Vec<Session>, AuthError> {
315 self.sessions.find_by_user_id(user_id).await
316 }
317
318 pub async fn verify_email(
320 &self,
321 raw_token: &str,
322 ip: Option<String>,
323 user_agent: Option<String>,
324 ) -> Result<VerifyEmailResult, AuthError> {
325 let verification = self.lookup_verification(raw_token, "email-verify:").await?;
326 let email = verification
327 .identifier
328 .strip_prefix("email-verify:")
329 .ok_or(AuthError::InvalidToken)?;
330
331 let user = self
332 .users
333 .find_by_email(email)
334 .await?
335 .ok_or(AuthError::UserNotFound)?;
336 self.users.set_email_verified(user.id).await?;
337 self.verifications
338 .delete_verification(verification.id)
339 .await?;
340
341 self.events
342 .emit(AuthEvent::EmailVerified { user_id: user.id })
343 .await;
344
345 let user = self
346 .users
347 .find_by_id(user.id)
348 .await?
349 .ok_or(AuthError::UserNotFound)?;
350
351 let (session, session_token) = if self.config.email.auto_sign_in_after_verification {
352 let (session, raw_token) = self
353 .create_session_internal(user.id, ip, user_agent)
354 .await?;
355 (Some(session), Some(raw_token))
356 } else {
357 (None, None)
358 };
359
360 Ok(VerifyEmailResult {
361 user,
362 session,
363 session_token,
364 })
365 }
366
367 pub async fn request_password_reset(
369 &self,
370 email: &str,
371 ) -> Result<RequestResetResult, AuthError> {
372 let email = email.trim().to_lowercase();
373 if let Some(user) = self.users.find_by_email(&email).await? {
374 let identifier = format!("password-reset:{}", user.email.to_lowercase());
375 let _ = self.verifications.delete_by_identifier(&identifier).await;
376
377 let raw_token = token::generate_token(self.config.token_length);
378 self.verifications
379 .create_verification(NewVerification {
380 identifier,
381 token_hash: token::hash_token(&raw_token),
382 expires_at: OffsetDateTime::now_utc() + self.config.reset_ttl,
383 })
384 .await?;
385 self.email
386 .send_password_reset_email(&user, &raw_token)
387 .await?;
388
389 self.events
390 .emit(AuthEvent::PasswordResetRequested { user_id: user.id })
391 .await;
392 }
393
394 Ok(RequestResetResult::default())
395 }
396
397 pub async fn reset_password(
399 &self,
400 raw_token: &str,
401 new_password: &str,
402 ) -> Result<ResetPasswordResult, AuthError> {
403 if new_password.len() < 8 {
404 return Err(AuthError::WeakPassword(8));
405 }
406
407 let verification = self
408 .lookup_verification(raw_token, "password-reset:")
409 .await?;
410 let email = verification
411 .identifier
412 .strip_prefix("password-reset:")
413 .ok_or(AuthError::InvalidToken)?;
414 let user = self
415 .users
416 .find_by_email(email)
417 .await?
418 .ok_or(AuthError::UserNotFound)?;
419
420 self.users
421 .update_password(user.id, &hash::hash_password(new_password)?)
422 .await?;
423 self.sessions.delete_by_user_id(user.id).await?;
424 self.verifications
425 .delete_verification(verification.id)
426 .await?;
427
428 self.events
429 .emit(AuthEvent::PasswordResetCompleted { user_id: user.id })
430 .await;
431
432 let user = self
433 .users
434 .find_by_id(user.id)
435 .await?
436 .ok_or(AuthError::UserNotFound)?;
437
438 Ok(ResetPasswordResult { user })
439 }
440
441 pub async fn cleanup_expired(&self) -> Result<(u64, u64, u64), AuthError> {
444 let sessions_deleted = self.sessions.delete_expired().await?;
445 let verifications_deleted = self.verifications.delete_expired().await?;
446 let oauth_states_deleted = self.oauth_states.delete_expired_oauth_states().await?;
447 Ok((
448 sessions_deleted,
449 verifications_deleted,
450 oauth_states_deleted,
451 ))
452 }
453
454 pub async fn oauth_callback(
459 &self,
460 info: OAuthUserInfo,
461 tokens: OAuthTokens,
462 ip: Option<String>,
463 user_agent: Option<String>,
464 ) -> Result<LoginResult, AuthError> {
465 let oauth_provider_id = info.provider_id.clone();
466 if let Some(account) = self
468 .accounts
469 .find_by_provider(&info.provider_id, &info.account_id)
470 .await?
471 {
472 let user = self
473 .users
474 .find_by_id(account.user_id)
475 .await?
476 .ok_or(AuthError::UserNotFound)?;
477 let (session, session_token) = self
478 .create_session_internal(user.id, ip, user_agent)
479 .await?;
480 self.events
481 .emit(AuthEvent::UserLoggedIn {
482 user_id: user.id,
483 method: LoginMethod::OAuth {
484 provider_id: oauth_provider_id,
485 },
486 })
487 .await;
488 return Ok(LoginResult {
489 user,
490 session,
491 session_token,
492 });
493 }
494
495 let user = if let Some(existing_user) = self.users.find_by_email(&info.email).await? {
497 if !self.config.oauth.allow_implicit_account_linking {
499 return Err(AuthError::OAuth(OAuthError::LinkingDisabled));
500 }
501 existing_user
502 } else {
503 self.users
505 .create_user(&info.email, info.name.as_deref(), None)
506 .await?
507 };
508
509 let access_token_expires_at = tokens.expires_in.map(|d| OffsetDateTime::now_utc() + d);
511 self.accounts
512 .create_account(NewAccount {
513 user_id: user.id,
514 provider_id: info.provider_id,
515 account_id: info.account_id,
516 access_token: tokens.access_token,
517 refresh_token: tokens.refresh_token,
518 access_token_expires_at,
519 scope: tokens.scope,
520 })
521 .await?;
522
523 if !user.is_verified() {
525 self.users.set_email_verified(user.id).await?;
526 }
527
528 let user = self
530 .users
531 .find_by_id(user.id)
532 .await?
533 .ok_or(AuthError::UserNotFound)?;
534 let (session, session_token) = self
535 .create_session_internal(user.id, ip, user_agent)
536 .await?;
537 self.events
538 .emit(AuthEvent::UserLoggedIn {
539 user_id: user.id,
540 method: LoginMethod::OAuth {
541 provider_id: oauth_provider_id,
542 },
543 })
544 .await;
545 Ok(LoginResult {
546 user,
547 session,
548 session_token,
549 })
550 }
551
552 pub async fn list_accounts(&self, user_id: i64) -> Result<Vec<PublicAccount>, AuthError> {
554 let accounts = self.accounts.find_by_user_id(user_id).await?;
555 Ok(accounts.into_iter().map(PublicAccount::from).collect())
556 }
557
558 pub async fn link_account(
560 &self,
561 user_id: i64,
562 info: OAuthUserInfo,
563 tokens: OAuthTokens,
564 ) -> Result<LinkAccountResult, AuthError> {
565 let linked_provider_id = info.provider_id.clone();
566 if let Some(existing) = self
567 .accounts
568 .find_by_provider(&info.provider_id, &info.account_id)
569 .await?
570 {
571 if existing.user_id != user_id {
572 return Err(AuthError::OAuth(OAuthError::AccountAlreadyLinked));
573 }
574 let expires_at = tokens.expires_in.map(|d| OffsetDateTime::now_utc() + d);
575 self.accounts
576 .update_account(
577 existing.id,
578 tokens.access_token,
579 tokens.refresh_token,
580 expires_at,
581 tokens.scope,
582 )
583 .await?;
584 } else {
585 let expires_at = tokens.expires_in.map(|d| OffsetDateTime::now_utc() + d);
586 self.accounts
587 .create_account(NewAccount {
588 user_id,
589 provider_id: info.provider_id,
590 account_id: info.account_id,
591 access_token: tokens.access_token,
592 refresh_token: tokens.refresh_token,
593 access_token_expires_at: expires_at,
594 scope: tokens.scope,
595 })
596 .await?;
597 }
598
599 let user = self
600 .users
601 .find_by_id(user_id)
602 .await?
603 .ok_or(AuthError::UserNotFound)?;
604
605 self.events
606 .emit(AuthEvent::OAuthAccountLinked {
607 user_id,
608 provider_id: linked_provider_id,
609 })
610 .await;
611
612 Ok(LinkAccountResult { user })
613 }
614
615 pub async fn unlink_account(
617 &self,
618 user_id: i64,
619 account_id: i64,
620 ) -> Result<UnlinkAccountResult, AuthError> {
621 let accounts = self.accounts.find_by_user_id(user_id).await?;
622 let target = accounts
623 .iter()
624 .find(|a| a.id == account_id)
625 .ok_or(AuthError::OAuth(OAuthError::AccountNotFound))?;
626
627 let user = self
628 .users
629 .find_by_id(user_id)
630 .await?
631 .ok_or(AuthError::UserNotFound)?;
632
633 if user.password_hash.is_none() && accounts.len() == 1 {
634 return Err(AuthError::OAuth(OAuthError::LastAuthMethod));
635 }
636
637 let unlinked_provider_id = target.provider_id.clone();
638
639 self.accounts.delete_account(target.id).await?;
640
641 self.events
642 .emit(AuthEvent::OAuthAccountUnlinked {
643 user_id,
644 provider_id: unlinked_provider_id,
645 })
646 .await;
647
648 Ok(UnlinkAccountResult::default())
649 }
650
651 pub async fn refresh_oauth_token(
653 &self,
654 user_id: i64,
655 account_id: i64,
656 provider_config: &OAuthProviderConfig,
657 ) -> Result<RefreshTokenResult, AuthError> {
658 let accounts = self.accounts.find_by_user_id(user_id).await?;
659 let account = accounts
660 .iter()
661 .find(|a| a.id == account_id)
662 .ok_or(AuthError::OAuth(OAuthError::AccountNotFound))?;
663
664 let refresh_token_str = account
665 .refresh_token
666 .as_deref()
667 .ok_or(AuthError::OAuth(OAuthError::NoRefreshToken))?;
668
669 let tokens = client::refresh_access_token(provider_config, refresh_token_str).await?;
670
671 let expires_at = tokens.expires_in.map(|d| OffsetDateTime::now_utc() + d);
672 self.accounts
673 .update_account(
674 account.id,
675 tokens.access_token.clone(),
676 tokens.refresh_token.clone(),
677 expires_at,
678 tokens.scope.clone(),
679 )
680 .await?;
681
682 Ok(RefreshTokenResult { tokens })
683 }
684
685 async fn create_session_internal(
686 &self,
687 user_id: i64,
688 ip: Option<String>,
689 user_agent: Option<String>,
690 ) -> Result<(Session, String), AuthError> {
691 let raw_token = token::generate_token(self.config.token_length);
692 let session = self
693 .sessions
694 .create_session(NewSession {
695 token_hash: token::hash_token(&raw_token),
696 user_id,
697 expires_at: OffsetDateTime::now_utc() + self.config.session_ttl,
698 ip_address: ip.clone(),
699 user_agent,
700 })
701 .await?;
702
703 self.events
704 .emit(AuthEvent::SessionCreated {
705 user_id,
706 session_id: session.id,
707 ip: ip.clone(),
708 })
709 .await;
710
711 Ok((session, raw_token))
712 }
713
714 async fn lookup_verification(
715 &self,
716 raw_token: &str,
717 prefix: &str,
718 ) -> Result<Verification, AuthError> {
719 let verification = self
720 .verifications
721 .find_by_token_hash(&token::hash_token(raw_token))
722 .await?
723 .ok_or(AuthError::InvalidToken)?;
724
725 if !verification.identifier.starts_with(prefix) {
726 return Err(AuthError::InvalidToken);
727 }
728
729 if verification.expires_at < OffsetDateTime::now_utc() {
730 self.verifications
731 .delete_verification(verification.id)
732 .await?;
733 return Err(AuthError::InvalidToken);
734 }
735
736 Ok(verification)
737 }
738}
739
740#[cfg(test)]
741mod tests {
742 use std::collections::HashMap;
743 use std::sync::{Arc, Mutex};
744
745 use async_trait::async_trait;
746 use time::OffsetDateTime;
747
748 use super::AuthService;
749 use crate::config::AuthConfig;
750 use crate::email::EmailSender;
751 use crate::error::{AuthError, OAuthError};
752 use crate::oauth::{OAuthProviderConfig, OAuthTokens};
753 use crate::store::{AccountStore, OAuthStateStore, SessionStore, UserStore, VerificationStore};
754 use crate::types::{
755 Account, NewAccount, NewOAuthState, NewSession, NewUser, NewVerification, OAuthIntent,
756 OAuthState, Session, User, Verification,
757 };
758
759 #[derive(Default)]
760 struct MemoryState {
761 next_user_id: i64,
762 next_session_id: i64,
763 next_verification_id: i64,
764 next_account_id: i64,
765 next_oauth_state_id: i64,
766 users: HashMap<i64, User>,
767 sessions: HashMap<i64, Session>,
768 verifications: HashMap<i64, Verification>,
769 accounts: HashMap<i64, Account>,
770 oauth_states: HashMap<i64, OAuthState>,
771 }
772
773 #[derive(Clone, Default)]
774 struct MemoryStore {
775 inner: Arc<Mutex<MemoryState>>,
776 }
777
778 #[async_trait]
779 impl UserStore for MemoryStore {
780 async fn create_user(
781 &self,
782 email: &str,
783 name: Option<&str>,
784 password_hash: Option<&str>,
785 ) -> Result<User, AuthError> {
786 let mut state = self.inner.lock().unwrap();
787 if state.users.values().any(|user| user.email == email) {
788 return Err(AuthError::EmailTaken);
789 }
790
791 state.next_user_id += 1;
792 let now = OffsetDateTime::now_utc();
793 let user = User {
794 id: state.next_user_id,
795 email: email.to_string(),
796 name: name.map(str::to_owned),
797 password_hash: password_hash.map(str::to_owned),
798 email_verified_at: None,
799 image: None,
800 created_at: now,
801 updated_at: now,
802 };
803 state.users.insert(user.id, user.clone());
804 Ok(user)
805 }
806
807 async fn find_by_email(&self, email: &str) -> Result<Option<User>, AuthError> {
808 let state = self.inner.lock().unwrap();
809 Ok(state
810 .users
811 .values()
812 .find(|user| user.email == email)
813 .cloned())
814 }
815
816 async fn find_by_id(&self, id: i64) -> Result<Option<User>, AuthError> {
817 Ok(self.inner.lock().unwrap().users.get(&id).cloned())
818 }
819
820 async fn set_email_verified(&self, user_id: i64) -> Result<(), AuthError> {
821 let mut state = self.inner.lock().unwrap();
822 let user = state
823 .users
824 .get_mut(&user_id)
825 .ok_or(AuthError::UserNotFound)?;
826 user.email_verified_at = Some(OffsetDateTime::now_utc());
827 user.updated_at = OffsetDateTime::now_utc();
828 Ok(())
829 }
830
831 async fn update_password(
832 &self,
833 user_id: i64,
834 password_hash: &str,
835 ) -> Result<(), AuthError> {
836 let mut state = self.inner.lock().unwrap();
837 let user = state
838 .users
839 .get_mut(&user_id)
840 .ok_or(AuthError::UserNotFound)?;
841 user.password_hash = Some(password_hash.to_string());
842 user.updated_at = OffsetDateTime::now_utc();
843 Ok(())
844 }
845
846 async fn delete_user(&self, user_id: i64) -> Result<(), AuthError> {
847 self.inner.lock().unwrap().users.remove(&user_id);
848 Ok(())
849 }
850 }
851
852 #[async_trait]
853 impl SessionStore for MemoryStore {
854 async fn create_session(&self, session: NewSession) -> Result<Session, AuthError> {
855 let mut state = self.inner.lock().unwrap();
856 state.next_session_id += 1;
857 let now = OffsetDateTime::now_utc();
858 let session = Session {
859 id: state.next_session_id,
860 token_hash: session.token_hash,
861 user_id: session.user_id,
862 expires_at: session.expires_at,
863 ip_address: session.ip_address,
864 user_agent: session.user_agent,
865 created_at: now,
866 updated_at: now,
867 };
868 state.sessions.insert(session.id, session.clone());
869 Ok(session)
870 }
871
872 async fn find_by_token_hash(&self, token_hash: &str) -> Result<Option<Session>, AuthError> {
873 let state = self.inner.lock().unwrap();
874 Ok(state
875 .sessions
876 .values()
877 .find(|session| session.token_hash == token_hash)
878 .cloned())
879 }
880
881 async fn find_by_user_id(&self, user_id: i64) -> Result<Vec<Session>, AuthError> {
882 let state = self.inner.lock().unwrap();
883 Ok(state
884 .sessions
885 .values()
886 .filter(|session| session.user_id == user_id)
887 .cloned()
888 .collect())
889 }
890
891 async fn delete_session(&self, id: i64) -> Result<(), AuthError> {
892 self.inner.lock().unwrap().sessions.remove(&id);
893 Ok(())
894 }
895
896 async fn delete_by_user_id(&self, user_id: i64) -> Result<(), AuthError> {
897 self.inner
898 .lock()
899 .unwrap()
900 .sessions
901 .retain(|_, session| session.user_id != user_id);
902 Ok(())
903 }
904
905 async fn delete_expired(&self) -> Result<u64, AuthError> {
906 let now = OffsetDateTime::now_utc();
907 let mut state = self.inner.lock().unwrap();
908 let before = state.sessions.len();
909 state
910 .sessions
911 .retain(|_, session| session.expires_at >= now);
912 Ok((before - state.sessions.len()) as u64)
913 }
914 }
915
916 #[async_trait]
917 impl VerificationStore for MemoryStore {
918 async fn create_verification(
919 &self,
920 verification: NewVerification,
921 ) -> Result<Verification, AuthError> {
922 let mut state = self.inner.lock().unwrap();
923 state.next_verification_id += 1;
924 let now = OffsetDateTime::now_utc();
925 let verification = Verification {
926 id: state.next_verification_id,
927 identifier: verification.identifier,
928 token_hash: verification.token_hash,
929 expires_at: verification.expires_at,
930 created_at: now,
931 updated_at: now,
932 };
933 state
934 .verifications
935 .insert(verification.id, verification.clone());
936 Ok(verification)
937 }
938
939 async fn find_by_identifier(
940 &self,
941 identifier: &str,
942 ) -> Result<Option<Verification>, AuthError> {
943 let state = self.inner.lock().unwrap();
944 Ok(state
945 .verifications
946 .values()
947 .find(|verification| verification.identifier == identifier)
948 .cloned())
949 }
950
951 async fn find_by_token_hash(
952 &self,
953 token_hash: &str,
954 ) -> Result<Option<Verification>, AuthError> {
955 let state = self.inner.lock().unwrap();
956 Ok(state
957 .verifications
958 .values()
959 .find(|verification| verification.token_hash == token_hash)
960 .cloned())
961 }
962
963 async fn delete_verification(&self, id: i64) -> Result<(), AuthError> {
964 self.inner.lock().unwrap().verifications.remove(&id);
965 Ok(())
966 }
967
968 async fn delete_by_identifier(&self, identifier: &str) -> Result<(), AuthError> {
969 self.inner
970 .lock()
971 .unwrap()
972 .verifications
973 .retain(|_, verification| verification.identifier != identifier);
974 Ok(())
975 }
976
977 async fn delete_expired(&self) -> Result<u64, AuthError> {
978 let now = OffsetDateTime::now_utc();
979 let mut state = self.inner.lock().unwrap();
980 let before = state.verifications.len();
981 state
982 .verifications
983 .retain(|_, verification| verification.expires_at >= now);
984 Ok((before - state.verifications.len()) as u64)
985 }
986 }
987
988 #[async_trait]
989 impl AccountStore for MemoryStore {
990 async fn create_account(&self, account: NewAccount) -> Result<Account, AuthError> {
991 let mut state = self.inner.lock().unwrap();
992 state.next_account_id += 1;
993 let now = OffsetDateTime::now_utc();
994 let account = Account {
995 id: state.next_account_id,
996 user_id: account.user_id,
997 provider_id: account.provider_id,
998 account_id: account.account_id,
999 access_token: account.access_token,
1000 refresh_token: account.refresh_token,
1001 access_token_expires_at: account.access_token_expires_at,
1002 scope: account.scope,
1003 created_at: now,
1004 updated_at: now,
1005 };
1006 state.accounts.insert(account.id, account.clone());
1007 Ok(account)
1008 }
1009
1010 async fn find_by_provider(
1011 &self,
1012 provider_id: &str,
1013 account_id: &str,
1014 ) -> Result<Option<Account>, AuthError> {
1015 let state = self.inner.lock().unwrap();
1016 Ok(state
1017 .accounts
1018 .values()
1019 .find(|account| {
1020 account.provider_id == provider_id && account.account_id == account_id
1021 })
1022 .cloned())
1023 }
1024
1025 async fn find_by_user_id(&self, user_id: i64) -> Result<Vec<Account>, AuthError> {
1026 let state = self.inner.lock().unwrap();
1027 Ok(state
1028 .accounts
1029 .values()
1030 .filter(|account| account.user_id == user_id)
1031 .cloned()
1032 .collect())
1033 }
1034
1035 async fn delete_account(&self, id: i64) -> Result<(), AuthError> {
1036 self.inner.lock().unwrap().accounts.remove(&id);
1037 Ok(())
1038 }
1039
1040 async fn update_account(
1041 &self,
1042 id: i64,
1043 access_token: Option<String>,
1044 refresh_token: Option<String>,
1045 access_token_expires_at: Option<OffsetDateTime>,
1046 scope: Option<String>,
1047 ) -> Result<(), AuthError> {
1048 let mut state = self.inner.lock().unwrap();
1049 let account = state
1050 .accounts
1051 .get_mut(&id)
1052 .ok_or(AuthError::OAuth(OAuthError::AccountNotFound))?;
1053 account.access_token = access_token;
1054 account.refresh_token = refresh_token;
1055 account.access_token_expires_at = access_token_expires_at;
1056 account.scope = scope;
1057 account.updated_at = OffsetDateTime::now_utc();
1058 Ok(())
1059 }
1060 }
1061
1062 #[async_trait]
1063 impl OAuthStateStore for MemoryStore {
1064 async fn create_oauth_state(
1065 &self,
1066 new_state: NewOAuthState,
1067 ) -> Result<OAuthState, AuthError> {
1068 let mut state = self.inner.lock().unwrap();
1069 state.next_oauth_state_id += 1;
1070 let now = OffsetDateTime::now_utc();
1071 let oauth_state = OAuthState {
1072 id: state.next_oauth_state_id,
1073 provider_id: new_state.provider_id,
1074 csrf_state: new_state.csrf_state,
1075 pkce_verifier: new_state.pkce_verifier,
1076 intent: new_state.intent,
1077 link_user_id: new_state.link_user_id,
1078 expires_at: new_state.expires_at,
1079 created_at: now,
1080 };
1081 state
1082 .oauth_states
1083 .insert(oauth_state.id, oauth_state.clone());
1084 Ok(oauth_state)
1085 }
1086
1087 async fn find_by_csrf_state(
1088 &self,
1089 csrf_state: &str,
1090 ) -> Result<Option<OAuthState>, AuthError> {
1091 let state = self.inner.lock().unwrap();
1092 Ok(state
1093 .oauth_states
1094 .values()
1095 .find(|s| s.csrf_state == csrf_state)
1096 .cloned())
1097 }
1098
1099 async fn delete_oauth_state(&self, id: i64) -> Result<(), AuthError> {
1100 self.inner.lock().unwrap().oauth_states.remove(&id);
1101 Ok(())
1102 }
1103
1104 async fn delete_expired_oauth_states(&self) -> Result<u64, AuthError> {
1105 let now = OffsetDateTime::now_utc();
1106 let mut state = self.inner.lock().unwrap();
1107 let before = state.oauth_states.len();
1108 state.oauth_states.retain(|_, s| s.expires_at >= now);
1109 Ok((before - state.oauth_states.len()) as u64)
1110 }
1111 }
1112
1113 #[derive(Clone, Default)]
1114 struct TestEmailSender {
1115 verification_tokens: Arc<Mutex<Vec<String>>>,
1116 reset_tokens: Arc<Mutex<Vec<String>>>,
1117 }
1118
1119 #[async_trait]
1120 impl EmailSender for TestEmailSender {
1121 async fn send_verification_email(
1122 &self,
1123 _user: &User,
1124 token: &str,
1125 ) -> Result<(), AuthError> {
1126 self.verification_tokens
1127 .lock()
1128 .unwrap()
1129 .push(token.to_string());
1130 Ok(())
1131 }
1132
1133 async fn send_password_reset_email(
1134 &self,
1135 _user: &User,
1136 token: &str,
1137 ) -> Result<(), AuthError> {
1138 self.reset_tokens.lock().unwrap().push(token.to_string());
1139 Ok(())
1140 }
1141 }
1142
1143 #[tokio::test]
1144 async fn signup_verify_login_and_reset_flow_works() {
1145 let store = MemoryStore::default();
1146 let email = TestEmailSender::default();
1147 let service = AuthService::new(
1148 AuthConfig::default(),
1149 store.clone(),
1150 store.clone(),
1151 store.clone(),
1152 store.clone(),
1153 store.clone(),
1154 email.clone(),
1155 );
1156
1157 let signup = service
1158 .signup(
1159 NewUser {
1160 email: "test@example.com".to_string(),
1161 name: Some("Test".to_string()),
1162 password: "supersecret".to_string(),
1163 },
1164 Some("127.0.0.1".to_string()),
1165 Some("test-agent".to_string()),
1166 )
1167 .await
1168 .unwrap();
1169
1170 assert_eq!(signup.user.email, "test@example.com");
1171 assert!(signup.session.is_some());
1172 assert_eq!(email.verification_tokens.lock().unwrap().len(), 1);
1173
1174 let verification_token = email.verification_tokens.lock().unwrap()[0].clone();
1175 let verify = service
1176 .verify_email(&verification_token, None, None)
1177 .await
1178 .unwrap();
1179 assert!(verify.user.is_verified());
1180
1181 let login = service
1182 .login("test@example.com", "supersecret", None, None)
1183 .await
1184 .unwrap();
1185 assert_eq!(login.user.email, "test@example.com");
1186
1187 service
1188 .request_password_reset("test@example.com")
1189 .await
1190 .unwrap();
1191 let reset_token = email.reset_tokens.lock().unwrap()[0].clone();
1192 service
1193 .reset_password(&reset_token, "newpassword")
1194 .await
1195 .unwrap();
1196
1197 let login = service
1198 .login("test@example.com", "newpassword", None, None)
1199 .await
1200 .unwrap();
1201 assert_eq!(login.user.email, "test@example.com");
1202 }
1203
1204 #[tokio::test]
1205 async fn oauth_callback_creates_new_user_and_account() {
1206 let store = MemoryStore::default();
1207 let email = TestEmailSender::default();
1208 let service = AuthService::new(
1209 AuthConfig::default(),
1210 store.clone(),
1211 store.clone(),
1212 store.clone(),
1213 store.clone(),
1214 store.clone(),
1215 email.clone(),
1216 );
1217
1218 let oauth_info = crate::oauth::OAuthUserInfo {
1219 provider_id: "google".to_string(),
1220 account_id: "google-123".to_string(),
1221 email: "oauth@example.com".to_string(),
1222 name: Some("OAuth User".to_string()),
1223 image: Some("https://example.com/avatar.jpg".to_string()),
1224 };
1225
1226 let result = service
1227 .oauth_callback(
1228 oauth_info,
1229 OAuthTokens::default(),
1230 Some("127.0.0.1".to_string()),
1231 Some("test-agent".to_string()),
1232 )
1233 .await
1234 .unwrap();
1235
1236 assert_eq!(result.user.email, "oauth@example.com");
1237 assert_eq!(result.user.name, Some("OAuth User".to_string()));
1238 assert!(result.user.is_verified());
1239 assert!(result.user.password_hash.is_none());
1240
1241 let accounts = AccountStore::find_by_user_id(&store, result.user.id)
1243 .await
1244 .unwrap();
1245 assert_eq!(accounts.len(), 1);
1246 assert_eq!(accounts[0].provider_id, "google");
1247 assert_eq!(accounts[0].account_id, "google-123");
1248 }
1249
1250 #[tokio::test]
1251 async fn oauth_callback_links_existing_user_by_email() {
1252 let store = MemoryStore::default();
1253 let email = TestEmailSender::default();
1254 let service = AuthService::new(
1255 AuthConfig::default(),
1256 store.clone(),
1257 store.clone(),
1258 store.clone(),
1259 store.clone(),
1260 store.clone(),
1261 email.clone(),
1262 );
1263
1264 let existing_user = store
1266 .create_user(
1267 "existing@example.com",
1268 Some("Existing User"),
1269 Some("hash123"),
1270 )
1271 .await
1272 .unwrap();
1273
1274 let oauth_info = crate::oauth::OAuthUserInfo {
1275 provider_id: "github".to_string(),
1276 account_id: "github-456".to_string(),
1277 email: "existing@example.com".to_string(),
1278 name: Some("GitHub User".to_string()),
1279 image: None,
1280 };
1281
1282 let result = service
1283 .oauth_callback(
1284 oauth_info,
1285 OAuthTokens::default(),
1286 Some("127.0.0.1".to_string()),
1287 Some("test-agent".to_string()),
1288 )
1289 .await
1290 .unwrap();
1291
1292 assert_eq!(result.user.id, existing_user.id);
1294 assert_eq!(result.user.email, "existing@example.com");
1295 assert!(result.user.is_verified());
1296
1297 let accounts = AccountStore::find_by_user_id(&store, result.user.id)
1299 .await
1300 .unwrap();
1301 assert_eq!(accounts.len(), 1);
1302 assert_eq!(accounts[0].provider_id, "github");
1303 assert_eq!(accounts[0].account_id, "github-456");
1304 }
1305
1306 #[tokio::test]
1307 async fn oauth_callback_logs_in_existing_account() {
1308 let store = MemoryStore::default();
1309 let email = TestEmailSender::default();
1310 let service = AuthService::new(
1311 AuthConfig::default(),
1312 store.clone(),
1313 store.clone(),
1314 store.clone(),
1315 store.clone(),
1316 store.clone(),
1317 email.clone(),
1318 );
1319
1320 let user = store
1322 .create_user("oauth@example.com", Some("OAuth User"), None)
1323 .await
1324 .unwrap();
1325 store
1326 .create_account(crate::types::NewAccount {
1327 user_id: user.id,
1328 provider_id: "google".to_string(),
1329 account_id: "google-789".to_string(),
1330 access_token: None,
1331 refresh_token: None,
1332 access_token_expires_at: None,
1333 scope: None,
1334 })
1335 .await
1336 .unwrap();
1337
1338 let oauth_info = crate::oauth::OAuthUserInfo {
1339 provider_id: "google".to_string(),
1340 account_id: "google-789".to_string(),
1341 email: "oauth@example.com".to_string(),
1342 name: Some("OAuth User".to_string()),
1343 image: None,
1344 };
1345
1346 let result = service
1347 .oauth_callback(
1348 oauth_info,
1349 OAuthTokens::default(),
1350 Some("127.0.0.1".to_string()),
1351 Some("test-agent".to_string()),
1352 )
1353 .await
1354 .unwrap();
1355
1356 assert_eq!(result.user.id, user.id);
1358 assert_eq!(result.user.email, "oauth@example.com");
1359
1360 let accounts = AccountStore::find_by_user_id(&store, result.user.id)
1362 .await
1363 .unwrap();
1364 assert_eq!(accounts.len(), 1);
1365 }
1366
1367 #[tokio::test]
1368 async fn oauth_callback_respects_linking_policy() {
1369 let store = MemoryStore::default();
1370 let email = TestEmailSender::default();
1371 let mut config = AuthConfig::default();
1372 config.oauth.allow_implicit_account_linking = false;
1373
1374 let service = AuthService::new(
1375 config,
1376 store.clone(),
1377 store.clone(),
1378 store.clone(),
1379 store.clone(),
1380 store.clone(),
1381 email.clone(),
1382 );
1383
1384 store
1386 .create_user(
1387 "existing@example.com",
1388 Some("Existing User"),
1389 Some("hash123"),
1390 )
1391 .await
1392 .unwrap();
1393
1394 let oauth_info = crate::oauth::OAuthUserInfo {
1395 provider_id: "google".to_string(),
1396 account_id: "google-999".to_string(),
1397 email: "existing@example.com".to_string(),
1398 name: Some("OAuth User".to_string()),
1399 image: None,
1400 };
1401
1402 let result = service
1404 .oauth_callback(
1405 oauth_info,
1406 OAuthTokens::default(),
1407 Some("127.0.0.1".to_string()),
1408 Some("test-agent".to_string()),
1409 )
1410 .await;
1411
1412 assert!(result.is_err());
1413 match result {
1414 Err(AuthError::OAuth(OAuthError::LinkingDisabled)) => {}
1415 _ => panic!("Expected OAuth linking-disabled error"),
1416 }
1417 }
1418
1419 #[tokio::test]
1420 async fn cleanup_expired_removes_sessions_verifications_and_oauth_states() {
1421 let store = MemoryStore::default();
1422 let email = TestEmailSender::default();
1423 let service = AuthService::new(
1424 AuthConfig::default(),
1425 store.clone(),
1426 store.clone(),
1427 store.clone(),
1428 store.clone(),
1429 store.clone(),
1430 email,
1431 );
1432
1433 store
1434 .create_session(NewSession {
1435 token_hash: "expired-session".to_string(),
1436 user_id: 1,
1437 expires_at: OffsetDateTime::now_utc() - time::Duration::hours(1),
1438 ip_address: None,
1439 user_agent: None,
1440 })
1441 .await
1442 .unwrap();
1443 store
1444 .create_verification(NewVerification {
1445 identifier: "email-verify:test@example.com".to_string(),
1446 token_hash: "expired-verification".to_string(),
1447 expires_at: OffsetDateTime::now_utc() - time::Duration::hours(1),
1448 })
1449 .await
1450 .unwrap();
1451 store
1452 .create_oauth_state(NewOAuthState {
1453 provider_id: "google".to_string(),
1454 csrf_state: "expired-oauth-state".to_string(),
1455 pkce_verifier: "pkce-verifier".to_string(),
1456 intent: OAuthIntent::Login,
1457 link_user_id: None,
1458 expires_at: OffsetDateTime::now_utc() - time::Duration::hours(1),
1459 })
1460 .await
1461 .unwrap();
1462
1463 let deleted = service.cleanup_expired().await.unwrap();
1464
1465 assert_eq!(deleted, (1, 1, 1));
1466 }
1467
1468 #[tokio::test]
1469 async fn list_accounts_returns_empty_for_user_with_no_accounts() {
1470 let store = MemoryStore::default();
1471 let email = TestEmailSender::default();
1472 let service = AuthService::new(
1473 AuthConfig::default(),
1474 store.clone(),
1475 store.clone(),
1476 store.clone(),
1477 store.clone(),
1478 store.clone(),
1479 email,
1480 );
1481
1482 let user = store
1483 .create_user("test@example.com", Some("Test User"), Some("hash123"))
1484 .await
1485 .unwrap();
1486
1487 let accounts = service.list_accounts(user.id).await.unwrap();
1488
1489 assert!(accounts.is_empty());
1490 }
1491
1492 #[tokio::test]
1493 async fn list_accounts_returns_public_accounts_without_tokens() {
1494 let store = MemoryStore::default();
1495 let email = TestEmailSender::default();
1496 let service = AuthService::new(
1497 AuthConfig::default(),
1498 store.clone(),
1499 store.clone(),
1500 store.clone(),
1501 store.clone(),
1502 store.clone(),
1503 email,
1504 );
1505
1506 let user = store
1507 .create_user("test@example.com", Some("Test User"), Some("hash123"))
1508 .await
1509 .unwrap();
1510
1511 store
1512 .create_account(NewAccount {
1513 user_id: user.id,
1514 provider_id: "google".to_string(),
1515 account_id: "google-123".to_string(),
1516 access_token: Some("secret-token".to_string()),
1517 refresh_token: Some("refresh-secret".to_string()),
1518 access_token_expires_at: None,
1519 scope: Some("openid,email".to_string()),
1520 })
1521 .await
1522 .unwrap();
1523
1524 store
1525 .create_account(NewAccount {
1526 user_id: user.id,
1527 provider_id: "github".to_string(),
1528 account_id: "github-456".to_string(),
1529 access_token: Some("another-token".to_string()),
1530 refresh_token: None,
1531 access_token_expires_at: None,
1532 scope: None,
1533 })
1534 .await
1535 .unwrap();
1536
1537 let accounts = service.list_accounts(user.id).await.unwrap();
1538
1539 assert_eq!(accounts.len(), 2);
1540 let provider_ids: Vec<&str> = accounts.iter().map(|a| a.provider_id.as_str()).collect();
1541 assert!(provider_ids.contains(&"google"));
1542 assert!(provider_ids.contains(&"github"));
1543 assert!(!format!("{:?}", accounts).contains("secret"));
1545 }
1546
1547 #[tokio::test]
1548 async fn link_account_creates_account_for_authenticated_user() {
1549 let store = MemoryStore::default();
1550 let email = TestEmailSender::default();
1551 let service = AuthService::new(
1552 AuthConfig::default(),
1553 store.clone(),
1554 store.clone(),
1555 store.clone(),
1556 store.clone(),
1557 store.clone(),
1558 email,
1559 );
1560
1561 let user = store
1562 .create_user("test@example.com", Some("Test User"), Some("hash123"))
1563 .await
1564 .unwrap();
1565
1566 let oauth_info = crate::oauth::OAuthUserInfo {
1567 provider_id: "google".to_string(),
1568 account_id: "google-123".to_string(),
1569 email: "test@example.com".to_string(),
1570 name: None,
1571 image: None,
1572 };
1573
1574 let result = service
1575 .link_account(user.id, oauth_info, OAuthTokens::default())
1576 .await;
1577
1578 assert!(result.is_ok());
1579
1580 let accounts = AccountStore::find_by_user_id(&store, user.id)
1581 .await
1582 .unwrap();
1583 assert_eq!(accounts.len(), 1);
1584 assert_eq!(accounts[0].provider_id, "google");
1585 assert_eq!(accounts[0].account_id, "google-123");
1586 }
1587
1588 #[tokio::test]
1589 async fn link_account_is_idempotent_when_already_linked_to_same_user() {
1590 let store = MemoryStore::default();
1591 let email = TestEmailSender::default();
1592 let service = AuthService::new(
1593 AuthConfig::default(),
1594 store.clone(),
1595 store.clone(),
1596 store.clone(),
1597 store.clone(),
1598 store.clone(),
1599 email,
1600 );
1601
1602 let user = store
1603 .create_user("test@example.com", Some("Test User"), Some("hash123"))
1604 .await
1605 .unwrap();
1606
1607 store
1608 .create_account(NewAccount {
1609 user_id: user.id,
1610 provider_id: "google".to_string(),
1611 account_id: "google-123".to_string(),
1612 access_token: Some("old-token".to_string()),
1613 refresh_token: Some("old-refresh".to_string()),
1614 access_token_expires_at: None,
1615 scope: None,
1616 })
1617 .await
1618 .unwrap();
1619
1620 let oauth_info = crate::oauth::OAuthUserInfo {
1621 provider_id: "google".to_string(),
1622 account_id: "google-123".to_string(),
1623 email: "test@example.com".to_string(),
1624 name: None,
1625 image: None,
1626 };
1627
1628 let result = service
1629 .link_account(user.id, oauth_info, OAuthTokens::default())
1630 .await;
1631
1632 assert!(result.is_ok());
1633
1634 let accounts = AccountStore::find_by_user_id(&store, user.id)
1635 .await
1636 .unwrap();
1637 assert_eq!(accounts.len(), 1);
1638 }
1639
1640 #[tokio::test]
1641 async fn link_account_rejects_when_already_linked_to_different_user() {
1642 let store = MemoryStore::default();
1643 let email = TestEmailSender::default();
1644 let service = AuthService::new(
1645 AuthConfig::default(),
1646 store.clone(),
1647 store.clone(),
1648 store.clone(),
1649 store.clone(),
1650 store.clone(),
1651 email,
1652 );
1653
1654 let user_a = store
1655 .create_user("usera@example.com", Some("User A"), Some("hash123"))
1656 .await
1657 .unwrap();
1658
1659 let user_b = store
1660 .create_user("userb@example.com", Some("User B"), Some("hash456"))
1661 .await
1662 .unwrap();
1663
1664 store
1665 .create_account(NewAccount {
1666 user_id: user_a.id,
1667 provider_id: "google".to_string(),
1668 account_id: "google-123".to_string(),
1669 access_token: None,
1670 refresh_token: None,
1671 access_token_expires_at: None,
1672 scope: None,
1673 })
1674 .await
1675 .unwrap();
1676
1677 let oauth_info = crate::oauth::OAuthUserInfo {
1678 provider_id: "google".to_string(),
1679 account_id: "google-123".to_string(),
1680 email: "userb@example.com".to_string(),
1681 name: None,
1682 image: None,
1683 };
1684
1685 let result = service
1686 .link_account(user_b.id, oauth_info, OAuthTokens::default())
1687 .await;
1688
1689 assert!(result.is_err());
1690 match result {
1691 Err(AuthError::OAuth(OAuthError::AccountAlreadyLinked)) => {}
1692 _ => panic!("Expected AccountAlreadyLinked error"),
1693 }
1694 }
1695
1696 #[tokio::test]
1697 async fn link_account_updates_existing_account_tokens() {
1698 let store = MemoryStore::default();
1699 let email = TestEmailSender::default();
1700 let service = AuthService::new(
1701 AuthConfig::default(),
1702 store.clone(),
1703 store.clone(),
1704 store.clone(),
1705 store.clone(),
1706 store.clone(),
1707 email,
1708 );
1709
1710 let user = store
1711 .create_user("test@example.com", Some("Test User"), Some("hash123"))
1712 .await
1713 .unwrap();
1714
1715 store
1716 .create_account(NewAccount {
1717 user_id: user.id,
1718 provider_id: "google".to_string(),
1719 account_id: "google-123".to_string(),
1720 access_token: Some("old-token".to_string()),
1721 refresh_token: Some("old-refresh".to_string()),
1722 access_token_expires_at: None,
1723 scope: None,
1724 })
1725 .await
1726 .unwrap();
1727
1728 let oauth_info = crate::oauth::OAuthUserInfo {
1729 provider_id: "google".to_string(),
1730 account_id: "google-123".to_string(),
1731 email: "test@example.com".to_string(),
1732 name: None,
1733 image: None,
1734 };
1735
1736 let tokens = OAuthTokens {
1737 access_token: Some("new-token".to_string()),
1738 refresh_token: Some("new-refresh".to_string()),
1739 ..Default::default()
1740 };
1741
1742 service
1743 .link_account(user.id, oauth_info, tokens)
1744 .await
1745 .unwrap();
1746
1747 let accounts = AccountStore::find_by_user_id(&store, user.id)
1748 .await
1749 .unwrap();
1750 assert_eq!(accounts.len(), 1);
1751 assert_eq!(accounts[0].access_token, Some("new-token".to_string()));
1752 assert_eq!(accounts[0].refresh_token, Some("new-refresh".to_string()));
1753 }
1754
1755 #[tokio::test]
1756 async fn unlink_account_removes_account() {
1757 let store = MemoryStore::default();
1758 let email = TestEmailSender::default();
1759 let service = AuthService::new(
1760 AuthConfig::default(),
1761 store.clone(),
1762 store.clone(),
1763 store.clone(),
1764 store.clone(),
1765 store.clone(),
1766 email,
1767 );
1768
1769 let user = store
1770 .create_user("test@example.com", Some("Test User"), Some("password-hash"))
1771 .await
1772 .unwrap();
1773
1774 let account1 = store
1775 .create_account(NewAccount {
1776 user_id: user.id,
1777 provider_id: "google".to_string(),
1778 account_id: "google-123".to_string(),
1779 access_token: None,
1780 refresh_token: None,
1781 access_token_expires_at: None,
1782 scope: None,
1783 })
1784 .await
1785 .unwrap();
1786
1787 store
1788 .create_account(NewAccount {
1789 user_id: user.id,
1790 provider_id: "github".to_string(),
1791 account_id: "github-456".to_string(),
1792 access_token: None,
1793 refresh_token: None,
1794 access_token_expires_at: None,
1795 scope: None,
1796 })
1797 .await
1798 .unwrap();
1799
1800 service.unlink_account(user.id, account1.id).await.unwrap();
1801
1802 let accounts = AccountStore::find_by_user_id(&store, user.id)
1803 .await
1804 .unwrap();
1805 assert_eq!(accounts.len(), 1);
1806 assert_eq!(accounts[0].provider_id, "github");
1807 }
1808
1809 #[tokio::test]
1810 async fn unlink_account_rejects_last_auth_method() {
1811 let store = MemoryStore::default();
1812 let email = TestEmailSender::default();
1813 let service = AuthService::new(
1814 AuthConfig::default(),
1815 store.clone(),
1816 store.clone(),
1817 store.clone(),
1818 store.clone(),
1819 store.clone(),
1820 email,
1821 );
1822
1823 let user = store
1825 .create_user("test@example.com", Some("Test User"), None)
1826 .await
1827 .unwrap();
1828
1829 let account = store
1830 .create_account(NewAccount {
1831 user_id: user.id,
1832 provider_id: "google".to_string(),
1833 account_id: "google-123".to_string(),
1834 access_token: None,
1835 refresh_token: None,
1836 access_token_expires_at: None,
1837 scope: None,
1838 })
1839 .await
1840 .unwrap();
1841
1842 let result = service.unlink_account(user.id, account.id).await;
1843
1844 assert!(result.is_err());
1845 match result {
1846 Err(AuthError::OAuth(OAuthError::LastAuthMethod)) => {}
1847 _ => panic!("Expected LastAuthMethod error"),
1848 }
1849 }
1850
1851 #[tokio::test]
1852 async fn refresh_oauth_token_rejects_when_no_refresh_token() {
1853 let store = MemoryStore::default();
1854 let email = TestEmailSender::default();
1855 let service = AuthService::new(
1856 AuthConfig::default(),
1857 store.clone(),
1858 store.clone(),
1859 store.clone(),
1860 store.clone(),
1861 store.clone(),
1862 email,
1863 );
1864
1865 let user = store
1866 .create_user("test@example.com", Some("Test User"), None)
1867 .await
1868 .unwrap();
1869
1870 let account = store
1871 .create_account(NewAccount {
1872 user_id: user.id,
1873 provider_id: "google".to_string(),
1874 account_id: "google-123".to_string(),
1875 access_token: Some("old-token".to_string()),
1876 refresh_token: None, access_token_expires_at: None,
1878 scope: None,
1879 })
1880 .await
1881 .unwrap();
1882
1883 let provider_config = OAuthProviderConfig {
1884 provider_id: "google".to_string(),
1885 client_id: "test".to_string(),
1886 client_secret: "test".to_string(),
1887 auth_url: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
1888 token_url: "https://oauth2.googleapis.com/token".to_string(),
1889 userinfo_url: "https://www.googleapis.com/oauth2/v2/userinfo".to_string(),
1890 redirect_url: "https://localhost/callback".to_string(),
1891 scopes: vec!["openid".to_string()],
1892 };
1893
1894 let result = service
1895 .refresh_oauth_token(user.id, account.id, &provider_config)
1896 .await;
1897
1898 assert!(result.is_err());
1899 match result {
1900 Err(AuthError::OAuth(OAuthError::NoRefreshToken)) => {}
1901 _ => panic!("Expected NoRefreshToken error"),
1902 }
1903 }
1904
1905 #[tokio::test]
1906 async fn refresh_oauth_token_rejects_for_wrong_account_id() {
1907 let store = MemoryStore::default();
1908 let email = TestEmailSender::default();
1909 let service = AuthService::new(
1910 AuthConfig::default(),
1911 store.clone(),
1912 store.clone(),
1913 store.clone(),
1914 store.clone(),
1915 store.clone(),
1916 email,
1917 );
1918
1919 let user = store
1920 .create_user("test@example.com", Some("Test User"), None)
1921 .await
1922 .unwrap();
1923
1924 store
1925 .create_account(NewAccount {
1926 user_id: user.id,
1927 provider_id: "google".to_string(),
1928 account_id: "google-123".to_string(),
1929 access_token: Some("old-token".to_string()),
1930 refresh_token: Some("refresh-token".to_string()),
1931 access_token_expires_at: None,
1932 scope: None,
1933 })
1934 .await
1935 .unwrap();
1936
1937 let provider_config = OAuthProviderConfig {
1938 provider_id: "google".to_string(),
1939 client_id: "test".to_string(),
1940 client_secret: "test".to_string(),
1941 auth_url: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
1942 token_url: "https://oauth2.googleapis.com/token".to_string(),
1943 userinfo_url: "https://www.googleapis.com/oauth2/v2/userinfo".to_string(),
1944 redirect_url: "https://localhost/callback".to_string(),
1945 scopes: vec!["openid".to_string()],
1946 };
1947
1948 let result = service
1949 .refresh_oauth_token(user.id, 9999, &provider_config)
1950 .await;
1951
1952 assert!(result.is_err());
1953 match result {
1954 Err(AuthError::OAuth(OAuthError::AccountNotFound)) => {}
1955 _ => panic!("Expected AccountNotFound error"),
1956 }
1957 }
1958}