1use std::sync::Arc;
2
3use time::OffsetDateTime;
4
5use crate::config::AuthConfig;
6use crate::crypto::{hash, token};
7use crate::email::EmailSender;
8use crate::error::AuthError;
9use crate::oauth::{OAuthTokens, OAuthUserInfo};
10use crate::store::{AccountStore, SessionStore, UserStore, VerificationStore};
11use crate::types::{NewAccount, NewSession, NewUser, NewVerification, Session, User, Verification};
12
13#[derive(Debug)]
15pub struct SignupResult {
16 pub user: User,
18 pub session: Option<Session>,
20 pub session_token: Option<String>,
22 pub verification_token: Option<String>,
24}
25
26#[derive(Debug)]
28pub struct LoginResult {
29 pub user: User,
31 pub session: Session,
33 pub session_token: String,
35}
36
37#[derive(Debug)]
39pub struct VerifyEmailResult {
40 pub user: User,
42 pub session: Option<Session>,
44 pub session_token: Option<String>,
46}
47
48#[derive(Debug, Default)]
50pub struct RequestResetResult {
51 pub _private: (),
53}
54
55#[derive(Debug)]
57pub struct ResetPasswordResult {
58 pub user: User,
60}
61
62#[derive(Debug)]
64pub struct SessionResult {
65 pub user: User,
67 pub session: Session,
69}
70
71pub struct AuthService<U, S, V, A, E>
73where
74 U: UserStore,
75 S: SessionStore,
76 V: VerificationStore,
77 A: AccountStore,
78 E: EmailSender,
79{
80 pub config: AuthConfig,
82 pub users: Arc<U>,
84 pub sessions: Arc<S>,
86 pub verifications: Arc<V>,
88 pub accounts: Arc<A>,
90 pub email: Arc<E>,
92}
93
94impl<U, S, V, A, E> AuthService<U, S, V, A, E>
95where
96 U: UserStore,
97 S: SessionStore,
98 V: VerificationStore,
99 A: AccountStore,
100 E: EmailSender,
101{
102 pub fn new(
104 config: AuthConfig,
105 users: U,
106 sessions: S,
107 verifications: V,
108 accounts: A,
109 email: E,
110 ) -> Self {
111 Self {
112 config,
113 users: Arc::new(users),
114 sessions: Arc::new(sessions),
115 verifications: Arc::new(verifications),
116 accounts: Arc::new(accounts),
117 email: Arc::new(email),
118 }
119 }
120
121 pub async fn signup(
123 &self,
124 input: NewUser,
125 ip: Option<String>,
126 user_agent: Option<String>,
127 ) -> Result<SignupResult, AuthError> {
128 if input.password.len() < 8 {
129 return Err(AuthError::WeakPassword(8));
130 }
131
132 let email = input.email.trim().to_lowercase();
133 if self.users.find_by_email(&email).await?.is_some() {
134 return Err(AuthError::EmailTaken);
135 }
136
137 let password_hash = hash::hash_password(&input.password)?;
138 let user = self
139 .users
140 .create_user(&email, input.name.as_deref(), Some(&password_hash))
141 .await?;
142
143 let verification_token = if self.config.email.send_verification_on_signup {
144 let identifier = format!("email-verify:{}", user.email.to_lowercase());
145 let raw_token = token::generate_token(self.config.token_length);
146
147 let _ = self.verifications.delete_by_identifier(&identifier).await;
148 self.verifications
149 .create_verification(NewVerification {
150 identifier,
151 token_hash: token::hash_token(&raw_token),
152 expires_at: OffsetDateTime::now_utc() + self.config.verification_ttl,
153 })
154 .await?;
155 self.email
156 .send_verification_email(&user, &raw_token)
157 .await?;
158 Some(raw_token)
159 } else {
160 None
161 };
162
163 let (session, session_token) = if self.config.email.auto_sign_in_after_signup {
164 let (session, raw_token) = self
165 .create_session_internal(user.id, ip, user_agent)
166 .await?;
167 (Some(session), Some(raw_token))
168 } else {
169 (None, None)
170 };
171
172 Ok(SignupResult {
173 user,
174 session,
175 session_token,
176 verification_token,
177 })
178 }
179
180 pub async fn login(
182 &self,
183 email: &str,
184 password: &str,
185 ip: Option<String>,
186 user_agent: Option<String>,
187 ) -> Result<LoginResult, AuthError> {
188 let user = self
189 .users
190 .find_by_email(&email.trim().to_lowercase())
191 .await?
192 .ok_or(AuthError::InvalidCredentials)?;
193
194 let password_hash = user
195 .password_hash
196 .as_deref()
197 .ok_or(AuthError::InvalidCredentials)?;
198 if !hash::verify_password(password, password_hash)? {
199 return Err(AuthError::InvalidCredentials);
200 }
201
202 if self.config.email.require_verification_to_login && !user.is_verified() {
203 return Err(AuthError::EmailNotVerified);
204 }
205
206 let (session, session_token) = self
207 .create_session_internal(user.id, ip, user_agent)
208 .await?;
209
210 Ok(LoginResult {
211 user,
212 session,
213 session_token,
214 })
215 }
216
217 pub async fn logout(&self, session_id: i64) -> Result<(), AuthError> {
219 self.sessions.delete_session(session_id).await
220 }
221
222 pub async fn logout_all(&self, user_id: i64) -> Result<(), AuthError> {
224 self.sessions.delete_by_user_id(user_id).await
225 }
226
227 pub async fn get_session(&self, raw_token: &str) -> Result<SessionResult, AuthError> {
229 let session = self
230 .sessions
231 .find_by_token_hash(&token::hash_token(raw_token))
232 .await?
233 .ok_or(AuthError::SessionNotFound)?;
234
235 if session.expires_at < OffsetDateTime::now_utc() {
236 self.sessions.delete_session(session.id).await?;
237 return Err(AuthError::SessionNotFound);
238 }
239
240 let user = self
241 .users
242 .find_by_id(session.user_id)
243 .await?
244 .ok_or(AuthError::UserNotFound)?;
245
246 Ok(SessionResult { user, session })
247 }
248
249 pub async fn list_sessions(&self, user_id: i64) -> Result<Vec<Session>, AuthError> {
251 self.sessions.find_by_user_id(user_id).await
252 }
253
254 pub async fn verify_email(
256 &self,
257 raw_token: &str,
258 ip: Option<String>,
259 user_agent: Option<String>,
260 ) -> Result<VerifyEmailResult, AuthError> {
261 let verification = self.lookup_verification(raw_token, "email-verify:").await?;
262 let email = verification
263 .identifier
264 .strip_prefix("email-verify:")
265 .ok_or(AuthError::InvalidToken)?;
266
267 let user = self
268 .users
269 .find_by_email(email)
270 .await?
271 .ok_or(AuthError::UserNotFound)?;
272 self.users.set_email_verified(user.id).await?;
273 self.verifications
274 .delete_verification(verification.id)
275 .await?;
276
277 let user = self
278 .users
279 .find_by_id(user.id)
280 .await?
281 .ok_or(AuthError::UserNotFound)?;
282
283 let (session, session_token) = if self.config.email.auto_sign_in_after_verification {
284 let (session, raw_token) = self
285 .create_session_internal(user.id, ip, user_agent)
286 .await?;
287 (Some(session), Some(raw_token))
288 } else {
289 (None, None)
290 };
291
292 Ok(VerifyEmailResult {
293 user,
294 session,
295 session_token,
296 })
297 }
298
299 pub async fn request_password_reset(
301 &self,
302 email: &str,
303 ) -> Result<RequestResetResult, AuthError> {
304 let email = email.trim().to_lowercase();
305 if let Some(user) = self.users.find_by_email(&email).await? {
306 let identifier = format!("password-reset:{}", user.email.to_lowercase());
307 let _ = self.verifications.delete_by_identifier(&identifier).await;
308
309 let raw_token = token::generate_token(self.config.token_length);
310 self.verifications
311 .create_verification(NewVerification {
312 identifier,
313 token_hash: token::hash_token(&raw_token),
314 expires_at: OffsetDateTime::now_utc() + self.config.reset_ttl,
315 })
316 .await?;
317 self.email
318 .send_password_reset_email(&user, &raw_token)
319 .await?;
320 }
321
322 Ok(RequestResetResult::default())
323 }
324
325 pub async fn reset_password(
327 &self,
328 raw_token: &str,
329 new_password: &str,
330 ) -> Result<ResetPasswordResult, AuthError> {
331 if new_password.len() < 8 {
332 return Err(AuthError::WeakPassword(8));
333 }
334
335 let verification = self
336 .lookup_verification(raw_token, "password-reset:")
337 .await?;
338 let email = verification
339 .identifier
340 .strip_prefix("password-reset:")
341 .ok_or(AuthError::InvalidToken)?;
342 let user = self
343 .users
344 .find_by_email(email)
345 .await?
346 .ok_or(AuthError::UserNotFound)?;
347
348 self.users
349 .update_password(user.id, &hash::hash_password(new_password)?)
350 .await?;
351 self.sessions.delete_by_user_id(user.id).await?;
352 self.verifications
353 .delete_verification(verification.id)
354 .await?;
355
356 let user = self
357 .users
358 .find_by_id(user.id)
359 .await?
360 .ok_or(AuthError::UserNotFound)?;
361
362 Ok(ResetPasswordResult { user })
363 }
364
365 pub async fn cleanup_expired(&self) -> Result<(u64, u64), AuthError> {
367 let sessions_deleted = self.sessions.delete_expired().await?;
368 let verifications_deleted = self.verifications.delete_expired().await?;
369 Ok((sessions_deleted, verifications_deleted))
370 }
371
372 pub async fn oauth_callback(
379 &self,
380 info: OAuthUserInfo,
381 tokens: OAuthTokens,
382 ip: Option<String>,
383 user_agent: Option<String>,
384 ) -> Result<LoginResult, AuthError> {
385 if let Some(account) = self
387 .accounts
388 .find_by_provider(&info.provider_id, &info.account_id)
389 .await?
390 {
391 let user = self
393 .users
394 .find_by_id(account.user_id)
395 .await?
396 .ok_or(AuthError::UserNotFound)?;
397 let (session, session_token) = self
398 .create_session_internal(user.id, ip, user_agent)
399 .await?;
400 return Ok(LoginResult {
401 user,
402 session,
403 session_token,
404 });
405 }
406
407 let user = if let Some(existing_user) = self.users.find_by_email(&info.email).await? {
409 if !self.config.oauth.allow_implicit_account_linking {
411 return Err(AuthError::OAuth(
412 "Account linking by email is disabled. Please sign in with your password first.".to_string()
413 ));
414 }
415 existing_user
416 } else {
417 self.users
419 .create_user(&info.email, info.name.as_deref(), None)
420 .await?
421 };
422
423 let access_token_expires_at = tokens.expires_in.map(|d| OffsetDateTime::now_utc() + d);
425 self.accounts
426 .create_account(NewAccount {
427 user_id: user.id,
428 provider_id: info.provider_id,
429 account_id: info.account_id,
430 access_token: tokens.access_token,
431 refresh_token: tokens.refresh_token,
432 access_token_expires_at,
433 scope: tokens.scope,
434 })
435 .await?;
436
437 if !user.is_verified() {
439 self.users.set_email_verified(user.id).await?;
440 }
441
442 let user = self
444 .users
445 .find_by_id(user.id)
446 .await?
447 .ok_or(AuthError::UserNotFound)?;
448 let (session, session_token) = self
449 .create_session_internal(user.id, ip, user_agent)
450 .await?;
451 Ok(LoginResult {
452 user,
453 session,
454 session_token,
455 })
456 }
457
458 async fn create_session_internal(
459 &self,
460 user_id: i64,
461 ip: Option<String>,
462 user_agent: Option<String>,
463 ) -> Result<(Session, String), AuthError> {
464 let raw_token = token::generate_token(self.config.token_length);
465 let session = self
466 .sessions
467 .create_session(NewSession {
468 token_hash: token::hash_token(&raw_token),
469 user_id,
470 expires_at: OffsetDateTime::now_utc() + self.config.session_ttl,
471 ip_address: ip,
472 user_agent,
473 })
474 .await?;
475
476 Ok((session, raw_token))
477 }
478
479 async fn lookup_verification(
480 &self,
481 raw_token: &str,
482 prefix: &str,
483 ) -> Result<Verification, AuthError> {
484 let verification = self
485 .verifications
486 .find_by_token_hash(&token::hash_token(raw_token))
487 .await?
488 .ok_or(AuthError::InvalidToken)?;
489
490 if !verification.identifier.starts_with(prefix) {
491 return Err(AuthError::InvalidToken);
492 }
493
494 if verification.expires_at < OffsetDateTime::now_utc() {
495 self.verifications
496 .delete_verification(verification.id)
497 .await?;
498 return Err(AuthError::InvalidToken);
499 }
500
501 Ok(verification)
502 }
503}
504
505#[cfg(test)]
506mod tests {
507 use std::collections::HashMap;
508 use std::sync::{Arc, Mutex};
509
510 use async_trait::async_trait;
511 use time::OffsetDateTime;
512
513 use super::AuthService;
514 use crate::config::AuthConfig;
515 use crate::email::EmailSender;
516 use crate::error::AuthError;
517 use crate::oauth::OAuthTokens;
518 use crate::store::{AccountStore, SessionStore, UserStore, VerificationStore};
519 use crate::types::{
520 Account, NewAccount, NewSession, NewUser, NewVerification, Session, User, Verification,
521 };
522
523 #[derive(Default)]
524 struct MemoryState {
525 next_user_id: i64,
526 next_session_id: i64,
527 next_verification_id: i64,
528 next_account_id: i64,
529 users: HashMap<i64, User>,
530 sessions: HashMap<i64, Session>,
531 verifications: HashMap<i64, Verification>,
532 accounts: HashMap<i64, Account>,
533 }
534
535 #[derive(Clone, Default)]
536 struct MemoryStore {
537 inner: Arc<Mutex<MemoryState>>,
538 }
539
540 #[async_trait]
541 impl UserStore for MemoryStore {
542 async fn create_user(
543 &self,
544 email: &str,
545 name: Option<&str>,
546 password_hash: Option<&str>,
547 ) -> Result<User, AuthError> {
548 let mut state = self.inner.lock().unwrap();
549 if state.users.values().any(|user| user.email == email) {
550 return Err(AuthError::EmailTaken);
551 }
552
553 state.next_user_id += 1;
554 let now = OffsetDateTime::now_utc();
555 let user = User {
556 id: state.next_user_id,
557 email: email.to_string(),
558 name: name.map(str::to_owned),
559 password_hash: password_hash.map(str::to_owned),
560 email_verified_at: None,
561 image: None,
562 created_at: now,
563 updated_at: now,
564 };
565 state.users.insert(user.id, user.clone());
566 Ok(user)
567 }
568
569 async fn find_by_email(&self, email: &str) -> Result<Option<User>, AuthError> {
570 let state = self.inner.lock().unwrap();
571 Ok(state
572 .users
573 .values()
574 .find(|user| user.email == email)
575 .cloned())
576 }
577
578 async fn find_by_id(&self, id: i64) -> Result<Option<User>, AuthError> {
579 Ok(self.inner.lock().unwrap().users.get(&id).cloned())
580 }
581
582 async fn set_email_verified(&self, user_id: i64) -> Result<(), AuthError> {
583 let mut state = self.inner.lock().unwrap();
584 let user = state
585 .users
586 .get_mut(&user_id)
587 .ok_or(AuthError::UserNotFound)?;
588 user.email_verified_at = Some(OffsetDateTime::now_utc());
589 user.updated_at = OffsetDateTime::now_utc();
590 Ok(())
591 }
592
593 async fn update_password(
594 &self,
595 user_id: i64,
596 password_hash: &str,
597 ) -> Result<(), AuthError> {
598 let mut state = self.inner.lock().unwrap();
599 let user = state
600 .users
601 .get_mut(&user_id)
602 .ok_or(AuthError::UserNotFound)?;
603 user.password_hash = Some(password_hash.to_string());
604 user.updated_at = OffsetDateTime::now_utc();
605 Ok(())
606 }
607
608 async fn delete_user(&self, user_id: i64) -> Result<(), AuthError> {
609 self.inner.lock().unwrap().users.remove(&user_id);
610 Ok(())
611 }
612 }
613
614 #[async_trait]
615 impl SessionStore for MemoryStore {
616 async fn create_session(&self, session: NewSession) -> Result<Session, AuthError> {
617 let mut state = self.inner.lock().unwrap();
618 state.next_session_id += 1;
619 let now = OffsetDateTime::now_utc();
620 let session = Session {
621 id: state.next_session_id,
622 token_hash: session.token_hash,
623 user_id: session.user_id,
624 expires_at: session.expires_at,
625 ip_address: session.ip_address,
626 user_agent: session.user_agent,
627 created_at: now,
628 updated_at: now,
629 };
630 state.sessions.insert(session.id, session.clone());
631 Ok(session)
632 }
633
634 async fn find_by_token_hash(&self, token_hash: &str) -> Result<Option<Session>, AuthError> {
635 let state = self.inner.lock().unwrap();
636 Ok(state
637 .sessions
638 .values()
639 .find(|session| session.token_hash == token_hash)
640 .cloned())
641 }
642
643 async fn find_by_user_id(&self, user_id: i64) -> Result<Vec<Session>, AuthError> {
644 let state = self.inner.lock().unwrap();
645 Ok(state
646 .sessions
647 .values()
648 .filter(|session| session.user_id == user_id)
649 .cloned()
650 .collect())
651 }
652
653 async fn delete_session(&self, id: i64) -> Result<(), AuthError> {
654 self.inner.lock().unwrap().sessions.remove(&id);
655 Ok(())
656 }
657
658 async fn delete_by_user_id(&self, user_id: i64) -> Result<(), AuthError> {
659 self.inner
660 .lock()
661 .unwrap()
662 .sessions
663 .retain(|_, session| session.user_id != user_id);
664 Ok(())
665 }
666
667 async fn delete_expired(&self) -> Result<u64, AuthError> {
668 let now = OffsetDateTime::now_utc();
669 let mut state = self.inner.lock().unwrap();
670 let before = state.sessions.len();
671 state
672 .sessions
673 .retain(|_, session| session.expires_at >= now);
674 Ok((before - state.sessions.len()) as u64)
675 }
676 }
677
678 #[async_trait]
679 impl VerificationStore for MemoryStore {
680 async fn create_verification(
681 &self,
682 verification: NewVerification,
683 ) -> Result<Verification, AuthError> {
684 let mut state = self.inner.lock().unwrap();
685 state.next_verification_id += 1;
686 let now = OffsetDateTime::now_utc();
687 let verification = Verification {
688 id: state.next_verification_id,
689 identifier: verification.identifier,
690 token_hash: verification.token_hash,
691 expires_at: verification.expires_at,
692 created_at: now,
693 updated_at: now,
694 };
695 state
696 .verifications
697 .insert(verification.id, verification.clone());
698 Ok(verification)
699 }
700
701 async fn find_by_identifier(
702 &self,
703 identifier: &str,
704 ) -> Result<Option<Verification>, AuthError> {
705 let state = self.inner.lock().unwrap();
706 Ok(state
707 .verifications
708 .values()
709 .find(|verification| verification.identifier == identifier)
710 .cloned())
711 }
712
713 async fn find_by_token_hash(
714 &self,
715 token_hash: &str,
716 ) -> Result<Option<Verification>, AuthError> {
717 let state = self.inner.lock().unwrap();
718 Ok(state
719 .verifications
720 .values()
721 .find(|verification| verification.token_hash == token_hash)
722 .cloned())
723 }
724
725 async fn delete_verification(&self, id: i64) -> Result<(), AuthError> {
726 self.inner.lock().unwrap().verifications.remove(&id);
727 Ok(())
728 }
729
730 async fn delete_by_identifier(&self, identifier: &str) -> Result<(), AuthError> {
731 self.inner
732 .lock()
733 .unwrap()
734 .verifications
735 .retain(|_, verification| verification.identifier != identifier);
736 Ok(())
737 }
738
739 async fn delete_expired(&self) -> Result<u64, AuthError> {
740 let now = OffsetDateTime::now_utc();
741 let mut state = self.inner.lock().unwrap();
742 let before = state.verifications.len();
743 state
744 .verifications
745 .retain(|_, verification| verification.expires_at >= now);
746 Ok((before - state.verifications.len()) as u64)
747 }
748 }
749
750 #[async_trait]
751 impl AccountStore for MemoryStore {
752 async fn create_account(&self, account: NewAccount) -> Result<Account, AuthError> {
753 let mut state = self.inner.lock().unwrap();
754 state.next_account_id += 1;
755 let now = OffsetDateTime::now_utc();
756 let account = Account {
757 id: state.next_account_id,
758 user_id: account.user_id,
759 provider_id: account.provider_id,
760 account_id: account.account_id,
761 access_token: account.access_token,
762 refresh_token: account.refresh_token,
763 access_token_expires_at: account.access_token_expires_at,
764 scope: account.scope,
765 created_at: now,
766 updated_at: now,
767 };
768 state.accounts.insert(account.id, account.clone());
769 Ok(account)
770 }
771
772 async fn find_by_provider(
773 &self,
774 provider_id: &str,
775 account_id: &str,
776 ) -> Result<Option<Account>, AuthError> {
777 let state = self.inner.lock().unwrap();
778 Ok(state
779 .accounts
780 .values()
781 .find(|account| {
782 account.provider_id == provider_id && account.account_id == account_id
783 })
784 .cloned())
785 }
786
787 async fn find_by_user_id(&self, user_id: i64) -> Result<Vec<Account>, AuthError> {
788 let state = self.inner.lock().unwrap();
789 Ok(state
790 .accounts
791 .values()
792 .filter(|account| account.user_id == user_id)
793 .cloned()
794 .collect())
795 }
796
797 async fn delete_account(&self, id: i64) -> Result<(), AuthError> {
798 self.inner.lock().unwrap().accounts.remove(&id);
799 Ok(())
800 }
801 }
802
803 #[derive(Clone, Default)]
804 struct TestEmailSender {
805 verification_tokens: Arc<Mutex<Vec<String>>>,
806 reset_tokens: Arc<Mutex<Vec<String>>>,
807 }
808
809 #[async_trait]
810 impl EmailSender for TestEmailSender {
811 async fn send_verification_email(
812 &self,
813 _user: &User,
814 token: &str,
815 ) -> Result<(), AuthError> {
816 self.verification_tokens
817 .lock()
818 .unwrap()
819 .push(token.to_string());
820 Ok(())
821 }
822
823 async fn send_password_reset_email(
824 &self,
825 _user: &User,
826 token: &str,
827 ) -> Result<(), AuthError> {
828 self.reset_tokens.lock().unwrap().push(token.to_string());
829 Ok(())
830 }
831 }
832
833 #[tokio::test]
834 async fn signup_verify_login_and_reset_flow_works() {
835 let store = MemoryStore::default();
836 let email = TestEmailSender::default();
837 let service = AuthService::new(
838 AuthConfig::default(),
839 store.clone(),
840 store.clone(),
841 store.clone(),
842 store.clone(),
843 email.clone(),
844 );
845
846 let signup = service
847 .signup(
848 NewUser {
849 email: "test@example.com".to_string(),
850 name: Some("Test".to_string()),
851 password: "supersecret".to_string(),
852 },
853 Some("127.0.0.1".to_string()),
854 Some("test-agent".to_string()),
855 )
856 .await
857 .unwrap();
858
859 assert_eq!(signup.user.email, "test@example.com");
860 assert!(signup.session.is_some());
861 assert_eq!(email.verification_tokens.lock().unwrap().len(), 1);
862
863 let verification_token = email.verification_tokens.lock().unwrap()[0].clone();
864 let verify = service
865 .verify_email(&verification_token, None, None)
866 .await
867 .unwrap();
868 assert!(verify.user.is_verified());
869
870 let login = service
871 .login("test@example.com", "supersecret", None, None)
872 .await
873 .unwrap();
874 assert_eq!(login.user.email, "test@example.com");
875
876 service
877 .request_password_reset("test@example.com")
878 .await
879 .unwrap();
880 let reset_token = email.reset_tokens.lock().unwrap()[0].clone();
881 service
882 .reset_password(&reset_token, "newpassword")
883 .await
884 .unwrap();
885
886 let login = service
887 .login("test@example.com", "newpassword", None, None)
888 .await
889 .unwrap();
890 assert_eq!(login.user.email, "test@example.com");
891 }
892
893 #[tokio::test]
894 async fn oauth_callback_creates_new_user_and_account() {
895 let store = MemoryStore::default();
896 let email = TestEmailSender::default();
897 let service = AuthService::new(
898 AuthConfig::default(),
899 store.clone(),
900 store.clone(),
901 store.clone(),
902 store.clone(),
903 email.clone(),
904 );
905
906 let oauth_info = crate::oauth::OAuthUserInfo {
907 provider_id: "google".to_string(),
908 account_id: "google-123".to_string(),
909 email: "oauth@example.com".to_string(),
910 name: Some("OAuth User".to_string()),
911 image: Some("https://example.com/avatar.jpg".to_string()),
912 };
913
914 let result = service
915 .oauth_callback(
916 oauth_info,
917 OAuthTokens::default(),
918 Some("127.0.0.1".to_string()),
919 Some("test-agent".to_string()),
920 )
921 .await
922 .unwrap();
923
924 assert_eq!(result.user.email, "oauth@example.com");
925 assert_eq!(result.user.name, Some("OAuth User".to_string()));
926 assert!(result.user.is_verified());
927 assert!(result.user.password_hash.is_none());
928
929 let accounts = AccountStore::find_by_user_id(&store, result.user.id)
931 .await
932 .unwrap();
933 assert_eq!(accounts.len(), 1);
934 assert_eq!(accounts[0].provider_id, "google");
935 assert_eq!(accounts[0].account_id, "google-123");
936 }
937
938 #[tokio::test]
939 async fn oauth_callback_links_existing_user_by_email() {
940 let store = MemoryStore::default();
941 let email = TestEmailSender::default();
942 let service = AuthService::new(
943 AuthConfig::default(),
944 store.clone(),
945 store.clone(),
946 store.clone(),
947 store.clone(),
948 email.clone(),
949 );
950
951 let existing_user = store
953 .create_user(
954 "existing@example.com",
955 Some("Existing User"),
956 Some("hash123"),
957 )
958 .await
959 .unwrap();
960
961 let oauth_info = crate::oauth::OAuthUserInfo {
962 provider_id: "github".to_string(),
963 account_id: "github-456".to_string(),
964 email: "existing@example.com".to_string(),
965 name: Some("GitHub User".to_string()),
966 image: None,
967 };
968
969 let result = service
970 .oauth_callback(
971 oauth_info,
972 OAuthTokens::default(),
973 Some("127.0.0.1".to_string()),
974 Some("test-agent".to_string()),
975 )
976 .await
977 .unwrap();
978
979 assert_eq!(result.user.id, existing_user.id);
981 assert_eq!(result.user.email, "existing@example.com");
982 assert!(result.user.is_verified());
983
984 let accounts = AccountStore::find_by_user_id(&store, result.user.id)
986 .await
987 .unwrap();
988 assert_eq!(accounts.len(), 1);
989 assert_eq!(accounts[0].provider_id, "github");
990 assert_eq!(accounts[0].account_id, "github-456");
991 }
992
993 #[tokio::test]
994 async fn oauth_callback_logs_in_existing_account() {
995 let store = MemoryStore::default();
996 let email = TestEmailSender::default();
997 let service = AuthService::new(
998 AuthConfig::default(),
999 store.clone(),
1000 store.clone(),
1001 store.clone(),
1002 store.clone(),
1003 email.clone(),
1004 );
1005
1006 let user = store
1008 .create_user("oauth@example.com", Some("OAuth User"), None)
1009 .await
1010 .unwrap();
1011 store
1012 .create_account(crate::types::NewAccount {
1013 user_id: user.id,
1014 provider_id: "google".to_string(),
1015 account_id: "google-789".to_string(),
1016 access_token: None,
1017 refresh_token: None,
1018 access_token_expires_at: None,
1019 scope: None,
1020 })
1021 .await
1022 .unwrap();
1023
1024 let oauth_info = crate::oauth::OAuthUserInfo {
1025 provider_id: "google".to_string(),
1026 account_id: "google-789".to_string(),
1027 email: "oauth@example.com".to_string(),
1028 name: Some("OAuth User".to_string()),
1029 image: None,
1030 };
1031
1032 let result = service
1033 .oauth_callback(
1034 oauth_info,
1035 OAuthTokens::default(),
1036 Some("127.0.0.1".to_string()),
1037 Some("test-agent".to_string()),
1038 )
1039 .await
1040 .unwrap();
1041
1042 assert_eq!(result.user.id, user.id);
1044 assert_eq!(result.user.email, "oauth@example.com");
1045
1046 let accounts = AccountStore::find_by_user_id(&store, result.user.id)
1048 .await
1049 .unwrap();
1050 assert_eq!(accounts.len(), 1);
1051 }
1052
1053 #[tokio::test]
1054 async fn oauth_callback_respects_linking_policy() {
1055 let store = MemoryStore::default();
1056 let email = TestEmailSender::default();
1057 let mut config = AuthConfig::default();
1058 config.oauth.allow_implicit_account_linking = false;
1059
1060 let service = AuthService::new(
1061 config,
1062 store.clone(),
1063 store.clone(),
1064 store.clone(),
1065 store.clone(),
1066 email.clone(),
1067 );
1068
1069 store
1071 .create_user(
1072 "existing@example.com",
1073 Some("Existing User"),
1074 Some("hash123"),
1075 )
1076 .await
1077 .unwrap();
1078
1079 let oauth_info = crate::oauth::OAuthUserInfo {
1080 provider_id: "google".to_string(),
1081 account_id: "google-999".to_string(),
1082 email: "existing@example.com".to_string(),
1083 name: Some("OAuth User".to_string()),
1084 image: None,
1085 };
1086
1087 let result = service
1089 .oauth_callback(
1090 oauth_info,
1091 OAuthTokens::default(),
1092 Some("127.0.0.1".to_string()),
1093 Some("test-agent".to_string()),
1094 )
1095 .await;
1096
1097 assert!(result.is_err());
1098 match result {
1099 Err(AuthError::OAuth(msg)) => {
1100 assert!(msg.contains("Account linking by email is disabled"));
1101 }
1102 _ => panic!("Expected OAuth error"),
1103 }
1104 }
1105}