Skip to main content

rs_auth_core/
service.rs

1use std::sync::Arc;
2
3use time::OffsetDateTime;
4
5use crate::config::AuthConfig;
6use crate::crypto::{hash, token};
7use crate::email::EmailSender;
8use crate::error::AuthError;
9use crate::oauth::{OAuthTokens, OAuthUserInfo};
10use crate::store::{AccountStore, SessionStore, UserStore, VerificationStore};
11use crate::types::{NewAccount, NewSession, NewUser, NewVerification, Session, User, Verification};
12
13/// Result returned by [`AuthService::signup`].
14#[derive(Debug)]
15pub struct SignupResult {
16    /// The newly created user.
17    pub user: User,
18    /// Session created if auto-sign-in is enabled.
19    pub session: Option<Session>,
20    /// Raw session token if auto-sign-in is enabled.
21    pub session_token: Option<String>,
22    /// Raw verification token if email verification is enabled.
23    pub verification_token: Option<String>,
24}
25
26/// Result returned by [`AuthService::login`].
27#[derive(Debug)]
28pub struct LoginResult {
29    /// The authenticated user.
30    pub user: User,
31    /// The newly created session.
32    pub session: Session,
33    /// Raw session token to be stored in a cookie.
34    pub session_token: String,
35}
36
37/// Result returned by [`AuthService::verify_email`].
38#[derive(Debug)]
39pub struct VerifyEmailResult {
40    /// The user whose email was verified.
41    pub user: User,
42    /// Session created if auto-sign-in after verification is enabled.
43    pub session: Option<Session>,
44    /// Raw session token if auto-sign-in after verification is enabled.
45    pub session_token: Option<String>,
46}
47
48/// Result returned by [`AuthService::request_password_reset`].
49#[derive(Debug, Default)]
50pub struct RequestResetResult {
51    /// Private field to prevent construction outside this crate.
52    pub _private: (),
53}
54
55/// Result returned by [`AuthService::reset_password`].
56#[derive(Debug)]
57pub struct ResetPasswordResult {
58    /// The user whose password was reset.
59    pub user: User,
60}
61
62/// Result returned by [`AuthService::get_session`].
63#[derive(Debug)]
64pub struct SessionResult {
65    /// The user associated with the session.
66    pub user: User,
67    /// The session details.
68    pub session: Session,
69}
70
71/// Core authentication service. Generic over storage backends and email sender.
72pub struct AuthService<U, S, V, A, E>
73where
74    U: UserStore,
75    S: SessionStore,
76    V: VerificationStore,
77    A: AccountStore,
78    E: EmailSender,
79{
80    /// Authentication configuration.
81    pub config: AuthConfig,
82    /// User storage backend.
83    pub users: Arc<U>,
84    /// Session storage backend.
85    pub sessions: Arc<S>,
86    /// Verification token storage backend.
87    pub verifications: Arc<V>,
88    /// OAuth account storage backend.
89    pub accounts: Arc<A>,
90    /// Email sender implementation.
91    pub email: Arc<E>,
92}
93
94impl<U, S, V, A, E> AuthService<U, S, V, A, E>
95where
96    U: UserStore,
97    S: SessionStore,
98    V: VerificationStore,
99    A: AccountStore,
100    E: EmailSender,
101{
102    /// Create a new authentication service with the given configuration and backends.
103    pub fn new(
104        config: AuthConfig,
105        users: U,
106        sessions: S,
107        verifications: V,
108        accounts: A,
109        email: E,
110    ) -> Self {
111        Self {
112            config,
113            users: Arc::new(users),
114            sessions: Arc::new(sessions),
115            verifications: Arc::new(verifications),
116            accounts: Arc::new(accounts),
117            email: Arc::new(email),
118        }
119    }
120
121    /// Register a new user with email and password.
122    pub async fn signup(
123        &self,
124        input: NewUser,
125        ip: Option<String>,
126        user_agent: Option<String>,
127    ) -> Result<SignupResult, AuthError> {
128        if input.password.len() < 8 {
129            return Err(AuthError::WeakPassword(8));
130        }
131
132        let email = input.email.trim().to_lowercase();
133        if self.users.find_by_email(&email).await?.is_some() {
134            return Err(AuthError::EmailTaken);
135        }
136
137        let password_hash = hash::hash_password(&input.password)?;
138        let user = self
139            .users
140            .create_user(&email, input.name.as_deref(), Some(&password_hash))
141            .await?;
142
143        let verification_token = if self.config.email.send_verification_on_signup {
144            let identifier = format!("email-verify:{}", user.email.to_lowercase());
145            let raw_token = token::generate_token(self.config.token_length);
146
147            let _ = self.verifications.delete_by_identifier(&identifier).await;
148            self.verifications
149                .create_verification(NewVerification {
150                    identifier,
151                    token_hash: token::hash_token(&raw_token),
152                    expires_at: OffsetDateTime::now_utc() + self.config.verification_ttl,
153                })
154                .await?;
155            self.email
156                .send_verification_email(&user, &raw_token)
157                .await?;
158            Some(raw_token)
159        } else {
160            None
161        };
162
163        let (session, session_token) = if self.config.email.auto_sign_in_after_signup {
164            let (session, raw_token) = self
165                .create_session_internal(user.id, ip, user_agent)
166                .await?;
167            (Some(session), Some(raw_token))
168        } else {
169            (None, None)
170        };
171
172        Ok(SignupResult {
173            user,
174            session,
175            session_token,
176            verification_token,
177        })
178    }
179
180    /// Authenticate a user with email and password.
181    pub async fn login(
182        &self,
183        email: &str,
184        password: &str,
185        ip: Option<String>,
186        user_agent: Option<String>,
187    ) -> Result<LoginResult, AuthError> {
188        let user = self
189            .users
190            .find_by_email(&email.trim().to_lowercase())
191            .await?
192            .ok_or(AuthError::InvalidCredentials)?;
193
194        let password_hash = user
195            .password_hash
196            .as_deref()
197            .ok_or(AuthError::InvalidCredentials)?;
198        if !hash::verify_password(password, password_hash)? {
199            return Err(AuthError::InvalidCredentials);
200        }
201
202        if self.config.email.require_verification_to_login && !user.is_verified() {
203            return Err(AuthError::EmailNotVerified);
204        }
205
206        let (session, session_token) = self
207            .create_session_internal(user.id, ip, user_agent)
208            .await?;
209
210        Ok(LoginResult {
211            user,
212            session,
213            session_token,
214        })
215    }
216
217    /// Delete a single session by ID.
218    pub async fn logout(&self, session_id: i64) -> Result<(), AuthError> {
219        self.sessions.delete_session(session_id).await
220    }
221
222    /// Delete all sessions for a user.
223    pub async fn logout_all(&self, user_id: i64) -> Result<(), AuthError> {
224        self.sessions.delete_by_user_id(user_id).await
225    }
226
227    /// Retrieve a session and its associated user by raw token.
228    pub async fn get_session(&self, raw_token: &str) -> Result<SessionResult, AuthError> {
229        let session = self
230            .sessions
231            .find_by_token_hash(&token::hash_token(raw_token))
232            .await?
233            .ok_or(AuthError::SessionNotFound)?;
234
235        if session.expires_at < OffsetDateTime::now_utc() {
236            self.sessions.delete_session(session.id).await?;
237            return Err(AuthError::SessionNotFound);
238        }
239
240        let user = self
241            .users
242            .find_by_id(session.user_id)
243            .await?
244            .ok_or(AuthError::UserNotFound)?;
245
246        Ok(SessionResult { user, session })
247    }
248
249    /// List all active sessions for a user.
250    pub async fn list_sessions(&self, user_id: i64) -> Result<Vec<Session>, AuthError> {
251        self.sessions.find_by_user_id(user_id).await
252    }
253
254    /// Verify a user's email address using a verification token.
255    pub async fn verify_email(
256        &self,
257        raw_token: &str,
258        ip: Option<String>,
259        user_agent: Option<String>,
260    ) -> Result<VerifyEmailResult, AuthError> {
261        let verification = self.lookup_verification(raw_token, "email-verify:").await?;
262        let email = verification
263            .identifier
264            .strip_prefix("email-verify:")
265            .ok_or(AuthError::InvalidToken)?;
266
267        let user = self
268            .users
269            .find_by_email(email)
270            .await?
271            .ok_or(AuthError::UserNotFound)?;
272        self.users.set_email_verified(user.id).await?;
273        self.verifications
274            .delete_verification(verification.id)
275            .await?;
276
277        let user = self
278            .users
279            .find_by_id(user.id)
280            .await?
281            .ok_or(AuthError::UserNotFound)?;
282
283        let (session, session_token) = if self.config.email.auto_sign_in_after_verification {
284            let (session, raw_token) = self
285                .create_session_internal(user.id, ip, user_agent)
286                .await?;
287            (Some(session), Some(raw_token))
288        } else {
289            (None, None)
290        };
291
292        Ok(VerifyEmailResult {
293            user,
294            session,
295            session_token,
296        })
297    }
298
299    /// Request a password reset token for a user by email.
300    pub async fn request_password_reset(
301        &self,
302        email: &str,
303    ) -> Result<RequestResetResult, AuthError> {
304        let email = email.trim().to_lowercase();
305        if let Some(user) = self.users.find_by_email(&email).await? {
306            let identifier = format!("password-reset:{}", user.email.to_lowercase());
307            let _ = self.verifications.delete_by_identifier(&identifier).await;
308
309            let raw_token = token::generate_token(self.config.token_length);
310            self.verifications
311                .create_verification(NewVerification {
312                    identifier,
313                    token_hash: token::hash_token(&raw_token),
314                    expires_at: OffsetDateTime::now_utc() + self.config.reset_ttl,
315                })
316                .await?;
317            self.email
318                .send_password_reset_email(&user, &raw_token)
319                .await?;
320        }
321
322        Ok(RequestResetResult::default())
323    }
324
325    /// Reset a user's password using a reset token.
326    pub async fn reset_password(
327        &self,
328        raw_token: &str,
329        new_password: &str,
330    ) -> Result<ResetPasswordResult, AuthError> {
331        if new_password.len() < 8 {
332            return Err(AuthError::WeakPassword(8));
333        }
334
335        let verification = self
336            .lookup_verification(raw_token, "password-reset:")
337            .await?;
338        let email = verification
339            .identifier
340            .strip_prefix("password-reset:")
341            .ok_or(AuthError::InvalidToken)?;
342        let user = self
343            .users
344            .find_by_email(email)
345            .await?
346            .ok_or(AuthError::UserNotFound)?;
347
348        self.users
349            .update_password(user.id, &hash::hash_password(new_password)?)
350            .await?;
351        self.sessions.delete_by_user_id(user.id).await?;
352        self.verifications
353            .delete_verification(verification.id)
354            .await?;
355
356        let user = self
357            .users
358            .find_by_id(user.id)
359            .await?
360            .ok_or(AuthError::UserNotFound)?;
361
362        Ok(ResetPasswordResult { user })
363    }
364
365    /// Delete expired sessions and verification tokens. Returns (sessions_deleted, verifications_deleted).
366    pub async fn cleanup_expired(&self) -> Result<(u64, u64), AuthError> {
367        let sessions_deleted = self.sessions.delete_expired().await?;
368        let verifications_deleted = self.verifications.delete_expired().await?;
369        Ok((sessions_deleted, verifications_deleted))
370    }
371
372    /// Handle OAuth callback - find or create user from OAuth info.
373    ///
374    /// NOTE: OAuth state verification happens in the handler layer before calling this method.
375    /// The CSRF state and PKCE verifier are stored in the `verifications` table with the
376    /// identifier format `oauth-state:{csrf_token}`. This reuses existing infrastructure
377    /// rather than requiring a dedicated OAuth state table.
378    pub async fn oauth_callback(
379        &self,
380        info: OAuthUserInfo,
381        tokens: OAuthTokens,
382        ip: Option<String>,
383        user_agent: Option<String>,
384    ) -> Result<LoginResult, AuthError> {
385        // 1. Check if account already exists for this provider
386        if let Some(account) = self
387            .accounts
388            .find_by_provider(&info.provider_id, &info.account_id)
389            .await?
390        {
391            // Existing account - just create session
392            let user = self
393                .users
394                .find_by_id(account.user_id)
395                .await?
396                .ok_or(AuthError::UserNotFound)?;
397            let (session, session_token) = self
398                .create_session_internal(user.id, ip, user_agent)
399                .await?;
400            return Ok(LoginResult {
401                user,
402                session,
403                session_token,
404            });
405        }
406
407        // 2. Check if user with this email already exists
408        let user = if let Some(existing_user) = self.users.find_by_email(&info.email).await? {
409            // Check if implicit account linking is allowed
410            if !self.config.oauth.allow_implicit_account_linking {
411                return Err(AuthError::OAuth(
412                    "Account linking by email is disabled. Please sign in with your password first.".to_string()
413                ));
414            }
415            existing_user
416        } else {
417            // 3. Create new user (no password for OAuth-only users)
418            self.users
419                .create_user(&info.email, info.name.as_deref(), None)
420                .await?
421        };
422
423        // 4. Link account with tokens
424        let access_token_expires_at = tokens.expires_in.map(|d| OffsetDateTime::now_utc() + d);
425        self.accounts
426            .create_account(NewAccount {
427                user_id: user.id,
428                provider_id: info.provider_id,
429                account_id: info.account_id,
430                access_token: tokens.access_token,
431                refresh_token: tokens.refresh_token,
432                access_token_expires_at,
433                scope: tokens.scope,
434            })
435            .await?;
436
437        // 5. Mark email as verified (OAuth providers verify emails)
438        if !user.is_verified() {
439            self.users.set_email_verified(user.id).await?;
440        }
441
442        // 6. Reload user and create session
443        let user = self
444            .users
445            .find_by_id(user.id)
446            .await?
447            .ok_or(AuthError::UserNotFound)?;
448        let (session, session_token) = self
449            .create_session_internal(user.id, ip, user_agent)
450            .await?;
451        Ok(LoginResult {
452            user,
453            session,
454            session_token,
455        })
456    }
457
458    async fn create_session_internal(
459        &self,
460        user_id: i64,
461        ip: Option<String>,
462        user_agent: Option<String>,
463    ) -> Result<(Session, String), AuthError> {
464        let raw_token = token::generate_token(self.config.token_length);
465        let session = self
466            .sessions
467            .create_session(NewSession {
468                token_hash: token::hash_token(&raw_token),
469                user_id,
470                expires_at: OffsetDateTime::now_utc() + self.config.session_ttl,
471                ip_address: ip,
472                user_agent,
473            })
474            .await?;
475
476        Ok((session, raw_token))
477    }
478
479    async fn lookup_verification(
480        &self,
481        raw_token: &str,
482        prefix: &str,
483    ) -> Result<Verification, AuthError> {
484        let verification = self
485            .verifications
486            .find_by_token_hash(&token::hash_token(raw_token))
487            .await?
488            .ok_or(AuthError::InvalidToken)?;
489
490        if !verification.identifier.starts_with(prefix) {
491            return Err(AuthError::InvalidToken);
492        }
493
494        if verification.expires_at < OffsetDateTime::now_utc() {
495            self.verifications
496                .delete_verification(verification.id)
497                .await?;
498            return Err(AuthError::InvalidToken);
499        }
500
501        Ok(verification)
502    }
503}
504
505#[cfg(test)]
506mod tests {
507    use std::collections::HashMap;
508    use std::sync::{Arc, Mutex};
509
510    use async_trait::async_trait;
511    use time::OffsetDateTime;
512
513    use super::AuthService;
514    use crate::config::AuthConfig;
515    use crate::email::EmailSender;
516    use crate::error::AuthError;
517    use crate::oauth::OAuthTokens;
518    use crate::store::{AccountStore, SessionStore, UserStore, VerificationStore};
519    use crate::types::{
520        Account, NewAccount, NewSession, NewUser, NewVerification, Session, User, Verification,
521    };
522
523    #[derive(Default)]
524    struct MemoryState {
525        next_user_id: i64,
526        next_session_id: i64,
527        next_verification_id: i64,
528        next_account_id: i64,
529        users: HashMap<i64, User>,
530        sessions: HashMap<i64, Session>,
531        verifications: HashMap<i64, Verification>,
532        accounts: HashMap<i64, Account>,
533    }
534
535    #[derive(Clone, Default)]
536    struct MemoryStore {
537        inner: Arc<Mutex<MemoryState>>,
538    }
539
540    #[async_trait]
541    impl UserStore for MemoryStore {
542        async fn create_user(
543            &self,
544            email: &str,
545            name: Option<&str>,
546            password_hash: Option<&str>,
547        ) -> Result<User, AuthError> {
548            let mut state = self.inner.lock().unwrap();
549            if state.users.values().any(|user| user.email == email) {
550                return Err(AuthError::EmailTaken);
551            }
552
553            state.next_user_id += 1;
554            let now = OffsetDateTime::now_utc();
555            let user = User {
556                id: state.next_user_id,
557                email: email.to_string(),
558                name: name.map(str::to_owned),
559                password_hash: password_hash.map(str::to_owned),
560                email_verified_at: None,
561                image: None,
562                created_at: now,
563                updated_at: now,
564            };
565            state.users.insert(user.id, user.clone());
566            Ok(user)
567        }
568
569        async fn find_by_email(&self, email: &str) -> Result<Option<User>, AuthError> {
570            let state = self.inner.lock().unwrap();
571            Ok(state
572                .users
573                .values()
574                .find(|user| user.email == email)
575                .cloned())
576        }
577
578        async fn find_by_id(&self, id: i64) -> Result<Option<User>, AuthError> {
579            Ok(self.inner.lock().unwrap().users.get(&id).cloned())
580        }
581
582        async fn set_email_verified(&self, user_id: i64) -> Result<(), AuthError> {
583            let mut state = self.inner.lock().unwrap();
584            let user = state
585                .users
586                .get_mut(&user_id)
587                .ok_or(AuthError::UserNotFound)?;
588            user.email_verified_at = Some(OffsetDateTime::now_utc());
589            user.updated_at = OffsetDateTime::now_utc();
590            Ok(())
591        }
592
593        async fn update_password(
594            &self,
595            user_id: i64,
596            password_hash: &str,
597        ) -> Result<(), AuthError> {
598            let mut state = self.inner.lock().unwrap();
599            let user = state
600                .users
601                .get_mut(&user_id)
602                .ok_or(AuthError::UserNotFound)?;
603            user.password_hash = Some(password_hash.to_string());
604            user.updated_at = OffsetDateTime::now_utc();
605            Ok(())
606        }
607
608        async fn delete_user(&self, user_id: i64) -> Result<(), AuthError> {
609            self.inner.lock().unwrap().users.remove(&user_id);
610            Ok(())
611        }
612    }
613
614    #[async_trait]
615    impl SessionStore for MemoryStore {
616        async fn create_session(&self, session: NewSession) -> Result<Session, AuthError> {
617            let mut state = self.inner.lock().unwrap();
618            state.next_session_id += 1;
619            let now = OffsetDateTime::now_utc();
620            let session = Session {
621                id: state.next_session_id,
622                token_hash: session.token_hash,
623                user_id: session.user_id,
624                expires_at: session.expires_at,
625                ip_address: session.ip_address,
626                user_agent: session.user_agent,
627                created_at: now,
628                updated_at: now,
629            };
630            state.sessions.insert(session.id, session.clone());
631            Ok(session)
632        }
633
634        async fn find_by_token_hash(&self, token_hash: &str) -> Result<Option<Session>, AuthError> {
635            let state = self.inner.lock().unwrap();
636            Ok(state
637                .sessions
638                .values()
639                .find(|session| session.token_hash == token_hash)
640                .cloned())
641        }
642
643        async fn find_by_user_id(&self, user_id: i64) -> Result<Vec<Session>, AuthError> {
644            let state = self.inner.lock().unwrap();
645            Ok(state
646                .sessions
647                .values()
648                .filter(|session| session.user_id == user_id)
649                .cloned()
650                .collect())
651        }
652
653        async fn delete_session(&self, id: i64) -> Result<(), AuthError> {
654            self.inner.lock().unwrap().sessions.remove(&id);
655            Ok(())
656        }
657
658        async fn delete_by_user_id(&self, user_id: i64) -> Result<(), AuthError> {
659            self.inner
660                .lock()
661                .unwrap()
662                .sessions
663                .retain(|_, session| session.user_id != user_id);
664            Ok(())
665        }
666
667        async fn delete_expired(&self) -> Result<u64, AuthError> {
668            let now = OffsetDateTime::now_utc();
669            let mut state = self.inner.lock().unwrap();
670            let before = state.sessions.len();
671            state
672                .sessions
673                .retain(|_, session| session.expires_at >= now);
674            Ok((before - state.sessions.len()) as u64)
675        }
676    }
677
678    #[async_trait]
679    impl VerificationStore for MemoryStore {
680        async fn create_verification(
681            &self,
682            verification: NewVerification,
683        ) -> Result<Verification, AuthError> {
684            let mut state = self.inner.lock().unwrap();
685            state.next_verification_id += 1;
686            let now = OffsetDateTime::now_utc();
687            let verification = Verification {
688                id: state.next_verification_id,
689                identifier: verification.identifier,
690                token_hash: verification.token_hash,
691                expires_at: verification.expires_at,
692                created_at: now,
693                updated_at: now,
694            };
695            state
696                .verifications
697                .insert(verification.id, verification.clone());
698            Ok(verification)
699        }
700
701        async fn find_by_identifier(
702            &self,
703            identifier: &str,
704        ) -> Result<Option<Verification>, AuthError> {
705            let state = self.inner.lock().unwrap();
706            Ok(state
707                .verifications
708                .values()
709                .find(|verification| verification.identifier == identifier)
710                .cloned())
711        }
712
713        async fn find_by_token_hash(
714            &self,
715            token_hash: &str,
716        ) -> Result<Option<Verification>, AuthError> {
717            let state = self.inner.lock().unwrap();
718            Ok(state
719                .verifications
720                .values()
721                .find(|verification| verification.token_hash == token_hash)
722                .cloned())
723        }
724
725        async fn delete_verification(&self, id: i64) -> Result<(), AuthError> {
726            self.inner.lock().unwrap().verifications.remove(&id);
727            Ok(())
728        }
729
730        async fn delete_by_identifier(&self, identifier: &str) -> Result<(), AuthError> {
731            self.inner
732                .lock()
733                .unwrap()
734                .verifications
735                .retain(|_, verification| verification.identifier != identifier);
736            Ok(())
737        }
738
739        async fn delete_expired(&self) -> Result<u64, AuthError> {
740            let now = OffsetDateTime::now_utc();
741            let mut state = self.inner.lock().unwrap();
742            let before = state.verifications.len();
743            state
744                .verifications
745                .retain(|_, verification| verification.expires_at >= now);
746            Ok((before - state.verifications.len()) as u64)
747        }
748    }
749
750    #[async_trait]
751    impl AccountStore for MemoryStore {
752        async fn create_account(&self, account: NewAccount) -> Result<Account, AuthError> {
753            let mut state = self.inner.lock().unwrap();
754            state.next_account_id += 1;
755            let now = OffsetDateTime::now_utc();
756            let account = Account {
757                id: state.next_account_id,
758                user_id: account.user_id,
759                provider_id: account.provider_id,
760                account_id: account.account_id,
761                access_token: account.access_token,
762                refresh_token: account.refresh_token,
763                access_token_expires_at: account.access_token_expires_at,
764                scope: account.scope,
765                created_at: now,
766                updated_at: now,
767            };
768            state.accounts.insert(account.id, account.clone());
769            Ok(account)
770        }
771
772        async fn find_by_provider(
773            &self,
774            provider_id: &str,
775            account_id: &str,
776        ) -> Result<Option<Account>, AuthError> {
777            let state = self.inner.lock().unwrap();
778            Ok(state
779                .accounts
780                .values()
781                .find(|account| {
782                    account.provider_id == provider_id && account.account_id == account_id
783                })
784                .cloned())
785        }
786
787        async fn find_by_user_id(&self, user_id: i64) -> Result<Vec<Account>, AuthError> {
788            let state = self.inner.lock().unwrap();
789            Ok(state
790                .accounts
791                .values()
792                .filter(|account| account.user_id == user_id)
793                .cloned()
794                .collect())
795        }
796
797        async fn delete_account(&self, id: i64) -> Result<(), AuthError> {
798            self.inner.lock().unwrap().accounts.remove(&id);
799            Ok(())
800        }
801    }
802
803    #[derive(Clone, Default)]
804    struct TestEmailSender {
805        verification_tokens: Arc<Mutex<Vec<String>>>,
806        reset_tokens: Arc<Mutex<Vec<String>>>,
807    }
808
809    #[async_trait]
810    impl EmailSender for TestEmailSender {
811        async fn send_verification_email(
812            &self,
813            _user: &User,
814            token: &str,
815        ) -> Result<(), AuthError> {
816            self.verification_tokens
817                .lock()
818                .unwrap()
819                .push(token.to_string());
820            Ok(())
821        }
822
823        async fn send_password_reset_email(
824            &self,
825            _user: &User,
826            token: &str,
827        ) -> Result<(), AuthError> {
828            self.reset_tokens.lock().unwrap().push(token.to_string());
829            Ok(())
830        }
831    }
832
833    #[tokio::test]
834    async fn signup_verify_login_and_reset_flow_works() {
835        let store = MemoryStore::default();
836        let email = TestEmailSender::default();
837        let service = AuthService::new(
838            AuthConfig::default(),
839            store.clone(),
840            store.clone(),
841            store.clone(),
842            store.clone(),
843            email.clone(),
844        );
845
846        let signup = service
847            .signup(
848                NewUser {
849                    email: "test@example.com".to_string(),
850                    name: Some("Test".to_string()),
851                    password: "supersecret".to_string(),
852                },
853                Some("127.0.0.1".to_string()),
854                Some("test-agent".to_string()),
855            )
856            .await
857            .unwrap();
858
859        assert_eq!(signup.user.email, "test@example.com");
860        assert!(signup.session.is_some());
861        assert_eq!(email.verification_tokens.lock().unwrap().len(), 1);
862
863        let verification_token = email.verification_tokens.lock().unwrap()[0].clone();
864        let verify = service
865            .verify_email(&verification_token, None, None)
866            .await
867            .unwrap();
868        assert!(verify.user.is_verified());
869
870        let login = service
871            .login("test@example.com", "supersecret", None, None)
872            .await
873            .unwrap();
874        assert_eq!(login.user.email, "test@example.com");
875
876        service
877            .request_password_reset("test@example.com")
878            .await
879            .unwrap();
880        let reset_token = email.reset_tokens.lock().unwrap()[0].clone();
881        service
882            .reset_password(&reset_token, "newpassword")
883            .await
884            .unwrap();
885
886        let login = service
887            .login("test@example.com", "newpassword", None, None)
888            .await
889            .unwrap();
890        assert_eq!(login.user.email, "test@example.com");
891    }
892
893    #[tokio::test]
894    async fn oauth_callback_creates_new_user_and_account() {
895        let store = MemoryStore::default();
896        let email = TestEmailSender::default();
897        let service = AuthService::new(
898            AuthConfig::default(),
899            store.clone(),
900            store.clone(),
901            store.clone(),
902            store.clone(),
903            email.clone(),
904        );
905
906        let oauth_info = crate::oauth::OAuthUserInfo {
907            provider_id: "google".to_string(),
908            account_id: "google-123".to_string(),
909            email: "oauth@example.com".to_string(),
910            name: Some("OAuth User".to_string()),
911            image: Some("https://example.com/avatar.jpg".to_string()),
912        };
913
914        let result = service
915            .oauth_callback(
916                oauth_info,
917                OAuthTokens::default(),
918                Some("127.0.0.1".to_string()),
919                Some("test-agent".to_string()),
920            )
921            .await
922            .unwrap();
923
924        assert_eq!(result.user.email, "oauth@example.com");
925        assert_eq!(result.user.name, Some("OAuth User".to_string()));
926        assert!(result.user.is_verified());
927        assert!(result.user.password_hash.is_none());
928
929        // Verify account was created
930        let accounts = AccountStore::find_by_user_id(&store, result.user.id)
931            .await
932            .unwrap();
933        assert_eq!(accounts.len(), 1);
934        assert_eq!(accounts[0].provider_id, "google");
935        assert_eq!(accounts[0].account_id, "google-123");
936    }
937
938    #[tokio::test]
939    async fn oauth_callback_links_existing_user_by_email() {
940        let store = MemoryStore::default();
941        let email = TestEmailSender::default();
942        let service = AuthService::new(
943            AuthConfig::default(),
944            store.clone(),
945            store.clone(),
946            store.clone(),
947            store.clone(),
948            email.clone(),
949        );
950
951        // Create existing user with password
952        let existing_user = store
953            .create_user(
954                "existing@example.com",
955                Some("Existing User"),
956                Some("hash123"),
957            )
958            .await
959            .unwrap();
960
961        let oauth_info = crate::oauth::OAuthUserInfo {
962            provider_id: "github".to_string(),
963            account_id: "github-456".to_string(),
964            email: "existing@example.com".to_string(),
965            name: Some("GitHub User".to_string()),
966            image: None,
967        };
968
969        let result = service
970            .oauth_callback(
971                oauth_info,
972                OAuthTokens::default(),
973                Some("127.0.0.1".to_string()),
974                Some("test-agent".to_string()),
975            )
976            .await
977            .unwrap();
978
979        // Should link to existing user
980        assert_eq!(result.user.id, existing_user.id);
981        assert_eq!(result.user.email, "existing@example.com");
982        assert!(result.user.is_verified());
983
984        // Verify account was linked
985        let accounts = AccountStore::find_by_user_id(&store, result.user.id)
986            .await
987            .unwrap();
988        assert_eq!(accounts.len(), 1);
989        assert_eq!(accounts[0].provider_id, "github");
990        assert_eq!(accounts[0].account_id, "github-456");
991    }
992
993    #[tokio::test]
994    async fn oauth_callback_logs_in_existing_account() {
995        let store = MemoryStore::default();
996        let email = TestEmailSender::default();
997        let service = AuthService::new(
998            AuthConfig::default(),
999            store.clone(),
1000            store.clone(),
1001            store.clone(),
1002            store.clone(),
1003            email.clone(),
1004        );
1005
1006        // Create user and account
1007        let user = store
1008            .create_user("oauth@example.com", Some("OAuth User"), None)
1009            .await
1010            .unwrap();
1011        store
1012            .create_account(crate::types::NewAccount {
1013                user_id: user.id,
1014                provider_id: "google".to_string(),
1015                account_id: "google-789".to_string(),
1016                access_token: None,
1017                refresh_token: None,
1018                access_token_expires_at: None,
1019                scope: None,
1020            })
1021            .await
1022            .unwrap();
1023
1024        let oauth_info = crate::oauth::OAuthUserInfo {
1025            provider_id: "google".to_string(),
1026            account_id: "google-789".to_string(),
1027            email: "oauth@example.com".to_string(),
1028            name: Some("OAuth User".to_string()),
1029            image: None,
1030        };
1031
1032        let result = service
1033            .oauth_callback(
1034                oauth_info,
1035                OAuthTokens::default(),
1036                Some("127.0.0.1".to_string()),
1037                Some("test-agent".to_string()),
1038            )
1039            .await
1040            .unwrap();
1041
1042        // Should log in existing user
1043        assert_eq!(result.user.id, user.id);
1044        assert_eq!(result.user.email, "oauth@example.com");
1045
1046        // Should not create duplicate account
1047        let accounts = AccountStore::find_by_user_id(&store, result.user.id)
1048            .await
1049            .unwrap();
1050        assert_eq!(accounts.len(), 1);
1051    }
1052
1053    #[tokio::test]
1054    async fn oauth_callback_respects_linking_policy() {
1055        let store = MemoryStore::default();
1056        let email = TestEmailSender::default();
1057        let mut config = AuthConfig::default();
1058        config.oauth.allow_implicit_account_linking = false;
1059
1060        let service = AuthService::new(
1061            config,
1062            store.clone(),
1063            store.clone(),
1064            store.clone(),
1065            store.clone(),
1066            email.clone(),
1067        );
1068
1069        // Create existing user
1070        store
1071            .create_user(
1072                "existing@example.com",
1073                Some("Existing User"),
1074                Some("hash123"),
1075            )
1076            .await
1077            .unwrap();
1078
1079        let oauth_info = crate::oauth::OAuthUserInfo {
1080            provider_id: "google".to_string(),
1081            account_id: "google-999".to_string(),
1082            email: "existing@example.com".to_string(),
1083            name: Some("OAuth User".to_string()),
1084            image: None,
1085        };
1086
1087        // Should fail because linking is disabled
1088        let result = service
1089            .oauth_callback(
1090                oauth_info,
1091                OAuthTokens::default(),
1092                Some("127.0.0.1".to_string()),
1093                Some("test-agent".to_string()),
1094            )
1095            .await;
1096
1097        assert!(result.is_err());
1098        match result {
1099            Err(AuthError::OAuth(msg)) => {
1100                assert!(msg.contains("Account linking by email is disabled"));
1101            }
1102            _ => panic!("Expected OAuth error"),
1103        }
1104    }
1105}