Skip to main content

rs_auth_core/
service.rs

1use std::sync::Arc;
2
3use time::OffsetDateTime;
4
5use crate::config::AuthConfig;
6use crate::crypto::{hash, token};
7use crate::email::EmailSender;
8use crate::error::{AuthError, OAuthError};
9use crate::oauth::{OAuthTokens, OAuthUserInfo};
10use crate::store::{AccountStore, OAuthStateStore, SessionStore, UserStore, VerificationStore};
11use crate::types::{NewAccount, NewSession, NewUser, NewVerification, Session, User, Verification};
12
13/// Result returned by [`AuthService::signup`].
14#[derive(Debug)]
15pub struct SignupResult {
16    /// The newly created user.
17    pub user: User,
18    /// Session created if auto-sign-in is enabled.
19    pub session: Option<Session>,
20    /// Raw session token if auto-sign-in is enabled.
21    pub session_token: Option<String>,
22    /// Raw verification token if email verification is enabled.
23    pub verification_token: Option<String>,
24}
25
26/// Result returned by [`AuthService::login`].
27#[derive(Debug)]
28pub struct LoginResult {
29    /// The authenticated user.
30    pub user: User,
31    /// The newly created session.
32    pub session: Session,
33    /// Raw session token to be stored in a cookie.
34    pub session_token: String,
35}
36
37/// Result returned by [`AuthService::verify_email`].
38#[derive(Debug)]
39pub struct VerifyEmailResult {
40    /// The user whose email was verified.
41    pub user: User,
42    /// Session created if auto-sign-in after verification is enabled.
43    pub session: Option<Session>,
44    /// Raw session token if auto-sign-in after verification is enabled.
45    pub session_token: Option<String>,
46}
47
48/// Result returned by [`AuthService::request_password_reset`].
49#[derive(Debug, Default)]
50pub struct RequestResetResult {
51    /// Private field to prevent construction outside this crate.
52    pub _private: (),
53}
54
55/// Result returned by [`AuthService::reset_password`].
56#[derive(Debug)]
57pub struct ResetPasswordResult {
58    /// The user whose password was reset.
59    pub user: User,
60}
61
62/// Result returned by [`AuthService::get_session`].
63#[derive(Debug)]
64pub struct SessionResult {
65    /// The user associated with the session.
66    pub user: User,
67    /// The session details.
68    pub session: Session,
69}
70
71/// Core authentication service. Generic over storage backends and email sender.
72pub struct AuthService<U, S, V, A, O, E>
73where
74    U: UserStore,
75    S: SessionStore,
76    V: VerificationStore,
77    A: AccountStore,
78    O: OAuthStateStore,
79    E: EmailSender,
80{
81    /// Authentication configuration.
82    pub config: AuthConfig,
83    /// User storage backend.
84    pub users: Arc<U>,
85    /// Session storage backend.
86    pub sessions: Arc<S>,
87    /// Verification token storage backend.
88    pub verifications: Arc<V>,
89    /// OAuth account storage backend.
90    pub accounts: Arc<A>,
91    /// OAuth state storage backend.
92    pub oauth_states: Arc<O>,
93    /// Email sender implementation.
94    pub email: Arc<E>,
95}
96
97impl<U, S, V, A, O, E> AuthService<U, S, V, A, O, E>
98where
99    U: UserStore,
100    S: SessionStore,
101    V: VerificationStore,
102    A: AccountStore,
103    O: OAuthStateStore,
104    E: EmailSender,
105{
106    /// Create a new authentication service with the given configuration and backends.
107    pub fn new(
108        config: AuthConfig,
109        users: U,
110        sessions: S,
111        verifications: V,
112        accounts: A,
113        oauth_states: O,
114        email: E,
115    ) -> Self {
116        Self {
117            config,
118            users: Arc::new(users),
119            sessions: Arc::new(sessions),
120            verifications: Arc::new(verifications),
121            accounts: Arc::new(accounts),
122            oauth_states: Arc::new(oauth_states),
123            email: Arc::new(email),
124        }
125    }
126
127    /// Register a new user with email and password.
128    pub async fn signup(
129        &self,
130        input: NewUser,
131        ip: Option<String>,
132        user_agent: Option<String>,
133    ) -> Result<SignupResult, AuthError> {
134        if input.password.len() < 8 {
135            return Err(AuthError::WeakPassword(8));
136        }
137
138        let email = input.email.trim().to_lowercase();
139        if self.users.find_by_email(&email).await?.is_some() {
140            return Err(AuthError::EmailTaken);
141        }
142
143        let password_hash = hash::hash_password(&input.password)?;
144        let user = self
145            .users
146            .create_user(&email, input.name.as_deref(), Some(&password_hash))
147            .await?;
148
149        let verification_token = if self.config.email.send_verification_on_signup {
150            let identifier = format!("email-verify:{}", user.email.to_lowercase());
151            let raw_token = token::generate_token(self.config.token_length);
152
153            let _ = self.verifications.delete_by_identifier(&identifier).await;
154            self.verifications
155                .create_verification(NewVerification {
156                    identifier,
157                    token_hash: token::hash_token(&raw_token),
158                    expires_at: OffsetDateTime::now_utc() + self.config.verification_ttl,
159                })
160                .await?;
161            self.email
162                .send_verification_email(&user, &raw_token)
163                .await?;
164            Some(raw_token)
165        } else {
166            None
167        };
168
169        let (session, session_token) = if self.config.email.auto_sign_in_after_signup {
170            let (session, raw_token) = self
171                .create_session_internal(user.id, ip, user_agent)
172                .await?;
173            (Some(session), Some(raw_token))
174        } else {
175            (None, None)
176        };
177
178        Ok(SignupResult {
179            user,
180            session,
181            session_token,
182            verification_token,
183        })
184    }
185
186    /// Authenticate a user with email and password.
187    pub async fn login(
188        &self,
189        email: &str,
190        password: &str,
191        ip: Option<String>,
192        user_agent: Option<String>,
193    ) -> Result<LoginResult, AuthError> {
194        let user = self
195            .users
196            .find_by_email(&email.trim().to_lowercase())
197            .await?
198            .ok_or(AuthError::InvalidCredentials)?;
199
200        let password_hash = user
201            .password_hash
202            .as_deref()
203            .ok_or(AuthError::InvalidCredentials)?;
204        if !hash::verify_password(password, password_hash)? {
205            return Err(AuthError::InvalidCredentials);
206        }
207
208        if self.config.email.require_verification_to_login && !user.is_verified() {
209            return Err(AuthError::EmailNotVerified);
210        }
211
212        let (session, session_token) = self
213            .create_session_internal(user.id, ip, user_agent)
214            .await?;
215
216        Ok(LoginResult {
217            user,
218            session,
219            session_token,
220        })
221    }
222
223    /// Delete a single session by ID.
224    pub async fn logout(&self, session_id: i64) -> Result<(), AuthError> {
225        self.sessions.delete_session(session_id).await
226    }
227
228    /// Delete all sessions for a user.
229    pub async fn logout_all(&self, user_id: i64) -> Result<(), AuthError> {
230        self.sessions.delete_by_user_id(user_id).await
231    }
232
233    /// Retrieve a session and its associated user by raw token.
234    pub async fn get_session(&self, raw_token: &str) -> Result<SessionResult, AuthError> {
235        let session = self
236            .sessions
237            .find_by_token_hash(&token::hash_token(raw_token))
238            .await?
239            .ok_or(AuthError::SessionNotFound)?;
240
241        if session.expires_at < OffsetDateTime::now_utc() {
242            self.sessions.delete_session(session.id).await?;
243            return Err(AuthError::SessionNotFound);
244        }
245
246        let user = self
247            .users
248            .find_by_id(session.user_id)
249            .await?
250            .ok_or(AuthError::UserNotFound)?;
251
252        Ok(SessionResult { user, session })
253    }
254
255    /// List all active sessions for a user.
256    pub async fn list_sessions(&self, user_id: i64) -> Result<Vec<Session>, AuthError> {
257        self.sessions.find_by_user_id(user_id).await
258    }
259
260    /// Verify a user's email address using a verification token.
261    pub async fn verify_email(
262        &self,
263        raw_token: &str,
264        ip: Option<String>,
265        user_agent: Option<String>,
266    ) -> Result<VerifyEmailResult, AuthError> {
267        let verification = self.lookup_verification(raw_token, "email-verify:").await?;
268        let email = verification
269            .identifier
270            .strip_prefix("email-verify:")
271            .ok_or(AuthError::InvalidToken)?;
272
273        let user = self
274            .users
275            .find_by_email(email)
276            .await?
277            .ok_or(AuthError::UserNotFound)?;
278        self.users.set_email_verified(user.id).await?;
279        self.verifications
280            .delete_verification(verification.id)
281            .await?;
282
283        let user = self
284            .users
285            .find_by_id(user.id)
286            .await?
287            .ok_or(AuthError::UserNotFound)?;
288
289        let (session, session_token) = if self.config.email.auto_sign_in_after_verification {
290            let (session, raw_token) = self
291                .create_session_internal(user.id, ip, user_agent)
292                .await?;
293            (Some(session), Some(raw_token))
294        } else {
295            (None, None)
296        };
297
298        Ok(VerifyEmailResult {
299            user,
300            session,
301            session_token,
302        })
303    }
304
305    /// Request a password reset token for a user by email.
306    pub async fn request_password_reset(
307        &self,
308        email: &str,
309    ) -> Result<RequestResetResult, AuthError> {
310        let email = email.trim().to_lowercase();
311        if let Some(user) = self.users.find_by_email(&email).await? {
312            let identifier = format!("password-reset:{}", user.email.to_lowercase());
313            let _ = self.verifications.delete_by_identifier(&identifier).await;
314
315            let raw_token = token::generate_token(self.config.token_length);
316            self.verifications
317                .create_verification(NewVerification {
318                    identifier,
319                    token_hash: token::hash_token(&raw_token),
320                    expires_at: OffsetDateTime::now_utc() + self.config.reset_ttl,
321                })
322                .await?;
323            self.email
324                .send_password_reset_email(&user, &raw_token)
325                .await?;
326        }
327
328        Ok(RequestResetResult::default())
329    }
330
331    /// Reset a user's password using a reset token.
332    pub async fn reset_password(
333        &self,
334        raw_token: &str,
335        new_password: &str,
336    ) -> Result<ResetPasswordResult, AuthError> {
337        if new_password.len() < 8 {
338            return Err(AuthError::WeakPassword(8));
339        }
340
341        let verification = self
342            .lookup_verification(raw_token, "password-reset:")
343            .await?;
344        let email = verification
345            .identifier
346            .strip_prefix("password-reset:")
347            .ok_or(AuthError::InvalidToken)?;
348        let user = self
349            .users
350            .find_by_email(email)
351            .await?
352            .ok_or(AuthError::UserNotFound)?;
353
354        self.users
355            .update_password(user.id, &hash::hash_password(new_password)?)
356            .await?;
357        self.sessions.delete_by_user_id(user.id).await?;
358        self.verifications
359            .delete_verification(verification.id)
360            .await?;
361
362        let user = self
363            .users
364            .find_by_id(user.id)
365            .await?
366            .ok_or(AuthError::UserNotFound)?;
367
368        Ok(ResetPasswordResult { user })
369    }
370
371    /// Delete expired sessions, verification tokens, and OAuth states.
372    /// Returns (sessions_deleted, verifications_deleted, oauth_states_deleted).
373    pub async fn cleanup_expired(&self) -> Result<(u64, u64, u64), AuthError> {
374        let sessions_deleted = self.sessions.delete_expired().await?;
375        let verifications_deleted = self.verifications.delete_expired().await?;
376        let oauth_states_deleted = self.oauth_states.delete_expired_oauth_states().await?;
377        Ok((
378            sessions_deleted,
379            verifications_deleted,
380            oauth_states_deleted,
381        ))
382    }
383
384    /// Handle OAuth callback - find or create user from OAuth info.
385    ///
386    /// NOTE: OAuth state verification happens in the handler layer before calling this method.
387    /// CSRF state and PKCE verifier are stored in the dedicated `oauth_states` table.
388    pub async fn oauth_callback(
389        &self,
390        info: OAuthUserInfo,
391        tokens: OAuthTokens,
392        ip: Option<String>,
393        user_agent: Option<String>,
394    ) -> Result<LoginResult, AuthError> {
395        // 1. Check if account already exists for this provider
396        if let Some(account) = self
397            .accounts
398            .find_by_provider(&info.provider_id, &info.account_id)
399            .await?
400        {
401            // Existing account - just create session
402            let user = self
403                .users
404                .find_by_id(account.user_id)
405                .await?
406                .ok_or(AuthError::UserNotFound)?;
407            let (session, session_token) = self
408                .create_session_internal(user.id, ip, user_agent)
409                .await?;
410            return Ok(LoginResult {
411                user,
412                session,
413                session_token,
414            });
415        }
416
417        // 2. Check if user with this email already exists
418        let user = if let Some(existing_user) = self.users.find_by_email(&info.email).await? {
419            // Check if implicit account linking is allowed
420            if !self.config.oauth.allow_implicit_account_linking {
421                return Err(AuthError::OAuth(OAuthError::LinkingDisabled));
422            }
423            existing_user
424        } else {
425            // 3. Create new user (no password for OAuth-only users)
426            self.users
427                .create_user(&info.email, info.name.as_deref(), None)
428                .await?
429        };
430
431        // 4. Link account with tokens
432        let access_token_expires_at = tokens.expires_in.map(|d| OffsetDateTime::now_utc() + d);
433        self.accounts
434            .create_account(NewAccount {
435                user_id: user.id,
436                provider_id: info.provider_id,
437                account_id: info.account_id,
438                access_token: tokens.access_token,
439                refresh_token: tokens.refresh_token,
440                access_token_expires_at,
441                scope: tokens.scope,
442            })
443            .await?;
444
445        // 5. Mark email as verified (OAuth providers verify emails)
446        if !user.is_verified() {
447            self.users.set_email_verified(user.id).await?;
448        }
449
450        // 6. Reload user and create session
451        let user = self
452            .users
453            .find_by_id(user.id)
454            .await?
455            .ok_or(AuthError::UserNotFound)?;
456        let (session, session_token) = self
457            .create_session_internal(user.id, ip, user_agent)
458            .await?;
459        Ok(LoginResult {
460            user,
461            session,
462            session_token,
463        })
464    }
465
466    async fn create_session_internal(
467        &self,
468        user_id: i64,
469        ip: Option<String>,
470        user_agent: Option<String>,
471    ) -> Result<(Session, String), AuthError> {
472        let raw_token = token::generate_token(self.config.token_length);
473        let session = self
474            .sessions
475            .create_session(NewSession {
476                token_hash: token::hash_token(&raw_token),
477                user_id,
478                expires_at: OffsetDateTime::now_utc() + self.config.session_ttl,
479                ip_address: ip,
480                user_agent,
481            })
482            .await?;
483
484        Ok((session, raw_token))
485    }
486
487    async fn lookup_verification(
488        &self,
489        raw_token: &str,
490        prefix: &str,
491    ) -> Result<Verification, AuthError> {
492        let verification = self
493            .verifications
494            .find_by_token_hash(&token::hash_token(raw_token))
495            .await?
496            .ok_or(AuthError::InvalidToken)?;
497
498        if !verification.identifier.starts_with(prefix) {
499            return Err(AuthError::InvalidToken);
500        }
501
502        if verification.expires_at < OffsetDateTime::now_utc() {
503            self.verifications
504                .delete_verification(verification.id)
505                .await?;
506            return Err(AuthError::InvalidToken);
507        }
508
509        Ok(verification)
510    }
511}
512
513#[cfg(test)]
514mod tests {
515    use std::collections::HashMap;
516    use std::sync::{Arc, Mutex};
517
518    use async_trait::async_trait;
519    use time::OffsetDateTime;
520
521    use super::AuthService;
522    use crate::config::AuthConfig;
523    use crate::email::EmailSender;
524    use crate::error::{AuthError, OAuthError};
525    use crate::oauth::OAuthTokens;
526    use crate::store::{AccountStore, OAuthStateStore, SessionStore, UserStore, VerificationStore};
527    use crate::types::{
528        Account, NewAccount, NewOAuthState, NewSession, NewUser, NewVerification, OAuthState,
529        Session, User, Verification,
530    };
531
532    #[derive(Default)]
533    struct MemoryState {
534        next_user_id: i64,
535        next_session_id: i64,
536        next_verification_id: i64,
537        next_account_id: i64,
538        next_oauth_state_id: i64,
539        users: HashMap<i64, User>,
540        sessions: HashMap<i64, Session>,
541        verifications: HashMap<i64, Verification>,
542        accounts: HashMap<i64, Account>,
543        oauth_states: HashMap<i64, OAuthState>,
544    }
545
546    #[derive(Clone, Default)]
547    struct MemoryStore {
548        inner: Arc<Mutex<MemoryState>>,
549    }
550
551    #[async_trait]
552    impl UserStore for MemoryStore {
553        async fn create_user(
554            &self,
555            email: &str,
556            name: Option<&str>,
557            password_hash: Option<&str>,
558        ) -> Result<User, AuthError> {
559            let mut state = self.inner.lock().unwrap();
560            if state.users.values().any(|user| user.email == email) {
561                return Err(AuthError::EmailTaken);
562            }
563
564            state.next_user_id += 1;
565            let now = OffsetDateTime::now_utc();
566            let user = User {
567                id: state.next_user_id,
568                email: email.to_string(),
569                name: name.map(str::to_owned),
570                password_hash: password_hash.map(str::to_owned),
571                email_verified_at: None,
572                image: None,
573                created_at: now,
574                updated_at: now,
575            };
576            state.users.insert(user.id, user.clone());
577            Ok(user)
578        }
579
580        async fn find_by_email(&self, email: &str) -> Result<Option<User>, AuthError> {
581            let state = self.inner.lock().unwrap();
582            Ok(state
583                .users
584                .values()
585                .find(|user| user.email == email)
586                .cloned())
587        }
588
589        async fn find_by_id(&self, id: i64) -> Result<Option<User>, AuthError> {
590            Ok(self.inner.lock().unwrap().users.get(&id).cloned())
591        }
592
593        async fn set_email_verified(&self, user_id: i64) -> Result<(), AuthError> {
594            let mut state = self.inner.lock().unwrap();
595            let user = state
596                .users
597                .get_mut(&user_id)
598                .ok_or(AuthError::UserNotFound)?;
599            user.email_verified_at = Some(OffsetDateTime::now_utc());
600            user.updated_at = OffsetDateTime::now_utc();
601            Ok(())
602        }
603
604        async fn update_password(
605            &self,
606            user_id: i64,
607            password_hash: &str,
608        ) -> Result<(), AuthError> {
609            let mut state = self.inner.lock().unwrap();
610            let user = state
611                .users
612                .get_mut(&user_id)
613                .ok_or(AuthError::UserNotFound)?;
614            user.password_hash = Some(password_hash.to_string());
615            user.updated_at = OffsetDateTime::now_utc();
616            Ok(())
617        }
618
619        async fn delete_user(&self, user_id: i64) -> Result<(), AuthError> {
620            self.inner.lock().unwrap().users.remove(&user_id);
621            Ok(())
622        }
623    }
624
625    #[async_trait]
626    impl SessionStore for MemoryStore {
627        async fn create_session(&self, session: NewSession) -> Result<Session, AuthError> {
628            let mut state = self.inner.lock().unwrap();
629            state.next_session_id += 1;
630            let now = OffsetDateTime::now_utc();
631            let session = Session {
632                id: state.next_session_id,
633                token_hash: session.token_hash,
634                user_id: session.user_id,
635                expires_at: session.expires_at,
636                ip_address: session.ip_address,
637                user_agent: session.user_agent,
638                created_at: now,
639                updated_at: now,
640            };
641            state.sessions.insert(session.id, session.clone());
642            Ok(session)
643        }
644
645        async fn find_by_token_hash(&self, token_hash: &str) -> Result<Option<Session>, AuthError> {
646            let state = self.inner.lock().unwrap();
647            Ok(state
648                .sessions
649                .values()
650                .find(|session| session.token_hash == token_hash)
651                .cloned())
652        }
653
654        async fn find_by_user_id(&self, user_id: i64) -> Result<Vec<Session>, AuthError> {
655            let state = self.inner.lock().unwrap();
656            Ok(state
657                .sessions
658                .values()
659                .filter(|session| session.user_id == user_id)
660                .cloned()
661                .collect())
662        }
663
664        async fn delete_session(&self, id: i64) -> Result<(), AuthError> {
665            self.inner.lock().unwrap().sessions.remove(&id);
666            Ok(())
667        }
668
669        async fn delete_by_user_id(&self, user_id: i64) -> Result<(), AuthError> {
670            self.inner
671                .lock()
672                .unwrap()
673                .sessions
674                .retain(|_, session| session.user_id != user_id);
675            Ok(())
676        }
677
678        async fn delete_expired(&self) -> Result<u64, AuthError> {
679            let now = OffsetDateTime::now_utc();
680            let mut state = self.inner.lock().unwrap();
681            let before = state.sessions.len();
682            state
683                .sessions
684                .retain(|_, session| session.expires_at >= now);
685            Ok((before - state.sessions.len()) as u64)
686        }
687    }
688
689    #[async_trait]
690    impl VerificationStore for MemoryStore {
691        async fn create_verification(
692            &self,
693            verification: NewVerification,
694        ) -> Result<Verification, AuthError> {
695            let mut state = self.inner.lock().unwrap();
696            state.next_verification_id += 1;
697            let now = OffsetDateTime::now_utc();
698            let verification = Verification {
699                id: state.next_verification_id,
700                identifier: verification.identifier,
701                token_hash: verification.token_hash,
702                expires_at: verification.expires_at,
703                created_at: now,
704                updated_at: now,
705            };
706            state
707                .verifications
708                .insert(verification.id, verification.clone());
709            Ok(verification)
710        }
711
712        async fn find_by_identifier(
713            &self,
714            identifier: &str,
715        ) -> Result<Option<Verification>, AuthError> {
716            let state = self.inner.lock().unwrap();
717            Ok(state
718                .verifications
719                .values()
720                .find(|verification| verification.identifier == identifier)
721                .cloned())
722        }
723
724        async fn find_by_token_hash(
725            &self,
726            token_hash: &str,
727        ) -> Result<Option<Verification>, AuthError> {
728            let state = self.inner.lock().unwrap();
729            Ok(state
730                .verifications
731                .values()
732                .find(|verification| verification.token_hash == token_hash)
733                .cloned())
734        }
735
736        async fn delete_verification(&self, id: i64) -> Result<(), AuthError> {
737            self.inner.lock().unwrap().verifications.remove(&id);
738            Ok(())
739        }
740
741        async fn delete_by_identifier(&self, identifier: &str) -> Result<(), AuthError> {
742            self.inner
743                .lock()
744                .unwrap()
745                .verifications
746                .retain(|_, verification| verification.identifier != identifier);
747            Ok(())
748        }
749
750        async fn delete_expired(&self) -> Result<u64, AuthError> {
751            let now = OffsetDateTime::now_utc();
752            let mut state = self.inner.lock().unwrap();
753            let before = state.verifications.len();
754            state
755                .verifications
756                .retain(|_, verification| verification.expires_at >= now);
757            Ok((before - state.verifications.len()) as u64)
758        }
759    }
760
761    #[async_trait]
762    impl AccountStore for MemoryStore {
763        async fn create_account(&self, account: NewAccount) -> Result<Account, AuthError> {
764            let mut state = self.inner.lock().unwrap();
765            state.next_account_id += 1;
766            let now = OffsetDateTime::now_utc();
767            let account = Account {
768                id: state.next_account_id,
769                user_id: account.user_id,
770                provider_id: account.provider_id,
771                account_id: account.account_id,
772                access_token: account.access_token,
773                refresh_token: account.refresh_token,
774                access_token_expires_at: account.access_token_expires_at,
775                scope: account.scope,
776                created_at: now,
777                updated_at: now,
778            };
779            state.accounts.insert(account.id, account.clone());
780            Ok(account)
781        }
782
783        async fn find_by_provider(
784            &self,
785            provider_id: &str,
786            account_id: &str,
787        ) -> Result<Option<Account>, AuthError> {
788            let state = self.inner.lock().unwrap();
789            Ok(state
790                .accounts
791                .values()
792                .find(|account| {
793                    account.provider_id == provider_id && account.account_id == account_id
794                })
795                .cloned())
796        }
797
798        async fn find_by_user_id(&self, user_id: i64) -> Result<Vec<Account>, AuthError> {
799            let state = self.inner.lock().unwrap();
800            Ok(state
801                .accounts
802                .values()
803                .filter(|account| account.user_id == user_id)
804                .cloned()
805                .collect())
806        }
807
808        async fn delete_account(&self, id: i64) -> Result<(), AuthError> {
809            self.inner.lock().unwrap().accounts.remove(&id);
810            Ok(())
811        }
812    }
813
814    #[async_trait]
815    impl OAuthStateStore for MemoryStore {
816        async fn create_oauth_state(
817            &self,
818            new_state: NewOAuthState,
819        ) -> Result<OAuthState, AuthError> {
820            let mut state = self.inner.lock().unwrap();
821            state.next_oauth_state_id += 1;
822            let now = OffsetDateTime::now_utc();
823            let oauth_state = OAuthState {
824                id: state.next_oauth_state_id,
825                provider_id: new_state.provider_id,
826                csrf_state: new_state.csrf_state,
827                pkce_verifier: new_state.pkce_verifier,
828                expires_at: new_state.expires_at,
829                created_at: now,
830            };
831            state
832                .oauth_states
833                .insert(oauth_state.id, oauth_state.clone());
834            Ok(oauth_state)
835        }
836
837        async fn find_by_csrf_state(
838            &self,
839            csrf_state: &str,
840        ) -> Result<Option<OAuthState>, AuthError> {
841            let state = self.inner.lock().unwrap();
842            Ok(state
843                .oauth_states
844                .values()
845                .find(|s| s.csrf_state == csrf_state)
846                .cloned())
847        }
848
849        async fn delete_oauth_state(&self, id: i64) -> Result<(), AuthError> {
850            self.inner.lock().unwrap().oauth_states.remove(&id);
851            Ok(())
852        }
853
854        async fn delete_expired_oauth_states(&self) -> Result<u64, AuthError> {
855            let now = OffsetDateTime::now_utc();
856            let mut state = self.inner.lock().unwrap();
857            let before = state.oauth_states.len();
858            state.oauth_states.retain(|_, s| s.expires_at >= now);
859            Ok((before - state.oauth_states.len()) as u64)
860        }
861    }
862
863    #[derive(Clone, Default)]
864    struct TestEmailSender {
865        verification_tokens: Arc<Mutex<Vec<String>>>,
866        reset_tokens: Arc<Mutex<Vec<String>>>,
867    }
868
869    #[async_trait]
870    impl EmailSender for TestEmailSender {
871        async fn send_verification_email(
872            &self,
873            _user: &User,
874            token: &str,
875        ) -> Result<(), AuthError> {
876            self.verification_tokens
877                .lock()
878                .unwrap()
879                .push(token.to_string());
880            Ok(())
881        }
882
883        async fn send_password_reset_email(
884            &self,
885            _user: &User,
886            token: &str,
887        ) -> Result<(), AuthError> {
888            self.reset_tokens.lock().unwrap().push(token.to_string());
889            Ok(())
890        }
891    }
892
893    #[tokio::test]
894    async fn signup_verify_login_and_reset_flow_works() {
895        let store = MemoryStore::default();
896        let email = TestEmailSender::default();
897        let service = AuthService::new(
898            AuthConfig::default(),
899            store.clone(),
900            store.clone(),
901            store.clone(),
902            store.clone(),
903            store.clone(),
904            email.clone(),
905        );
906
907        let signup = service
908            .signup(
909                NewUser {
910                    email: "test@example.com".to_string(),
911                    name: Some("Test".to_string()),
912                    password: "supersecret".to_string(),
913                },
914                Some("127.0.0.1".to_string()),
915                Some("test-agent".to_string()),
916            )
917            .await
918            .unwrap();
919
920        assert_eq!(signup.user.email, "test@example.com");
921        assert!(signup.session.is_some());
922        assert_eq!(email.verification_tokens.lock().unwrap().len(), 1);
923
924        let verification_token = email.verification_tokens.lock().unwrap()[0].clone();
925        let verify = service
926            .verify_email(&verification_token, None, None)
927            .await
928            .unwrap();
929        assert!(verify.user.is_verified());
930
931        let login = service
932            .login("test@example.com", "supersecret", None, None)
933            .await
934            .unwrap();
935        assert_eq!(login.user.email, "test@example.com");
936
937        service
938            .request_password_reset("test@example.com")
939            .await
940            .unwrap();
941        let reset_token = email.reset_tokens.lock().unwrap()[0].clone();
942        service
943            .reset_password(&reset_token, "newpassword")
944            .await
945            .unwrap();
946
947        let login = service
948            .login("test@example.com", "newpassword", None, None)
949            .await
950            .unwrap();
951        assert_eq!(login.user.email, "test@example.com");
952    }
953
954    #[tokio::test]
955    async fn oauth_callback_creates_new_user_and_account() {
956        let store = MemoryStore::default();
957        let email = TestEmailSender::default();
958        let service = AuthService::new(
959            AuthConfig::default(),
960            store.clone(),
961            store.clone(),
962            store.clone(),
963            store.clone(),
964            store.clone(),
965            email.clone(),
966        );
967
968        let oauth_info = crate::oauth::OAuthUserInfo {
969            provider_id: "google".to_string(),
970            account_id: "google-123".to_string(),
971            email: "oauth@example.com".to_string(),
972            name: Some("OAuth User".to_string()),
973            image: Some("https://example.com/avatar.jpg".to_string()),
974        };
975
976        let result = service
977            .oauth_callback(
978                oauth_info,
979                OAuthTokens::default(),
980                Some("127.0.0.1".to_string()),
981                Some("test-agent".to_string()),
982            )
983            .await
984            .unwrap();
985
986        assert_eq!(result.user.email, "oauth@example.com");
987        assert_eq!(result.user.name, Some("OAuth User".to_string()));
988        assert!(result.user.is_verified());
989        assert!(result.user.password_hash.is_none());
990
991        // Verify account was created
992        let accounts = AccountStore::find_by_user_id(&store, result.user.id)
993            .await
994            .unwrap();
995        assert_eq!(accounts.len(), 1);
996        assert_eq!(accounts[0].provider_id, "google");
997        assert_eq!(accounts[0].account_id, "google-123");
998    }
999
1000    #[tokio::test]
1001    async fn oauth_callback_links_existing_user_by_email() {
1002        let store = MemoryStore::default();
1003        let email = TestEmailSender::default();
1004        let service = AuthService::new(
1005            AuthConfig::default(),
1006            store.clone(),
1007            store.clone(),
1008            store.clone(),
1009            store.clone(),
1010            store.clone(),
1011            email.clone(),
1012        );
1013
1014        // Create existing user with password
1015        let existing_user = store
1016            .create_user(
1017                "existing@example.com",
1018                Some("Existing User"),
1019                Some("hash123"),
1020            )
1021            .await
1022            .unwrap();
1023
1024        let oauth_info = crate::oauth::OAuthUserInfo {
1025            provider_id: "github".to_string(),
1026            account_id: "github-456".to_string(),
1027            email: "existing@example.com".to_string(),
1028            name: Some("GitHub User".to_string()),
1029            image: None,
1030        };
1031
1032        let result = service
1033            .oauth_callback(
1034                oauth_info,
1035                OAuthTokens::default(),
1036                Some("127.0.0.1".to_string()),
1037                Some("test-agent".to_string()),
1038            )
1039            .await
1040            .unwrap();
1041
1042        // Should link to existing user
1043        assert_eq!(result.user.id, existing_user.id);
1044        assert_eq!(result.user.email, "existing@example.com");
1045        assert!(result.user.is_verified());
1046
1047        // Verify account was linked
1048        let accounts = AccountStore::find_by_user_id(&store, result.user.id)
1049            .await
1050            .unwrap();
1051        assert_eq!(accounts.len(), 1);
1052        assert_eq!(accounts[0].provider_id, "github");
1053        assert_eq!(accounts[0].account_id, "github-456");
1054    }
1055
1056    #[tokio::test]
1057    async fn oauth_callback_logs_in_existing_account() {
1058        let store = MemoryStore::default();
1059        let email = TestEmailSender::default();
1060        let service = AuthService::new(
1061            AuthConfig::default(),
1062            store.clone(),
1063            store.clone(),
1064            store.clone(),
1065            store.clone(),
1066            store.clone(),
1067            email.clone(),
1068        );
1069
1070        // Create user and account
1071        let user = store
1072            .create_user("oauth@example.com", Some("OAuth User"), None)
1073            .await
1074            .unwrap();
1075        store
1076            .create_account(crate::types::NewAccount {
1077                user_id: user.id,
1078                provider_id: "google".to_string(),
1079                account_id: "google-789".to_string(),
1080                access_token: None,
1081                refresh_token: None,
1082                access_token_expires_at: None,
1083                scope: None,
1084            })
1085            .await
1086            .unwrap();
1087
1088        let oauth_info = crate::oauth::OAuthUserInfo {
1089            provider_id: "google".to_string(),
1090            account_id: "google-789".to_string(),
1091            email: "oauth@example.com".to_string(),
1092            name: Some("OAuth User".to_string()),
1093            image: None,
1094        };
1095
1096        let result = service
1097            .oauth_callback(
1098                oauth_info,
1099                OAuthTokens::default(),
1100                Some("127.0.0.1".to_string()),
1101                Some("test-agent".to_string()),
1102            )
1103            .await
1104            .unwrap();
1105
1106        // Should log in existing user
1107        assert_eq!(result.user.id, user.id);
1108        assert_eq!(result.user.email, "oauth@example.com");
1109
1110        // Should not create duplicate account
1111        let accounts = AccountStore::find_by_user_id(&store, result.user.id)
1112            .await
1113            .unwrap();
1114        assert_eq!(accounts.len(), 1);
1115    }
1116
1117    #[tokio::test]
1118    async fn oauth_callback_respects_linking_policy() {
1119        let store = MemoryStore::default();
1120        let email = TestEmailSender::default();
1121        let mut config = AuthConfig::default();
1122        config.oauth.allow_implicit_account_linking = false;
1123
1124        let service = AuthService::new(
1125            config,
1126            store.clone(),
1127            store.clone(),
1128            store.clone(),
1129            store.clone(),
1130            store.clone(),
1131            email.clone(),
1132        );
1133
1134        // Create existing user
1135        store
1136            .create_user(
1137                "existing@example.com",
1138                Some("Existing User"),
1139                Some("hash123"),
1140            )
1141            .await
1142            .unwrap();
1143
1144        let oauth_info = crate::oauth::OAuthUserInfo {
1145            provider_id: "google".to_string(),
1146            account_id: "google-999".to_string(),
1147            email: "existing@example.com".to_string(),
1148            name: Some("OAuth User".to_string()),
1149            image: None,
1150        };
1151
1152        // Should fail because linking is disabled
1153        let result = service
1154            .oauth_callback(
1155                oauth_info,
1156                OAuthTokens::default(),
1157                Some("127.0.0.1".to_string()),
1158                Some("test-agent".to_string()),
1159            )
1160            .await;
1161
1162        assert!(result.is_err());
1163        match result {
1164            Err(AuthError::OAuth(OAuthError::LinkingDisabled)) => {}
1165            _ => panic!("Expected OAuth linking-disabled error"),
1166        }
1167    }
1168
1169    #[tokio::test]
1170    async fn cleanup_expired_removes_sessions_verifications_and_oauth_states() {
1171        let store = MemoryStore::default();
1172        let email = TestEmailSender::default();
1173        let service = AuthService::new(
1174            AuthConfig::default(),
1175            store.clone(),
1176            store.clone(),
1177            store.clone(),
1178            store.clone(),
1179            store.clone(),
1180            email,
1181        );
1182
1183        store
1184            .create_session(NewSession {
1185                token_hash: "expired-session".to_string(),
1186                user_id: 1,
1187                expires_at: OffsetDateTime::now_utc() - time::Duration::hours(1),
1188                ip_address: None,
1189                user_agent: None,
1190            })
1191            .await
1192            .unwrap();
1193        store
1194            .create_verification(NewVerification {
1195                identifier: "email-verify:test@example.com".to_string(),
1196                token_hash: "expired-verification".to_string(),
1197                expires_at: OffsetDateTime::now_utc() - time::Duration::hours(1),
1198            })
1199            .await
1200            .unwrap();
1201        store
1202            .create_oauth_state(NewOAuthState {
1203                provider_id: "google".to_string(),
1204                csrf_state: "expired-oauth-state".to_string(),
1205                pkce_verifier: "pkce-verifier".to_string(),
1206                expires_at: OffsetDateTime::now_utc() - time::Duration::hours(1),
1207            })
1208            .await
1209            .unwrap();
1210
1211        let deleted = service.cleanup_expired().await.unwrap();
1212
1213        assert_eq!(deleted, (1, 1, 1));
1214    }
1215}