Skip to main content

better_auth_api/plugins/
password_management.rs

1use async_trait::async_trait;
2use chrono::{Duration, Utc};
3use serde::{Deserialize, Deserializer, Serialize};
4use uuid::Uuid;
5use validator::Validate;
6
7use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
8use better_auth_core::{AuthError, AuthResult};
9use better_auth_core::{AuthRequest, AuthResponse, CreateVerification, HttpMethod, UpdateUser};
10use better_auth_core::{AuthSession, AuthUser, AuthVerification, DatabaseAdapter};
11
12/// Password management plugin for password reset and change functionality
13pub struct PasswordManagementPlugin {
14    config: PasswordManagementConfig,
15}
16
17#[derive(Debug, Clone)]
18pub struct PasswordManagementConfig {
19    pub reset_token_expiry_hours: i64,
20    pub require_current_password: bool,
21    pub send_email_notifications: bool,
22}
23
24// Request structures for password endpoints
25#[derive(Debug, Deserialize, Validate)]
26struct ForgetPasswordRequest {
27    #[validate(email(message = "Invalid email address"))]
28    email: String,
29    #[serde(rename = "redirectTo")]
30    redirect_to: Option<String>,
31}
32
33#[derive(Debug, Deserialize, Validate)]
34struct ResetPasswordRequest {
35    #[serde(rename = "newPassword")]
36    #[validate(length(min = 1, message = "New password is required"))]
37    new_password: String,
38    token: Option<String>,
39}
40
41#[derive(Debug, Deserialize, Validate)]
42struct SetPasswordRequest {
43    #[serde(rename = "newPassword")]
44    #[validate(length(min = 1, message = "New password is required"))]
45    new_password: String,
46}
47
48#[derive(Debug, Deserialize, Validate)]
49struct ChangePasswordRequest {
50    #[serde(rename = "newPassword")]
51    #[validate(length(min = 1, message = "New password is required"))]
52    new_password: String,
53    #[serde(rename = "currentPassword")]
54    #[validate(length(min = 1, message = "Current password is required"))]
55    current_password: String,
56    #[serde(
57        default,
58        rename = "revokeOtherSessions",
59        deserialize_with = "deserialize_bool_or_string"
60    )]
61    revoke_other_sessions: Option<bool>,
62}
63
64/// Deserialize a value that can be either a boolean or a string ("true"/"false") into Option<bool>.
65/// This is needed because the better-auth TypeScript SDK sends `revokeOtherSessions` as a boolean,
66/// while some clients may send it as a string.
67fn deserialize_bool_or_string<'de, D>(deserializer: D) -> Result<Option<bool>, D::Error>
68where
69    D: Deserializer<'de>,
70{
71    let value: Option<serde_json::Value> = Option::deserialize(deserializer)?;
72    match value {
73        None => Ok(None),
74        Some(serde_json::Value::Bool(b)) => Ok(Some(b)),
75        Some(serde_json::Value::String(s)) => match s.to_lowercase().as_str() {
76            "true" => Ok(Some(true)),
77            "false" => Ok(Some(false)),
78            _ => Err(serde::de::Error::custom(format!(
79                "invalid value for revokeOtherSessions: {}",
80                s
81            ))),
82        },
83        Some(other) => Err(serde::de::Error::custom(format!(
84            "invalid type for revokeOtherSessions: {}",
85            other
86        ))),
87    }
88}
89
90// Response structures
91#[derive(Debug, Serialize, Deserialize)]
92struct StatusResponse {
93    status: bool,
94}
95
96#[derive(Debug, Serialize)]
97struct ChangePasswordResponse<U: Serialize> {
98    token: Option<String>,
99    user: U,
100}
101
102#[derive(Debug, Serialize, Deserialize)]
103struct ResetPasswordTokenResponse {
104    token: String,
105}
106
107impl PasswordManagementPlugin {
108    pub fn new() -> Self {
109        Self {
110            config: PasswordManagementConfig::default(),
111        }
112    }
113
114    pub fn with_config(config: PasswordManagementConfig) -> Self {
115        Self { config }
116    }
117
118    pub fn reset_token_expiry_hours(mut self, hours: i64) -> Self {
119        self.config.reset_token_expiry_hours = hours;
120        self
121    }
122
123    pub fn require_current_password(mut self, require: bool) -> Self {
124        self.config.require_current_password = require;
125        self
126    }
127
128    pub fn send_email_notifications(mut self, send: bool) -> Self {
129        self.config.send_email_notifications = send;
130        self
131    }
132}
133
134impl Default for PasswordManagementPlugin {
135    fn default() -> Self {
136        Self::new()
137    }
138}
139
140impl Default for PasswordManagementConfig {
141    fn default() -> Self {
142        Self {
143            reset_token_expiry_hours: 24, // 24 hours default expiry
144            require_current_password: true,
145            send_email_notifications: true,
146        }
147    }
148}
149
150#[async_trait]
151impl<DB: DatabaseAdapter> AuthPlugin<DB> for PasswordManagementPlugin {
152    fn name(&self) -> &'static str {
153        "password-management"
154    }
155
156    fn routes(&self) -> Vec<AuthRoute> {
157        vec![
158            AuthRoute::post("/forget-password", "forget_password"),
159            AuthRoute::post("/reset-password", "reset_password"),
160            AuthRoute::get("/reset-password/{token}", "reset_password_token"),
161            AuthRoute::post("/change-password", "change_password"),
162            AuthRoute::post("/set-password", "set_password"),
163        ]
164    }
165
166    async fn on_request(
167        &self,
168        req: &AuthRequest,
169        ctx: &AuthContext<DB>,
170    ) -> AuthResult<Option<AuthResponse>> {
171        match (req.method(), req.path()) {
172            (HttpMethod::Post, "/forget-password") => {
173                Ok(Some(self.handle_forget_password(req, ctx).await?))
174            }
175            (HttpMethod::Post, "/reset-password") => {
176                Ok(Some(self.handle_reset_password(req, ctx).await?))
177            }
178            (HttpMethod::Post, "/change-password") => {
179                Ok(Some(self.handle_change_password(req, ctx).await?))
180            }
181            (HttpMethod::Post, "/set-password") => {
182                Ok(Some(self.handle_set_password(req, ctx).await?))
183            }
184            (HttpMethod::Get, path) if path.starts_with("/reset-password/") => {
185                let token = &path[16..]; // Remove "/reset-password/" prefix
186                Ok(Some(
187                    self.handle_reset_password_token(token, req, ctx).await?,
188                ))
189            }
190            _ => Ok(None),
191        }
192    }
193}
194
195// Implementation methods outside the trait
196impl PasswordManagementPlugin {
197    async fn handle_forget_password<DB: DatabaseAdapter>(
198        &self,
199        req: &AuthRequest,
200        ctx: &AuthContext<DB>,
201    ) -> AuthResult<AuthResponse> {
202        let forget_req: ForgetPasswordRequest = match better_auth_core::validate_request_body(req) {
203            Ok(v) => v,
204            Err(resp) => return Ok(resp),
205        };
206
207        // Check if user exists
208        let user = match ctx.database.get_user_by_email(&forget_req.email).await? {
209            Some(user) => user,
210            None => {
211                // Don't reveal whether email exists or not for security
212                let response = StatusResponse { status: true };
213                return Ok(AuthResponse::json(200, &response)?);
214            }
215        };
216
217        // Generate password reset token
218        let reset_token = format!("reset_{}", Uuid::new_v4());
219        let expires_at = Utc::now() + Duration::hours(self.config.reset_token_expiry_hours);
220
221        // Create verification token
222        let create_verification = CreateVerification {
223            identifier: user.email().unwrap_or_default().to_string(),
224            value: reset_token.clone(),
225            expires_at,
226        };
227
228        ctx.database
229            .create_verification(create_verification)
230            .await?;
231
232        // Send email with reset link
233        if self.config.send_email_notifications {
234            let reset_url = if let Some(redirect_to) = &forget_req.redirect_to {
235                format!("{}?token={}", redirect_to, reset_token)
236            } else {
237                format!(
238                    "{}/reset-password?token={}",
239                    ctx.config.base_url, reset_token
240                )
241            };
242
243            if let Ok(provider) = ctx.email_provider() {
244                let subject = "Reset your password";
245                let html = format!(
246                    "<p>Click the link below to reset your password:</p>\
247                     <p><a href=\"{url}\">Reset Password</a></p>",
248                    url = reset_url
249                );
250                let text = format!("Reset your password: {}", reset_url);
251
252                if let Err(e) = provider
253                    .send(&forget_req.email, subject, &html, &text)
254                    .await
255                {
256                    eprintln!(
257                        "[password-management] Failed to send reset email to {}: {}",
258                        forget_req.email, e
259                    );
260                }
261            } else {
262                eprintln!(
263                    "[password-management] No email provider configured, skipping password reset email for {}",
264                    forget_req.email
265                );
266            }
267        }
268
269        let response = StatusResponse { status: true };
270        Ok(AuthResponse::json(200, &response)?)
271    }
272
273    async fn handle_reset_password<DB: DatabaseAdapter>(
274        &self,
275        req: &AuthRequest,
276        ctx: &AuthContext<DB>,
277    ) -> AuthResult<AuthResponse> {
278        let reset_req: ResetPasswordRequest = match better_auth_core::validate_request_body(req) {
279            Ok(v) => v,
280            Err(resp) => return Ok(resp),
281        };
282
283        // Validate password
284        self.validate_password(&reset_req.new_password, ctx)?;
285
286        // Find user by reset token
287        let token = reset_req.token.as_deref().unwrap_or("");
288        if token.is_empty() {
289            return Err(AuthError::bad_request("Reset token is required"));
290        }
291
292        let (user, verification) = self
293            .find_user_by_reset_token(token, ctx)
294            .await?
295            .ok_or_else(|| AuthError::bad_request("Invalid or expired reset token"))?;
296
297        // Hash new password
298        let password_hash = self.hash_password(&reset_req.new_password)?;
299
300        // Update user password
301        let mut metadata = user.metadata().clone();
302        metadata["password_hash"] = serde_json::Value::String(password_hash);
303
304        let update_user = UpdateUser {
305            email: None,
306            name: None,
307            image: None,
308            email_verified: None,
309            username: None,
310            display_username: None,
311            role: None,
312            banned: None,
313            ban_reason: None,
314            ban_expires: None,
315            two_factor_enabled: None,
316            metadata: Some(metadata),
317        };
318
319        ctx.database.update_user(user.id(), update_user).await?;
320
321        // Delete the used verification token
322        ctx.database.delete_verification(verification.id()).await?;
323
324        // Revoke all existing sessions for security
325        ctx.database.delete_user_sessions(user.id()).await?;
326
327        let response = StatusResponse { status: true };
328        Ok(AuthResponse::json(200, &response)?)
329    }
330
331    async fn handle_change_password<DB: DatabaseAdapter>(
332        &self,
333        req: &AuthRequest,
334        ctx: &AuthContext<DB>,
335    ) -> AuthResult<AuthResponse> {
336        let change_req: ChangePasswordRequest = match better_auth_core::validate_request_body(req) {
337            Ok(v) => v,
338            Err(resp) => return Ok(resp),
339        };
340
341        // Get current user from session
342        let user = self
343            .get_current_user(req, ctx)
344            .await?
345            .ok_or(AuthError::Unauthenticated)?;
346
347        // Verify current password
348        if self.config.require_current_password {
349            let stored_hash = user
350                .metadata()
351                .get("password_hash")
352                .and_then(|v| v.as_str())
353                .ok_or_else(|| AuthError::bad_request("No password set for this user"))?;
354
355            self.verify_password(&change_req.current_password, stored_hash)
356                .map_err(|_| AuthError::InvalidCredentials)?;
357        }
358
359        // Validate new password
360        self.validate_password(&change_req.new_password, ctx)?;
361
362        // Hash new password
363        let password_hash = self.hash_password(&change_req.new_password)?;
364
365        // Update user password
366        let mut metadata = user.metadata().clone();
367        metadata["password_hash"] = serde_json::Value::String(password_hash);
368
369        let update_user = UpdateUser {
370            email: None,
371            name: None,
372            image: None,
373            email_verified: None,
374            username: None,
375            display_username: None,
376            role: None,
377            banned: None,
378            ban_reason: None,
379            ban_expires: None,
380            two_factor_enabled: None,
381            metadata: Some(metadata),
382        };
383
384        let updated_user = ctx.database.update_user(user.id(), update_user).await?;
385
386        // Handle session revocation
387        let new_token = if change_req.revoke_other_sessions == Some(true) {
388            // Revoke all sessions except current one
389            ctx.database.delete_user_sessions(user.id()).await?;
390
391            // Create new session
392            let session_manager =
393                better_auth_core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
394            let session = session_manager
395                .create_session(&updated_user, None, None)
396                .await?;
397            Some(session.token().to_string())
398        } else {
399            None
400        };
401
402        let response = ChangePasswordResponse {
403            token: new_token.clone(),
404            user: updated_user,
405        };
406
407        let auth_response = AuthResponse::json(200, &response)?;
408
409        // Set session cookie if a new session was created
410        if let Some(token) = new_token {
411            let cookie_header = self.create_session_cookie(&token, ctx);
412            Ok(auth_response.with_header("Set-Cookie", cookie_header))
413        } else {
414            Ok(auth_response)
415        }
416    }
417
418    async fn handle_set_password<DB: DatabaseAdapter>(
419        &self,
420        req: &AuthRequest,
421        ctx: &AuthContext<DB>,
422    ) -> AuthResult<AuthResponse> {
423        let set_req: SetPasswordRequest = match better_auth_core::validate_request_body(req) {
424            Ok(v) => v,
425            Err(resp) => return Ok(resp),
426        };
427
428        // Authenticate user
429        let user = self
430            .get_current_user(req, ctx)
431            .await?
432            .ok_or(AuthError::Unauthenticated)?;
433
434        // Verify the user does NOT already have a password
435        if user
436            .metadata()
437            .get("password_hash")
438            .and_then(|v| v.as_str())
439            .is_some()
440        {
441            return Err(AuthError::bad_request(
442                "User already has a password. Use /change-password instead.",
443            ));
444        }
445
446        // Validate new password
447        self.validate_password(&set_req.new_password, ctx)?;
448
449        // Hash and store the new password
450        let password_hash = self.hash_password(&set_req.new_password)?;
451
452        let mut metadata = user.metadata().clone();
453        metadata["password_hash"] = serde_json::Value::String(password_hash);
454
455        let update_user = UpdateUser {
456            email: None,
457            name: None,
458            image: None,
459            email_verified: None,
460            username: None,
461            display_username: None,
462            role: None,
463            banned: None,
464            ban_reason: None,
465            ban_expires: None,
466            two_factor_enabled: None,
467            metadata: Some(metadata),
468        };
469
470        ctx.database.update_user(user.id(), update_user).await?;
471
472        let response = StatusResponse { status: true };
473        Ok(AuthResponse::json(200, &response)?)
474    }
475
476    async fn handle_reset_password_token<DB: DatabaseAdapter>(
477        &self,
478        token: &str,
479        _req: &AuthRequest,
480        ctx: &AuthContext<DB>,
481    ) -> AuthResult<AuthResponse> {
482        // Check if token is valid and get callback URL from query parameters
483        let callback_url = _req.query.get("callbackURL").cloned();
484
485        // Validate the reset token exists and is not expired
486        let (_user, _verification) = match self.find_user_by_reset_token(token, ctx).await? {
487            Some((user, verification)) => (user, verification),
488            None => {
489                // Redirect to callback URL with error if provided
490                if let Some(callback_url) = callback_url {
491                    let redirect_url = format!("{}?error=INVALID_TOKEN", callback_url);
492                    let mut headers = std::collections::HashMap::new();
493                    headers.insert("Location".to_string(), redirect_url);
494                    return Ok(AuthResponse {
495                        status: 302,
496                        headers,
497                        body: Vec::new(),
498                    });
499                }
500
501                return Err(AuthError::bad_request("Invalid or expired reset token"));
502            }
503        };
504
505        // If callback URL is provided, redirect with valid token
506        if let Some(callback_url) = callback_url {
507            let redirect_url = format!("{}?token={}", callback_url, token);
508            let mut headers = std::collections::HashMap::new();
509            headers.insert("Location".to_string(), redirect_url);
510            return Ok(AuthResponse {
511                status: 302,
512                headers,
513                body: Vec::new(),
514            });
515        }
516
517        // Otherwise return the token directly
518        let response = ResetPasswordTokenResponse {
519            token: token.to_string(),
520        };
521        Ok(AuthResponse::json(200, &response)?)
522    }
523
524    async fn find_user_by_reset_token<DB: DatabaseAdapter>(
525        &self,
526        token: &str,
527        ctx: &AuthContext<DB>,
528    ) -> AuthResult<Option<(DB::User, DB::Verification)>> {
529        // Find verification token by value
530        let verification = match ctx.database.get_verification_by_value(token).await? {
531            Some(verification) => verification,
532            None => return Ok(None),
533        };
534
535        // Get user by email (stored in identifier field)
536        let user = match ctx
537            .database
538            .get_user_by_email(verification.identifier())
539            .await?
540        {
541            Some(user) => user,
542            None => return Ok(None),
543        };
544
545        Ok(Some((user, verification)))
546    }
547
548    async fn get_current_user<DB: DatabaseAdapter>(
549        &self,
550        req: &AuthRequest,
551        ctx: &AuthContext<DB>,
552    ) -> AuthResult<Option<DB::User>> {
553        let session_manager =
554            better_auth_core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
555
556        if let Some(token) = session_manager.extract_session_token(req)
557            && let Some(session) = session_manager.get_session(&token).await?
558        {
559            return ctx.database.get_user_by_id(session.user_id()).await;
560        }
561
562        Ok(None)
563    }
564
565    fn validate_password<DB: DatabaseAdapter>(
566        &self,
567        password: &str,
568        ctx: &AuthContext<DB>,
569    ) -> AuthResult<()> {
570        let config = &ctx.config.password;
571
572        if password.len() < config.min_length {
573            return Err(AuthError::bad_request(format!(
574                "Password must be at least {} characters long",
575                config.min_length
576            )));
577        }
578
579        if config.require_uppercase && !password.chars().any(|c| c.is_uppercase()) {
580            return Err(AuthError::bad_request(
581                "Password must contain at least one uppercase letter",
582            ));
583        }
584
585        if config.require_lowercase && !password.chars().any(|c| c.is_lowercase()) {
586            return Err(AuthError::bad_request(
587                "Password must contain at least one lowercase letter",
588            ));
589        }
590
591        if config.require_numbers && !password.chars().any(|c| c.is_ascii_digit()) {
592            return Err(AuthError::bad_request(
593                "Password must contain at least one number",
594            ));
595        }
596
597        if config.require_special
598            && !password
599                .chars()
600                .any(|c| "!@#$%^&*()_+-=[]{}|;:,.<>?".contains(c))
601        {
602            return Err(AuthError::bad_request(
603                "Password must contain at least one special character",
604            ));
605        }
606
607        Ok(())
608    }
609
610    fn hash_password(&self, password: &str) -> AuthResult<String> {
611        use argon2::password_hash::{SaltString, rand_core::OsRng};
612        use argon2::{Argon2, PasswordHasher};
613
614        let salt = SaltString::generate(&mut OsRng);
615        let argon2 = Argon2::default();
616
617        let password_hash = argon2
618            .hash_password(password.as_bytes(), &salt)
619            .map_err(|e| AuthError::PasswordHash(format!("Failed to hash password: {}", e)))?;
620
621        Ok(password_hash.to_string())
622    }
623
624    fn create_session_cookie<DB: DatabaseAdapter>(
625        &self,
626        token: &str,
627        ctx: &AuthContext<DB>,
628    ) -> String {
629        let session_config = &ctx.config.session;
630        let secure = if session_config.cookie_secure {
631            "; Secure"
632        } else {
633            ""
634        };
635        let http_only = if session_config.cookie_http_only {
636            "; HttpOnly"
637        } else {
638            ""
639        };
640        let same_site = match session_config.cookie_same_site {
641            better_auth_core::config::SameSite::Strict => "; SameSite=Strict",
642            better_auth_core::config::SameSite::Lax => "; SameSite=Lax",
643            better_auth_core::config::SameSite::None => "; SameSite=None",
644        };
645
646        let expires = chrono::Utc::now() + session_config.expires_in;
647        let expires_str = expires.format("%a, %d %b %Y %H:%M:%S GMT");
648
649        format!(
650            "{}={}; Path=/; Expires={}{}{}{}",
651            session_config.cookie_name, token, expires_str, secure, http_only, same_site
652        )
653    }
654
655    fn verify_password(&self, password: &str, hash: &str) -> AuthResult<()> {
656        use argon2::password_hash::PasswordHash;
657        use argon2::{Argon2, PasswordVerifier};
658
659        let parsed_hash = PasswordHash::new(hash)
660            .map_err(|e| AuthError::PasswordHash(format!("Invalid password hash: {}", e)))?;
661
662        let argon2 = Argon2::default();
663        argon2
664            .verify_password(password.as_bytes(), &parsed_hash)
665            .map_err(|_| AuthError::InvalidCredentials)?;
666
667        Ok(())
668    }
669}
670
671#[cfg(test)]
672mod tests {
673    use super::*;
674    use better_auth_core::adapters::{MemoryDatabaseAdapter, SessionOps, UserOps, VerificationOps};
675    use better_auth_core::config::{Argon2Config, AuthConfig, PasswordConfig};
676    use better_auth_core::{CreateSession, CreateUser, CreateVerification, Session, User};
677    use chrono::{Duration, Utc};
678    use std::collections::HashMap;
679    use std::sync::Arc;
680
681    async fn create_test_context_with_user() -> (AuthContext<MemoryDatabaseAdapter>, User, Session)
682    {
683        let mut config = AuthConfig::new("test-secret-key-at-least-32-chars-long");
684        config.password = PasswordConfig {
685            min_length: 8,
686            require_uppercase: true,
687            require_lowercase: true,
688            require_numbers: true,
689            require_special: true,
690            argon2_config: Argon2Config::default(),
691        };
692
693        let config = Arc::new(config);
694        let database = Arc::new(MemoryDatabaseAdapter::new());
695        let ctx = AuthContext::new(config.clone(), database.clone());
696
697        // Create test user with hashed password
698        let plugin = PasswordManagementPlugin::new();
699        let password_hash = plugin.hash_password("Password123!").unwrap();
700
701        let metadata = serde_json::json!({
702            "password_hash": password_hash,
703        });
704
705        let create_user = CreateUser::new()
706            .with_email("test@example.com")
707            .with_name("Test User")
708            .with_metadata(metadata);
709        let user = database.create_user(create_user).await.unwrap();
710
711        // Create test session
712        let create_session = CreateSession {
713            user_id: user.id.clone(),
714            expires_at: Utc::now() + Duration::hours(24),
715            ip_address: Some("127.0.0.1".to_string()),
716            user_agent: Some("test-agent".to_string()),
717            impersonated_by: None,
718            active_organization_id: None,
719        };
720        let session = database.create_session(create_session).await.unwrap();
721
722        (ctx, user, session)
723    }
724
725    fn create_auth_request(
726        method: HttpMethod,
727        path: &str,
728        token: Option<&str>,
729        body: Option<Vec<u8>>,
730    ) -> AuthRequest {
731        let mut headers = HashMap::new();
732        if let Some(token) = token {
733            headers.insert("authorization".to_string(), format!("Bearer {}", token));
734        }
735
736        AuthRequest {
737            method,
738            path: path.to_string(),
739            headers,
740            body,
741            query: HashMap::new(),
742        }
743    }
744
745    #[tokio::test]
746    async fn test_forget_password_success() {
747        let plugin = PasswordManagementPlugin::new();
748        let (ctx, _user, _session) = create_test_context_with_user().await;
749
750        let body = serde_json::json!({
751            "email": "test@example.com",
752            "redirectTo": "http://localhost:3000/reset"
753        });
754
755        let req = create_auth_request(
756            HttpMethod::Post,
757            "/forget-password",
758            None,
759            Some(body.to_string().into_bytes()),
760        );
761
762        let response = plugin.handle_forget_password(&req, &ctx).await.unwrap();
763        assert_eq!(response.status, 200);
764
765        let body_str = String::from_utf8(response.body).unwrap();
766        let response_data: StatusResponse = serde_json::from_str(&body_str).unwrap();
767        assert!(response_data.status);
768    }
769
770    #[tokio::test]
771    async fn test_forget_password_unknown_email() {
772        let plugin = PasswordManagementPlugin::new();
773        let (ctx, _user, _session) = create_test_context_with_user().await;
774
775        let body = serde_json::json!({
776            "email": "unknown@example.com"
777        });
778
779        let req = create_auth_request(
780            HttpMethod::Post,
781            "/forget-password",
782            None,
783            Some(body.to_string().into_bytes()),
784        );
785
786        let response = plugin.handle_forget_password(&req, &ctx).await.unwrap();
787        assert_eq!(response.status, 200);
788
789        // Should return success even for unknown emails (security)
790        let body_str = String::from_utf8(response.body).unwrap();
791        let response_data: StatusResponse = serde_json::from_str(&body_str).unwrap();
792        assert!(response_data.status);
793    }
794
795    #[tokio::test]
796    async fn test_reset_password_success() {
797        let plugin = PasswordManagementPlugin::new();
798        let (ctx, user, _session) = create_test_context_with_user().await;
799
800        // Create verification token
801        let reset_token = format!("reset_{}", uuid::Uuid::new_v4());
802        let create_verification = CreateVerification {
803            identifier: user.email.clone().unwrap(),
804            value: reset_token.clone(),
805            expires_at: Utc::now() + Duration::hours(24),
806        };
807        ctx.database
808            .create_verification(create_verification)
809            .await
810            .unwrap();
811
812        let body = serde_json::json!({
813            "newPassword": "NewPassword123!",
814            "token": reset_token
815        });
816
817        let req = create_auth_request(
818            HttpMethod::Post,
819            "/reset-password",
820            None,
821            Some(body.to_string().into_bytes()),
822        );
823
824        let response = plugin.handle_reset_password(&req, &ctx).await.unwrap();
825        assert_eq!(response.status, 200);
826
827        let body_str = String::from_utf8(response.body).unwrap();
828        let response_data: StatusResponse = serde_json::from_str(&body_str).unwrap();
829        assert!(response_data.status);
830
831        // Verify password was updated
832        let updated_user = ctx
833            .database
834            .get_user_by_id(&user.id)
835            .await
836            .unwrap()
837            .unwrap();
838        let stored_hash = updated_user
839            .metadata
840            .get("password_hash")
841            .unwrap()
842            .as_str()
843            .unwrap();
844        assert!(
845            plugin
846                .verify_password("NewPassword123!", stored_hash)
847                .is_ok()
848        );
849
850        // Verify token was deleted
851        let verification_check = ctx
852            .database
853            .get_verification_by_value(&reset_token)
854            .await
855            .unwrap();
856        assert!(verification_check.is_none());
857    }
858
859    #[tokio::test]
860    async fn test_reset_password_invalid_token() {
861        let plugin = PasswordManagementPlugin::new();
862        let (ctx, _user, _session) = create_test_context_with_user().await;
863
864        let body = serde_json::json!({
865            "newPassword": "NewPassword123!",
866            "token": "invalid_token"
867        });
868
869        let req = create_auth_request(
870            HttpMethod::Post,
871            "/reset-password",
872            None,
873            Some(body.to_string().into_bytes()),
874        );
875
876        let err = plugin.handle_reset_password(&req, &ctx).await.unwrap_err();
877        assert_eq!(err.status_code(), 400);
878    }
879
880    #[tokio::test]
881    async fn test_reset_password_weak_password() {
882        let plugin = PasswordManagementPlugin::new();
883        let (ctx, user, _session) = create_test_context_with_user().await;
884
885        // Create verification token
886        let reset_token = format!("reset_{}", uuid::Uuid::new_v4());
887        let create_verification = CreateVerification {
888            identifier: user.email.clone().unwrap(),
889            value: reset_token.clone(),
890            expires_at: Utc::now() + Duration::hours(24),
891        };
892        ctx.database
893            .create_verification(create_verification)
894            .await
895            .unwrap();
896
897        let body = serde_json::json!({
898            "newPassword": "weak",
899            "token": reset_token
900        });
901
902        let req = create_auth_request(
903            HttpMethod::Post,
904            "/reset-password",
905            None,
906            Some(body.to_string().into_bytes()),
907        );
908
909        let err = plugin.handle_reset_password(&req, &ctx).await.unwrap_err();
910        assert_eq!(err.status_code(), 400);
911    }
912
913    #[tokio::test]
914    async fn test_change_password_success() {
915        let plugin = PasswordManagementPlugin::new();
916        let (ctx, _user, session) = create_test_context_with_user().await;
917
918        let body = serde_json::json!({
919            "currentPassword": "Password123!",
920            "newPassword": "NewPassword123!",
921            "revokeOtherSessions": "false"
922        });
923
924        let req = create_auth_request(
925            HttpMethod::Post,
926            "/change-password",
927            Some(&session.token),
928            Some(body.to_string().into_bytes()),
929        );
930
931        let response = plugin.handle_change_password(&req, &ctx).await.unwrap();
932        assert_eq!(response.status, 200);
933
934        let body_str = String::from_utf8(response.body).unwrap();
935        let response_data: serde_json::Value = serde_json::from_str(&body_str).unwrap();
936        assert!(response_data["token"].is_null()); // No new token when not revoking sessions
937
938        // Verify password was updated by checking the database directly
939        let user_id = response_data["user"]["id"].as_str().unwrap();
940        let updated_user = ctx.database.get_user_by_id(user_id).await.unwrap().unwrap();
941        let stored_hash = updated_user
942            .metadata
943            .get("password_hash")
944            .unwrap()
945            .as_str()
946            .unwrap();
947        assert!(
948            plugin
949                .verify_password("NewPassword123!", stored_hash)
950                .is_ok()
951        );
952    }
953
954    #[tokio::test]
955    async fn test_change_password_with_session_revocation() {
956        let plugin = PasswordManagementPlugin::new();
957        let (ctx, _user, session) = create_test_context_with_user().await;
958
959        let body = serde_json::json!({
960            "currentPassword": "Password123!",
961            "newPassword": "NewPassword123!",
962            "revokeOtherSessions": "true"
963        });
964
965        let req = create_auth_request(
966            HttpMethod::Post,
967            "/change-password",
968            Some(&session.token),
969            Some(body.to_string().into_bytes()),
970        );
971
972        let response = plugin.handle_change_password(&req, &ctx).await.unwrap();
973        assert_eq!(response.status, 200);
974
975        let body_str = String::from_utf8(response.body).unwrap();
976        let response_data: serde_json::Value = serde_json::from_str(&body_str).unwrap();
977        assert!(response_data["token"].is_string()); // New token when revoking sessions
978    }
979
980    #[tokio::test]
981    async fn test_change_password_sets_cookie_on_session_revocation() {
982        let plugin = PasswordManagementPlugin::new();
983        let (ctx, _user, session) = create_test_context_with_user().await;
984
985        let body = serde_json::json!({
986            "currentPassword": "Password123!",
987            "newPassword": "NewPassword123!",
988            "revokeOtherSessions": true
989        });
990
991        let req = create_auth_request(
992            HttpMethod::Post,
993            "/change-password",
994            Some(&session.token),
995            Some(body.to_string().into_bytes()),
996        );
997
998        let response = plugin.handle_change_password(&req, &ctx).await.unwrap();
999        assert_eq!(response.status, 200);
1000
1001        // Verify Set-Cookie header is present
1002        let set_cookie = response.headers.get("Set-Cookie");
1003        assert!(
1004            set_cookie.is_some(),
1005            "Set-Cookie header must be set when revokeOtherSessions is true"
1006        );
1007
1008        let cookie_value = set_cookie.unwrap();
1009        assert!(
1010            cookie_value.contains(&ctx.config.session.cookie_name),
1011            "Cookie must contain the session cookie name"
1012        );
1013        assert!(
1014            cookie_value.contains("Path=/"),
1015            "Cookie must contain Path=/"
1016        );
1017        assert!(
1018            cookie_value.contains("Expires="),
1019            "Cookie must contain an expiration"
1020        );
1021    }
1022
1023    #[tokio::test]
1024    async fn test_change_password_no_cookie_without_revocation() {
1025        let plugin = PasswordManagementPlugin::new();
1026        let (ctx, _user, session) = create_test_context_with_user().await;
1027
1028        let body = serde_json::json!({
1029            "currentPassword": "Password123!",
1030            "newPassword": "NewPassword123!"
1031        });
1032
1033        let req = create_auth_request(
1034            HttpMethod::Post,
1035            "/change-password",
1036            Some(&session.token),
1037            Some(body.to_string().into_bytes()),
1038        );
1039
1040        let response = plugin.handle_change_password(&req, &ctx).await.unwrap();
1041        assert_eq!(response.status, 200);
1042
1043        // Verify Set-Cookie header is NOT present when not revoking sessions
1044        let set_cookie = response.headers.get("Set-Cookie");
1045        assert!(
1046            set_cookie.is_none(),
1047            "Set-Cookie header must not be set when revokeOtherSessions is not true"
1048        );
1049    }
1050
1051    #[tokio::test]
1052    async fn test_change_password_revoke_with_boolean() {
1053        let plugin = PasswordManagementPlugin::new();
1054        let (ctx, _user, session) = create_test_context_with_user().await;
1055
1056        // Send revokeOtherSessions as a boolean (as better-auth TS SDK does)
1057        let body = serde_json::json!({
1058            "currentPassword": "Password123!",
1059            "newPassword": "NewPassword123!",
1060            "revokeOtherSessions": true
1061        });
1062
1063        let req = create_auth_request(
1064            HttpMethod::Post,
1065            "/change-password",
1066            Some(&session.token),
1067            Some(body.to_string().into_bytes()),
1068        );
1069
1070        let response = plugin.handle_change_password(&req, &ctx).await.unwrap();
1071        assert_eq!(response.status, 200);
1072
1073        let body_str = String::from_utf8(response.body).unwrap();
1074        let response_data: serde_json::Value = serde_json::from_str(&body_str).unwrap();
1075        assert!(
1076            response_data["token"].is_string(),
1077            "New token must be returned when revokeOtherSessions is boolean true"
1078        );
1079    }
1080
1081    #[tokio::test]
1082    async fn test_change_password_wrong_current_password() {
1083        let plugin = PasswordManagementPlugin::new();
1084        let (ctx, _user, session) = create_test_context_with_user().await;
1085
1086        let body = serde_json::json!({
1087            "currentPassword": "WrongPassword123!",
1088            "newPassword": "NewPassword123!"
1089        });
1090
1091        let req = create_auth_request(
1092            HttpMethod::Post,
1093            "/change-password",
1094            Some(&session.token),
1095            Some(body.to_string().into_bytes()),
1096        );
1097
1098        let err = plugin.handle_change_password(&req, &ctx).await.unwrap_err();
1099        assert_eq!(err.status_code(), 401);
1100    }
1101
1102    #[tokio::test]
1103    async fn test_change_password_unauthorized() {
1104        let plugin = PasswordManagementPlugin::new();
1105        let (ctx, _user, _session) = create_test_context_with_user().await;
1106
1107        let body = serde_json::json!({
1108            "currentPassword": "Password123!",
1109            "newPassword": "NewPassword123!"
1110        });
1111
1112        let req = create_auth_request(
1113            HttpMethod::Post,
1114            "/change-password",
1115            None,
1116            Some(body.to_string().into_bytes()),
1117        );
1118
1119        let err = plugin.handle_change_password(&req, &ctx).await.unwrap_err();
1120        assert_eq!(err.status_code(), 401);
1121    }
1122
1123    #[tokio::test]
1124    async fn test_reset_password_token_endpoint_success() {
1125        let plugin = PasswordManagementPlugin::new();
1126        let (ctx, user, _session) = create_test_context_with_user().await;
1127
1128        // Create verification token
1129        let reset_token = format!("reset_{}", uuid::Uuid::new_v4());
1130        let create_verification = CreateVerification {
1131            identifier: user.email.clone().unwrap(),
1132            value: reset_token.clone(),
1133            expires_at: Utc::now() + Duration::hours(24),
1134        };
1135        ctx.database
1136            .create_verification(create_verification)
1137            .await
1138            .unwrap();
1139
1140        let req = create_auth_request(HttpMethod::Get, "/reset-password/token", None, None);
1141
1142        let response = plugin
1143            .handle_reset_password_token(&reset_token, &req, &ctx)
1144            .await
1145            .unwrap();
1146        assert_eq!(response.status, 200);
1147
1148        let body_str = String::from_utf8(response.body).unwrap();
1149        let response_data: ResetPasswordTokenResponse = serde_json::from_str(&body_str).unwrap();
1150        assert_eq!(response_data.token, reset_token);
1151    }
1152
1153    #[tokio::test]
1154    async fn test_reset_password_token_endpoint_with_callback() {
1155        let plugin = PasswordManagementPlugin::new();
1156        let (ctx, user, _session) = create_test_context_with_user().await;
1157
1158        // Create verification token
1159        let reset_token = format!("reset_{}", uuid::Uuid::new_v4());
1160        let create_verification = CreateVerification {
1161            identifier: user.email.clone().unwrap(),
1162            value: reset_token.clone(),
1163            expires_at: Utc::now() + Duration::hours(24),
1164        };
1165        ctx.database
1166            .create_verification(create_verification)
1167            .await
1168            .unwrap();
1169
1170        let mut query = HashMap::new();
1171        query.insert(
1172            "callbackURL".to_string(),
1173            "http://localhost:3000/reset".to_string(),
1174        );
1175
1176        let req = AuthRequest {
1177            method: HttpMethod::Get,
1178            path: "/reset-password/token".to_string(),
1179            headers: HashMap::new(),
1180            body: None,
1181            query,
1182        };
1183
1184        let response = plugin
1185            .handle_reset_password_token(&reset_token, &req, &ctx)
1186            .await
1187            .unwrap();
1188        assert_eq!(response.status, 302);
1189
1190        // Check redirect URL
1191        let location_header = response
1192            .headers
1193            .iter()
1194            .find(|(key, _)| *key == "Location")
1195            .map(|(_, value)| value);
1196        assert!(location_header.is_some());
1197        assert!(
1198            location_header
1199                .unwrap()
1200                .contains("http://localhost:3000/reset")
1201        );
1202        assert!(location_header.unwrap().contains(&reset_token));
1203    }
1204
1205    #[tokio::test]
1206    async fn test_reset_password_token_endpoint_invalid_token() {
1207        let plugin = PasswordManagementPlugin::new();
1208        let (ctx, _user, _session) = create_test_context_with_user().await;
1209
1210        let req = create_auth_request(HttpMethod::Get, "/reset-password/token", None, None);
1211
1212        let err = plugin
1213            .handle_reset_password_token("invalid_token", &req, &ctx)
1214            .await
1215            .unwrap_err();
1216        assert_eq!(err.status_code(), 400);
1217    }
1218
1219    #[tokio::test]
1220    async fn test_password_validation() {
1221        let plugin = PasswordManagementPlugin::new();
1222        let mut config = AuthConfig::new("test-secret");
1223        config.password = PasswordConfig {
1224            min_length: 8,
1225            require_uppercase: true,
1226            require_lowercase: true,
1227            require_numbers: true,
1228            require_special: true,
1229            argon2_config: Argon2Config::default(),
1230        };
1231        let ctx = AuthContext::new(Arc::new(config), Arc::new(MemoryDatabaseAdapter::new()));
1232
1233        // Test valid password
1234        assert!(plugin.validate_password("Password123!", &ctx).is_ok());
1235
1236        // Test too short
1237        assert!(plugin.validate_password("Pass1!", &ctx).is_err());
1238
1239        // Test missing uppercase
1240        assert!(plugin.validate_password("password123!", &ctx).is_err());
1241
1242        // Test missing lowercase
1243        assert!(plugin.validate_password("PASSWORD123!", &ctx).is_err());
1244
1245        // Test missing number
1246        assert!(plugin.validate_password("Password!", &ctx).is_err());
1247
1248        // Test missing special character
1249        assert!(plugin.validate_password("Password123", &ctx).is_err());
1250    }
1251
1252    #[tokio::test]
1253    async fn test_password_hashing_and_verification() {
1254        let plugin = PasswordManagementPlugin::new();
1255
1256        let password = "TestPassword123!";
1257        let hash = plugin.hash_password(password).unwrap();
1258
1259        // Should verify correctly
1260        assert!(plugin.verify_password(password, &hash).is_ok());
1261
1262        // Should fail with wrong password
1263        assert!(plugin.verify_password("WrongPassword123!", &hash).is_err());
1264    }
1265
1266    #[tokio::test]
1267    async fn test_plugin_routes() {
1268        let plugin = PasswordManagementPlugin::new();
1269        let routes = AuthPlugin::<MemoryDatabaseAdapter>::routes(&plugin);
1270
1271        assert_eq!(routes.len(), 5);
1272        assert!(
1273            routes
1274                .iter()
1275                .any(|r| r.path == "/forget-password" && r.method == HttpMethod::Post)
1276        );
1277        assert!(
1278            routes
1279                .iter()
1280                .any(|r| r.path == "/reset-password" && r.method == HttpMethod::Post)
1281        );
1282        assert!(
1283            routes
1284                .iter()
1285                .any(|r| r.path == "/reset-password/{token}" && r.method == HttpMethod::Get)
1286        );
1287        assert!(
1288            routes
1289                .iter()
1290                .any(|r| r.path == "/change-password" && r.method == HttpMethod::Post)
1291        );
1292    }
1293
1294    #[tokio::test]
1295    async fn test_plugin_on_request_routing() {
1296        let plugin = PasswordManagementPlugin::new();
1297        let (ctx, _user, session) = create_test_context_with_user().await;
1298
1299        // Test forget password
1300        let body = serde_json::json!({"email": "test@example.com"});
1301        let req = create_auth_request(
1302            HttpMethod::Post,
1303            "/forget-password",
1304            None,
1305            Some(body.to_string().into_bytes()),
1306        );
1307        let response = plugin.on_request(&req, &ctx).await.unwrap();
1308        assert!(response.is_some());
1309        assert_eq!(response.unwrap().status, 200);
1310
1311        // Test change password
1312        let body = serde_json::json!({
1313            "currentPassword": "Password123!",
1314            "newPassword": "NewPassword123!"
1315        });
1316        let req = create_auth_request(
1317            HttpMethod::Post,
1318            "/change-password",
1319            Some(&session.token),
1320            Some(body.to_string().into_bytes()),
1321        );
1322        let response = plugin.on_request(&req, &ctx).await.unwrap();
1323        assert!(response.is_some());
1324        assert_eq!(response.unwrap().status, 200);
1325
1326        // Test invalid route
1327        let req = create_auth_request(HttpMethod::Get, "/invalid-route", None, None);
1328        let response = plugin.on_request(&req, &ctx).await.unwrap();
1329        assert!(response.is_none());
1330    }
1331
1332    #[tokio::test]
1333    async fn test_configuration() {
1334        let config = PasswordManagementConfig {
1335            reset_token_expiry_hours: 48,
1336            require_current_password: false,
1337            send_email_notifications: false,
1338        };
1339
1340        let plugin = PasswordManagementPlugin::with_config(config);
1341        assert_eq!(plugin.config.reset_token_expiry_hours, 48);
1342        assert!(!plugin.config.require_current_password);
1343        assert!(!plugin.config.send_email_notifications);
1344    }
1345}