1use async_trait::async_trait;
2use chrono::{Duration, Utc};
3use serde::{Deserialize, Deserializer, Serialize};
4use uuid::Uuid;
5use validator::Validate;
6
7use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
8use better_auth_core::{AuthError, AuthResult};
9use better_auth_core::{AuthRequest, AuthResponse, CreateVerification, HttpMethod, UpdateUser};
10use better_auth_core::{AuthSession, AuthUser, AuthVerification, DatabaseAdapter};
11
12pub struct PasswordManagementPlugin {
14 config: PasswordManagementConfig,
15}
16
17#[derive(Debug, Clone)]
18pub struct PasswordManagementConfig {
19 pub reset_token_expiry_hours: i64,
20 pub require_current_password: bool,
21 pub send_email_notifications: bool,
22}
23
24#[derive(Debug, Deserialize, Validate)]
26struct ForgetPasswordRequest {
27 #[validate(email(message = "Invalid email address"))]
28 email: String,
29 #[serde(rename = "redirectTo")]
30 redirect_to: Option<String>,
31}
32
33#[derive(Debug, Deserialize, Validate)]
34struct ResetPasswordRequest {
35 #[serde(rename = "newPassword")]
36 #[validate(length(min = 1, message = "New password is required"))]
37 new_password: String,
38 token: Option<String>,
39}
40
41#[derive(Debug, Deserialize, Validate)]
42struct SetPasswordRequest {
43 #[serde(rename = "newPassword")]
44 #[validate(length(min = 1, message = "New password is required"))]
45 new_password: String,
46}
47
48#[derive(Debug, Deserialize, Validate)]
49struct ChangePasswordRequest {
50 #[serde(rename = "newPassword")]
51 #[validate(length(min = 1, message = "New password is required"))]
52 new_password: String,
53 #[serde(rename = "currentPassword")]
54 #[validate(length(min = 1, message = "Current password is required"))]
55 current_password: String,
56 #[serde(
57 default,
58 rename = "revokeOtherSessions",
59 deserialize_with = "deserialize_bool_or_string"
60 )]
61 revoke_other_sessions: Option<bool>,
62}
63
64fn deserialize_bool_or_string<'de, D>(deserializer: D) -> Result<Option<bool>, D::Error>
68where
69 D: Deserializer<'de>,
70{
71 let value: Option<serde_json::Value> = Option::deserialize(deserializer)?;
72 match value {
73 None => Ok(None),
74 Some(serde_json::Value::Bool(b)) => Ok(Some(b)),
75 Some(serde_json::Value::String(s)) => match s.to_lowercase().as_str() {
76 "true" => Ok(Some(true)),
77 "false" => Ok(Some(false)),
78 _ => Err(serde::de::Error::custom(format!(
79 "invalid value for revokeOtherSessions: {}",
80 s
81 ))),
82 },
83 Some(other) => Err(serde::de::Error::custom(format!(
84 "invalid type for revokeOtherSessions: {}",
85 other
86 ))),
87 }
88}
89
90#[derive(Debug, Serialize, Deserialize)]
92struct StatusResponse {
93 status: bool,
94}
95
96#[derive(Debug, Serialize)]
97struct ChangePasswordResponse<U: Serialize> {
98 token: Option<String>,
99 user: U,
100}
101
102#[derive(Debug, Serialize, Deserialize)]
103struct ResetPasswordTokenResponse {
104 token: String,
105}
106
107impl PasswordManagementPlugin {
108 pub fn new() -> Self {
109 Self {
110 config: PasswordManagementConfig::default(),
111 }
112 }
113
114 pub fn with_config(config: PasswordManagementConfig) -> Self {
115 Self { config }
116 }
117
118 pub fn reset_token_expiry_hours(mut self, hours: i64) -> Self {
119 self.config.reset_token_expiry_hours = hours;
120 self
121 }
122
123 pub fn require_current_password(mut self, require: bool) -> Self {
124 self.config.require_current_password = require;
125 self
126 }
127
128 pub fn send_email_notifications(mut self, send: bool) -> Self {
129 self.config.send_email_notifications = send;
130 self
131 }
132}
133
134impl Default for PasswordManagementPlugin {
135 fn default() -> Self {
136 Self::new()
137 }
138}
139
140impl Default for PasswordManagementConfig {
141 fn default() -> Self {
142 Self {
143 reset_token_expiry_hours: 24, require_current_password: true,
145 send_email_notifications: true,
146 }
147 }
148}
149
150#[async_trait]
151impl<DB: DatabaseAdapter> AuthPlugin<DB> for PasswordManagementPlugin {
152 fn name(&self) -> &'static str {
153 "password-management"
154 }
155
156 fn routes(&self) -> Vec<AuthRoute> {
157 vec![
158 AuthRoute::post("/forget-password", "forget_password"),
159 AuthRoute::post("/reset-password", "reset_password"),
160 AuthRoute::get("/reset-password/{token}", "reset_password_token"),
161 AuthRoute::post("/change-password", "change_password"),
162 AuthRoute::post("/set-password", "set_password"),
163 ]
164 }
165
166 async fn on_request(
167 &self,
168 req: &AuthRequest,
169 ctx: &AuthContext<DB>,
170 ) -> AuthResult<Option<AuthResponse>> {
171 match (req.method(), req.path()) {
172 (HttpMethod::Post, "/forget-password") => {
173 Ok(Some(self.handle_forget_password(req, ctx).await?))
174 }
175 (HttpMethod::Post, "/reset-password") => {
176 Ok(Some(self.handle_reset_password(req, ctx).await?))
177 }
178 (HttpMethod::Post, "/change-password") => {
179 Ok(Some(self.handle_change_password(req, ctx).await?))
180 }
181 (HttpMethod::Post, "/set-password") => {
182 Ok(Some(self.handle_set_password(req, ctx).await?))
183 }
184 (HttpMethod::Get, path) if path.starts_with("/reset-password/") => {
185 let token = &path[16..]; Ok(Some(
187 self.handle_reset_password_token(token, req, ctx).await?,
188 ))
189 }
190 _ => Ok(None),
191 }
192 }
193}
194
195impl PasswordManagementPlugin {
197 async fn handle_forget_password<DB: DatabaseAdapter>(
198 &self,
199 req: &AuthRequest,
200 ctx: &AuthContext<DB>,
201 ) -> AuthResult<AuthResponse> {
202 let forget_req: ForgetPasswordRequest = match better_auth_core::validate_request_body(req) {
203 Ok(v) => v,
204 Err(resp) => return Ok(resp),
205 };
206
207 let user = match ctx.database.get_user_by_email(&forget_req.email).await? {
209 Some(user) => user,
210 None => {
211 let response = StatusResponse { status: true };
213 return Ok(AuthResponse::json(200, &response)?);
214 }
215 };
216
217 let reset_token = format!("reset_{}", Uuid::new_v4());
219 let expires_at = Utc::now() + Duration::hours(self.config.reset_token_expiry_hours);
220
221 let create_verification = CreateVerification {
223 identifier: user.email().unwrap_or_default().to_string(),
224 value: reset_token.clone(),
225 expires_at,
226 };
227
228 ctx.database
229 .create_verification(create_verification)
230 .await?;
231
232 if self.config.send_email_notifications {
234 let reset_url = if let Some(redirect_to) = &forget_req.redirect_to {
235 format!("{}?token={}", redirect_to, reset_token)
236 } else {
237 format!(
238 "{}/reset-password?token={}",
239 ctx.config.base_url, reset_token
240 )
241 };
242
243 if let Ok(provider) = ctx.email_provider() {
244 let subject = "Reset your password";
245 let html = format!(
246 "<p>Click the link below to reset your password:</p>\
247 <p><a href=\"{url}\">Reset Password</a></p>",
248 url = reset_url
249 );
250 let text = format!("Reset your password: {}", reset_url);
251
252 if let Err(e) = provider
253 .send(&forget_req.email, subject, &html, &text)
254 .await
255 {
256 eprintln!(
257 "[password-management] Failed to send reset email to {}: {}",
258 forget_req.email, e
259 );
260 }
261 } else {
262 eprintln!(
263 "[password-management] No email provider configured, skipping password reset email for {}",
264 forget_req.email
265 );
266 }
267 }
268
269 let response = StatusResponse { status: true };
270 Ok(AuthResponse::json(200, &response)?)
271 }
272
273 async fn handle_reset_password<DB: DatabaseAdapter>(
274 &self,
275 req: &AuthRequest,
276 ctx: &AuthContext<DB>,
277 ) -> AuthResult<AuthResponse> {
278 let reset_req: ResetPasswordRequest = match better_auth_core::validate_request_body(req) {
279 Ok(v) => v,
280 Err(resp) => return Ok(resp),
281 };
282
283 self.validate_password(&reset_req.new_password, ctx)?;
285
286 let token = reset_req.token.as_deref().unwrap_or("");
288 if token.is_empty() {
289 return Err(AuthError::bad_request("Reset token is required"));
290 }
291
292 let (user, verification) = self
293 .find_user_by_reset_token(token, ctx)
294 .await?
295 .ok_or_else(|| AuthError::bad_request("Invalid or expired reset token"))?;
296
297 let password_hash = self.hash_password(&reset_req.new_password)?;
299
300 let mut metadata = user.metadata().clone();
302 metadata["password_hash"] = serde_json::Value::String(password_hash);
303
304 let update_user = UpdateUser {
305 email: None,
306 name: None,
307 image: None,
308 email_verified: None,
309 username: None,
310 display_username: None,
311 role: None,
312 banned: None,
313 ban_reason: None,
314 ban_expires: None,
315 two_factor_enabled: None,
316 metadata: Some(metadata),
317 };
318
319 ctx.database.update_user(user.id(), update_user).await?;
320
321 ctx.database.delete_verification(verification.id()).await?;
323
324 ctx.database.delete_user_sessions(user.id()).await?;
326
327 let response = StatusResponse { status: true };
328 Ok(AuthResponse::json(200, &response)?)
329 }
330
331 async fn handle_change_password<DB: DatabaseAdapter>(
332 &self,
333 req: &AuthRequest,
334 ctx: &AuthContext<DB>,
335 ) -> AuthResult<AuthResponse> {
336 let change_req: ChangePasswordRequest = match better_auth_core::validate_request_body(req) {
337 Ok(v) => v,
338 Err(resp) => return Ok(resp),
339 };
340
341 let user = self
343 .get_current_user(req, ctx)
344 .await?
345 .ok_or(AuthError::Unauthenticated)?;
346
347 if self.config.require_current_password {
349 let stored_hash = user
350 .metadata()
351 .get("password_hash")
352 .and_then(|v| v.as_str())
353 .ok_or_else(|| AuthError::bad_request("No password set for this user"))?;
354
355 self.verify_password(&change_req.current_password, stored_hash)
356 .map_err(|_| AuthError::InvalidCredentials)?;
357 }
358
359 self.validate_password(&change_req.new_password, ctx)?;
361
362 let password_hash = self.hash_password(&change_req.new_password)?;
364
365 let mut metadata = user.metadata().clone();
367 metadata["password_hash"] = serde_json::Value::String(password_hash);
368
369 let update_user = UpdateUser {
370 email: None,
371 name: None,
372 image: None,
373 email_verified: None,
374 username: None,
375 display_username: None,
376 role: None,
377 banned: None,
378 ban_reason: None,
379 ban_expires: None,
380 two_factor_enabled: None,
381 metadata: Some(metadata),
382 };
383
384 let updated_user = ctx.database.update_user(user.id(), update_user).await?;
385
386 let new_token = if change_req.revoke_other_sessions == Some(true) {
388 ctx.database.delete_user_sessions(user.id()).await?;
390
391 let session_manager =
393 better_auth_core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
394 let session = session_manager
395 .create_session(&updated_user, None, None)
396 .await?;
397 Some(session.token().to_string())
398 } else {
399 None
400 };
401
402 let response = ChangePasswordResponse {
403 token: new_token.clone(),
404 user: updated_user,
405 };
406
407 let auth_response = AuthResponse::json(200, &response)?;
408
409 if let Some(token) = new_token {
411 let cookie_header = self.create_session_cookie(&token, ctx);
412 Ok(auth_response.with_header("Set-Cookie", cookie_header))
413 } else {
414 Ok(auth_response)
415 }
416 }
417
418 async fn handle_set_password<DB: DatabaseAdapter>(
419 &self,
420 req: &AuthRequest,
421 ctx: &AuthContext<DB>,
422 ) -> AuthResult<AuthResponse> {
423 let set_req: SetPasswordRequest = match better_auth_core::validate_request_body(req) {
424 Ok(v) => v,
425 Err(resp) => return Ok(resp),
426 };
427
428 let user = self
430 .get_current_user(req, ctx)
431 .await?
432 .ok_or(AuthError::Unauthenticated)?;
433
434 if user
436 .metadata()
437 .get("password_hash")
438 .and_then(|v| v.as_str())
439 .is_some()
440 {
441 return Err(AuthError::bad_request(
442 "User already has a password. Use /change-password instead.",
443 ));
444 }
445
446 self.validate_password(&set_req.new_password, ctx)?;
448
449 let password_hash = self.hash_password(&set_req.new_password)?;
451
452 let mut metadata = user.metadata().clone();
453 metadata["password_hash"] = serde_json::Value::String(password_hash);
454
455 let update_user = UpdateUser {
456 email: None,
457 name: None,
458 image: None,
459 email_verified: None,
460 username: None,
461 display_username: None,
462 role: None,
463 banned: None,
464 ban_reason: None,
465 ban_expires: None,
466 two_factor_enabled: None,
467 metadata: Some(metadata),
468 };
469
470 ctx.database.update_user(user.id(), update_user).await?;
471
472 let response = StatusResponse { status: true };
473 Ok(AuthResponse::json(200, &response)?)
474 }
475
476 async fn handle_reset_password_token<DB: DatabaseAdapter>(
477 &self,
478 token: &str,
479 _req: &AuthRequest,
480 ctx: &AuthContext<DB>,
481 ) -> AuthResult<AuthResponse> {
482 let callback_url = _req.query.get("callbackURL").cloned();
484
485 let (_user, _verification) = match self.find_user_by_reset_token(token, ctx).await? {
487 Some((user, verification)) => (user, verification),
488 None => {
489 if let Some(callback_url) = callback_url {
491 let redirect_url = format!("{}?error=INVALID_TOKEN", callback_url);
492 let mut headers = std::collections::HashMap::new();
493 headers.insert("Location".to_string(), redirect_url);
494 return Ok(AuthResponse {
495 status: 302,
496 headers,
497 body: Vec::new(),
498 });
499 }
500
501 return Err(AuthError::bad_request("Invalid or expired reset token"));
502 }
503 };
504
505 if let Some(callback_url) = callback_url {
507 let redirect_url = format!("{}?token={}", callback_url, token);
508 let mut headers = std::collections::HashMap::new();
509 headers.insert("Location".to_string(), redirect_url);
510 return Ok(AuthResponse {
511 status: 302,
512 headers,
513 body: Vec::new(),
514 });
515 }
516
517 let response = ResetPasswordTokenResponse {
519 token: token.to_string(),
520 };
521 Ok(AuthResponse::json(200, &response)?)
522 }
523
524 async fn find_user_by_reset_token<DB: DatabaseAdapter>(
525 &self,
526 token: &str,
527 ctx: &AuthContext<DB>,
528 ) -> AuthResult<Option<(DB::User, DB::Verification)>> {
529 let verification = match ctx.database.get_verification_by_value(token).await? {
531 Some(verification) => verification,
532 None => return Ok(None),
533 };
534
535 let user = match ctx
537 .database
538 .get_user_by_email(verification.identifier())
539 .await?
540 {
541 Some(user) => user,
542 None => return Ok(None),
543 };
544
545 Ok(Some((user, verification)))
546 }
547
548 async fn get_current_user<DB: DatabaseAdapter>(
549 &self,
550 req: &AuthRequest,
551 ctx: &AuthContext<DB>,
552 ) -> AuthResult<Option<DB::User>> {
553 let session_manager =
554 better_auth_core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
555
556 if let Some(token) = session_manager.extract_session_token(req)
557 && let Some(session) = session_manager.get_session(&token).await?
558 {
559 return ctx.database.get_user_by_id(session.user_id()).await;
560 }
561
562 Ok(None)
563 }
564
565 fn validate_password<DB: DatabaseAdapter>(
566 &self,
567 password: &str,
568 ctx: &AuthContext<DB>,
569 ) -> AuthResult<()> {
570 let config = &ctx.config.password;
571
572 if password.len() < config.min_length {
573 return Err(AuthError::bad_request(format!(
574 "Password must be at least {} characters long",
575 config.min_length
576 )));
577 }
578
579 if config.require_uppercase && !password.chars().any(|c| c.is_uppercase()) {
580 return Err(AuthError::bad_request(
581 "Password must contain at least one uppercase letter",
582 ));
583 }
584
585 if config.require_lowercase && !password.chars().any(|c| c.is_lowercase()) {
586 return Err(AuthError::bad_request(
587 "Password must contain at least one lowercase letter",
588 ));
589 }
590
591 if config.require_numbers && !password.chars().any(|c| c.is_ascii_digit()) {
592 return Err(AuthError::bad_request(
593 "Password must contain at least one number",
594 ));
595 }
596
597 if config.require_special
598 && !password
599 .chars()
600 .any(|c| "!@#$%^&*()_+-=[]{}|;:,.<>?".contains(c))
601 {
602 return Err(AuthError::bad_request(
603 "Password must contain at least one special character",
604 ));
605 }
606
607 Ok(())
608 }
609
610 fn hash_password(&self, password: &str) -> AuthResult<String> {
611 use argon2::password_hash::{SaltString, rand_core::OsRng};
612 use argon2::{Argon2, PasswordHasher};
613
614 let salt = SaltString::generate(&mut OsRng);
615 let argon2 = Argon2::default();
616
617 let password_hash = argon2
618 .hash_password(password.as_bytes(), &salt)
619 .map_err(|e| AuthError::PasswordHash(format!("Failed to hash password: {}", e)))?;
620
621 Ok(password_hash.to_string())
622 }
623
624 fn create_session_cookie<DB: DatabaseAdapter>(
625 &self,
626 token: &str,
627 ctx: &AuthContext<DB>,
628 ) -> String {
629 let session_config = &ctx.config.session;
630 let secure = if session_config.cookie_secure {
631 "; Secure"
632 } else {
633 ""
634 };
635 let http_only = if session_config.cookie_http_only {
636 "; HttpOnly"
637 } else {
638 ""
639 };
640 let same_site = match session_config.cookie_same_site {
641 better_auth_core::config::SameSite::Strict => "; SameSite=Strict",
642 better_auth_core::config::SameSite::Lax => "; SameSite=Lax",
643 better_auth_core::config::SameSite::None => "; SameSite=None",
644 };
645
646 let expires = chrono::Utc::now() + session_config.expires_in;
647 let expires_str = expires.format("%a, %d %b %Y %H:%M:%S GMT");
648
649 format!(
650 "{}={}; Path=/; Expires={}{}{}{}",
651 session_config.cookie_name, token, expires_str, secure, http_only, same_site
652 )
653 }
654
655 fn verify_password(&self, password: &str, hash: &str) -> AuthResult<()> {
656 use argon2::password_hash::PasswordHash;
657 use argon2::{Argon2, PasswordVerifier};
658
659 let parsed_hash = PasswordHash::new(hash)
660 .map_err(|e| AuthError::PasswordHash(format!("Invalid password hash: {}", e)))?;
661
662 let argon2 = Argon2::default();
663 argon2
664 .verify_password(password.as_bytes(), &parsed_hash)
665 .map_err(|_| AuthError::InvalidCredentials)?;
666
667 Ok(())
668 }
669}
670
671#[cfg(test)]
672mod tests {
673 use super::*;
674 use better_auth_core::adapters::{MemoryDatabaseAdapter, SessionOps, UserOps, VerificationOps};
675 use better_auth_core::config::{Argon2Config, AuthConfig, PasswordConfig};
676 use better_auth_core::{CreateSession, CreateUser, CreateVerification, Session, User};
677 use chrono::{Duration, Utc};
678 use std::collections::HashMap;
679 use std::sync::Arc;
680
681 async fn create_test_context_with_user() -> (AuthContext<MemoryDatabaseAdapter>, User, Session)
682 {
683 let mut config = AuthConfig::new("test-secret-key-at-least-32-chars-long");
684 config.password = PasswordConfig {
685 min_length: 8,
686 require_uppercase: true,
687 require_lowercase: true,
688 require_numbers: true,
689 require_special: true,
690 argon2_config: Argon2Config::default(),
691 };
692
693 let config = Arc::new(config);
694 let database = Arc::new(MemoryDatabaseAdapter::new());
695 let ctx = AuthContext::new(config.clone(), database.clone());
696
697 let plugin = PasswordManagementPlugin::new();
699 let password_hash = plugin.hash_password("Password123!").unwrap();
700
701 let metadata = serde_json::json!({
702 "password_hash": password_hash,
703 });
704
705 let create_user = CreateUser::new()
706 .with_email("test@example.com")
707 .with_name("Test User")
708 .with_metadata(metadata);
709 let user = database.create_user(create_user).await.unwrap();
710
711 let create_session = CreateSession {
713 user_id: user.id.clone(),
714 expires_at: Utc::now() + Duration::hours(24),
715 ip_address: Some("127.0.0.1".to_string()),
716 user_agent: Some("test-agent".to_string()),
717 impersonated_by: None,
718 active_organization_id: None,
719 };
720 let session = database.create_session(create_session).await.unwrap();
721
722 (ctx, user, session)
723 }
724
725 fn create_auth_request(
726 method: HttpMethod,
727 path: &str,
728 token: Option<&str>,
729 body: Option<Vec<u8>>,
730 ) -> AuthRequest {
731 let mut headers = HashMap::new();
732 if let Some(token) = token {
733 headers.insert("authorization".to_string(), format!("Bearer {}", token));
734 }
735
736 AuthRequest {
737 method,
738 path: path.to_string(),
739 headers,
740 body,
741 query: HashMap::new(),
742 }
743 }
744
745 #[tokio::test]
746 async fn test_forget_password_success() {
747 let plugin = PasswordManagementPlugin::new();
748 let (ctx, _user, _session) = create_test_context_with_user().await;
749
750 let body = serde_json::json!({
751 "email": "test@example.com",
752 "redirectTo": "http://localhost:3000/reset"
753 });
754
755 let req = create_auth_request(
756 HttpMethod::Post,
757 "/forget-password",
758 None,
759 Some(body.to_string().into_bytes()),
760 );
761
762 let response = plugin.handle_forget_password(&req, &ctx).await.unwrap();
763 assert_eq!(response.status, 200);
764
765 let body_str = String::from_utf8(response.body).unwrap();
766 let response_data: StatusResponse = serde_json::from_str(&body_str).unwrap();
767 assert!(response_data.status);
768 }
769
770 #[tokio::test]
771 async fn test_forget_password_unknown_email() {
772 let plugin = PasswordManagementPlugin::new();
773 let (ctx, _user, _session) = create_test_context_with_user().await;
774
775 let body = serde_json::json!({
776 "email": "unknown@example.com"
777 });
778
779 let req = create_auth_request(
780 HttpMethod::Post,
781 "/forget-password",
782 None,
783 Some(body.to_string().into_bytes()),
784 );
785
786 let response = plugin.handle_forget_password(&req, &ctx).await.unwrap();
787 assert_eq!(response.status, 200);
788
789 let body_str = String::from_utf8(response.body).unwrap();
791 let response_data: StatusResponse = serde_json::from_str(&body_str).unwrap();
792 assert!(response_data.status);
793 }
794
795 #[tokio::test]
796 async fn test_reset_password_success() {
797 let plugin = PasswordManagementPlugin::new();
798 let (ctx, user, _session) = create_test_context_with_user().await;
799
800 let reset_token = format!("reset_{}", uuid::Uuid::new_v4());
802 let create_verification = CreateVerification {
803 identifier: user.email.clone().unwrap(),
804 value: reset_token.clone(),
805 expires_at: Utc::now() + Duration::hours(24),
806 };
807 ctx.database
808 .create_verification(create_verification)
809 .await
810 .unwrap();
811
812 let body = serde_json::json!({
813 "newPassword": "NewPassword123!",
814 "token": reset_token
815 });
816
817 let req = create_auth_request(
818 HttpMethod::Post,
819 "/reset-password",
820 None,
821 Some(body.to_string().into_bytes()),
822 );
823
824 let response = plugin.handle_reset_password(&req, &ctx).await.unwrap();
825 assert_eq!(response.status, 200);
826
827 let body_str = String::from_utf8(response.body).unwrap();
828 let response_data: StatusResponse = serde_json::from_str(&body_str).unwrap();
829 assert!(response_data.status);
830
831 let updated_user = ctx
833 .database
834 .get_user_by_id(&user.id)
835 .await
836 .unwrap()
837 .unwrap();
838 let stored_hash = updated_user
839 .metadata
840 .get("password_hash")
841 .unwrap()
842 .as_str()
843 .unwrap();
844 assert!(
845 plugin
846 .verify_password("NewPassword123!", stored_hash)
847 .is_ok()
848 );
849
850 let verification_check = ctx
852 .database
853 .get_verification_by_value(&reset_token)
854 .await
855 .unwrap();
856 assert!(verification_check.is_none());
857 }
858
859 #[tokio::test]
860 async fn test_reset_password_invalid_token() {
861 let plugin = PasswordManagementPlugin::new();
862 let (ctx, _user, _session) = create_test_context_with_user().await;
863
864 let body = serde_json::json!({
865 "newPassword": "NewPassword123!",
866 "token": "invalid_token"
867 });
868
869 let req = create_auth_request(
870 HttpMethod::Post,
871 "/reset-password",
872 None,
873 Some(body.to_string().into_bytes()),
874 );
875
876 let err = plugin.handle_reset_password(&req, &ctx).await.unwrap_err();
877 assert_eq!(err.status_code(), 400);
878 }
879
880 #[tokio::test]
881 async fn test_reset_password_weak_password() {
882 let plugin = PasswordManagementPlugin::new();
883 let (ctx, user, _session) = create_test_context_with_user().await;
884
885 let reset_token = format!("reset_{}", uuid::Uuid::new_v4());
887 let create_verification = CreateVerification {
888 identifier: user.email.clone().unwrap(),
889 value: reset_token.clone(),
890 expires_at: Utc::now() + Duration::hours(24),
891 };
892 ctx.database
893 .create_verification(create_verification)
894 .await
895 .unwrap();
896
897 let body = serde_json::json!({
898 "newPassword": "weak",
899 "token": reset_token
900 });
901
902 let req = create_auth_request(
903 HttpMethod::Post,
904 "/reset-password",
905 None,
906 Some(body.to_string().into_bytes()),
907 );
908
909 let err = plugin.handle_reset_password(&req, &ctx).await.unwrap_err();
910 assert_eq!(err.status_code(), 400);
911 }
912
913 #[tokio::test]
914 async fn test_change_password_success() {
915 let plugin = PasswordManagementPlugin::new();
916 let (ctx, _user, session) = create_test_context_with_user().await;
917
918 let body = serde_json::json!({
919 "currentPassword": "Password123!",
920 "newPassword": "NewPassword123!",
921 "revokeOtherSessions": "false"
922 });
923
924 let req = create_auth_request(
925 HttpMethod::Post,
926 "/change-password",
927 Some(&session.token),
928 Some(body.to_string().into_bytes()),
929 );
930
931 let response = plugin.handle_change_password(&req, &ctx).await.unwrap();
932 assert_eq!(response.status, 200);
933
934 let body_str = String::from_utf8(response.body).unwrap();
935 let response_data: serde_json::Value = serde_json::from_str(&body_str).unwrap();
936 assert!(response_data["token"].is_null()); let user_id = response_data["user"]["id"].as_str().unwrap();
940 let updated_user = ctx.database.get_user_by_id(user_id).await.unwrap().unwrap();
941 let stored_hash = updated_user
942 .metadata
943 .get("password_hash")
944 .unwrap()
945 .as_str()
946 .unwrap();
947 assert!(
948 plugin
949 .verify_password("NewPassword123!", stored_hash)
950 .is_ok()
951 );
952 }
953
954 #[tokio::test]
955 async fn test_change_password_with_session_revocation() {
956 let plugin = PasswordManagementPlugin::new();
957 let (ctx, _user, session) = create_test_context_with_user().await;
958
959 let body = serde_json::json!({
960 "currentPassword": "Password123!",
961 "newPassword": "NewPassword123!",
962 "revokeOtherSessions": "true"
963 });
964
965 let req = create_auth_request(
966 HttpMethod::Post,
967 "/change-password",
968 Some(&session.token),
969 Some(body.to_string().into_bytes()),
970 );
971
972 let response = plugin.handle_change_password(&req, &ctx).await.unwrap();
973 assert_eq!(response.status, 200);
974
975 let body_str = String::from_utf8(response.body).unwrap();
976 let response_data: serde_json::Value = serde_json::from_str(&body_str).unwrap();
977 assert!(response_data["token"].is_string()); }
979
980 #[tokio::test]
981 async fn test_change_password_sets_cookie_on_session_revocation() {
982 let plugin = PasswordManagementPlugin::new();
983 let (ctx, _user, session) = create_test_context_with_user().await;
984
985 let body = serde_json::json!({
986 "currentPassword": "Password123!",
987 "newPassword": "NewPassword123!",
988 "revokeOtherSessions": true
989 });
990
991 let req = create_auth_request(
992 HttpMethod::Post,
993 "/change-password",
994 Some(&session.token),
995 Some(body.to_string().into_bytes()),
996 );
997
998 let response = plugin.handle_change_password(&req, &ctx).await.unwrap();
999 assert_eq!(response.status, 200);
1000
1001 let set_cookie = response.headers.get("Set-Cookie");
1003 assert!(
1004 set_cookie.is_some(),
1005 "Set-Cookie header must be set when revokeOtherSessions is true"
1006 );
1007
1008 let cookie_value = set_cookie.unwrap();
1009 assert!(
1010 cookie_value.contains(&ctx.config.session.cookie_name),
1011 "Cookie must contain the session cookie name"
1012 );
1013 assert!(
1014 cookie_value.contains("Path=/"),
1015 "Cookie must contain Path=/"
1016 );
1017 assert!(
1018 cookie_value.contains("Expires="),
1019 "Cookie must contain an expiration"
1020 );
1021 }
1022
1023 #[tokio::test]
1024 async fn test_change_password_no_cookie_without_revocation() {
1025 let plugin = PasswordManagementPlugin::new();
1026 let (ctx, _user, session) = create_test_context_with_user().await;
1027
1028 let body = serde_json::json!({
1029 "currentPassword": "Password123!",
1030 "newPassword": "NewPassword123!"
1031 });
1032
1033 let req = create_auth_request(
1034 HttpMethod::Post,
1035 "/change-password",
1036 Some(&session.token),
1037 Some(body.to_string().into_bytes()),
1038 );
1039
1040 let response = plugin.handle_change_password(&req, &ctx).await.unwrap();
1041 assert_eq!(response.status, 200);
1042
1043 let set_cookie = response.headers.get("Set-Cookie");
1045 assert!(
1046 set_cookie.is_none(),
1047 "Set-Cookie header must not be set when revokeOtherSessions is not true"
1048 );
1049 }
1050
1051 #[tokio::test]
1052 async fn test_change_password_revoke_with_boolean() {
1053 let plugin = PasswordManagementPlugin::new();
1054 let (ctx, _user, session) = create_test_context_with_user().await;
1055
1056 let body = serde_json::json!({
1058 "currentPassword": "Password123!",
1059 "newPassword": "NewPassword123!",
1060 "revokeOtherSessions": true
1061 });
1062
1063 let req = create_auth_request(
1064 HttpMethod::Post,
1065 "/change-password",
1066 Some(&session.token),
1067 Some(body.to_string().into_bytes()),
1068 );
1069
1070 let response = plugin.handle_change_password(&req, &ctx).await.unwrap();
1071 assert_eq!(response.status, 200);
1072
1073 let body_str = String::from_utf8(response.body).unwrap();
1074 let response_data: serde_json::Value = serde_json::from_str(&body_str).unwrap();
1075 assert!(
1076 response_data["token"].is_string(),
1077 "New token must be returned when revokeOtherSessions is boolean true"
1078 );
1079 }
1080
1081 #[tokio::test]
1082 async fn test_change_password_wrong_current_password() {
1083 let plugin = PasswordManagementPlugin::new();
1084 let (ctx, _user, session) = create_test_context_with_user().await;
1085
1086 let body = serde_json::json!({
1087 "currentPassword": "WrongPassword123!",
1088 "newPassword": "NewPassword123!"
1089 });
1090
1091 let req = create_auth_request(
1092 HttpMethod::Post,
1093 "/change-password",
1094 Some(&session.token),
1095 Some(body.to_string().into_bytes()),
1096 );
1097
1098 let err = plugin.handle_change_password(&req, &ctx).await.unwrap_err();
1099 assert_eq!(err.status_code(), 401);
1100 }
1101
1102 #[tokio::test]
1103 async fn test_change_password_unauthorized() {
1104 let plugin = PasswordManagementPlugin::new();
1105 let (ctx, _user, _session) = create_test_context_with_user().await;
1106
1107 let body = serde_json::json!({
1108 "currentPassword": "Password123!",
1109 "newPassword": "NewPassword123!"
1110 });
1111
1112 let req = create_auth_request(
1113 HttpMethod::Post,
1114 "/change-password",
1115 None,
1116 Some(body.to_string().into_bytes()),
1117 );
1118
1119 let err = plugin.handle_change_password(&req, &ctx).await.unwrap_err();
1120 assert_eq!(err.status_code(), 401);
1121 }
1122
1123 #[tokio::test]
1124 async fn test_reset_password_token_endpoint_success() {
1125 let plugin = PasswordManagementPlugin::new();
1126 let (ctx, user, _session) = create_test_context_with_user().await;
1127
1128 let reset_token = format!("reset_{}", uuid::Uuid::new_v4());
1130 let create_verification = CreateVerification {
1131 identifier: user.email.clone().unwrap(),
1132 value: reset_token.clone(),
1133 expires_at: Utc::now() + Duration::hours(24),
1134 };
1135 ctx.database
1136 .create_verification(create_verification)
1137 .await
1138 .unwrap();
1139
1140 let req = create_auth_request(HttpMethod::Get, "/reset-password/token", None, None);
1141
1142 let response = plugin
1143 .handle_reset_password_token(&reset_token, &req, &ctx)
1144 .await
1145 .unwrap();
1146 assert_eq!(response.status, 200);
1147
1148 let body_str = String::from_utf8(response.body).unwrap();
1149 let response_data: ResetPasswordTokenResponse = serde_json::from_str(&body_str).unwrap();
1150 assert_eq!(response_data.token, reset_token);
1151 }
1152
1153 #[tokio::test]
1154 async fn test_reset_password_token_endpoint_with_callback() {
1155 let plugin = PasswordManagementPlugin::new();
1156 let (ctx, user, _session) = create_test_context_with_user().await;
1157
1158 let reset_token = format!("reset_{}", uuid::Uuid::new_v4());
1160 let create_verification = CreateVerification {
1161 identifier: user.email.clone().unwrap(),
1162 value: reset_token.clone(),
1163 expires_at: Utc::now() + Duration::hours(24),
1164 };
1165 ctx.database
1166 .create_verification(create_verification)
1167 .await
1168 .unwrap();
1169
1170 let mut query = HashMap::new();
1171 query.insert(
1172 "callbackURL".to_string(),
1173 "http://localhost:3000/reset".to_string(),
1174 );
1175
1176 let req = AuthRequest {
1177 method: HttpMethod::Get,
1178 path: "/reset-password/token".to_string(),
1179 headers: HashMap::new(),
1180 body: None,
1181 query,
1182 };
1183
1184 let response = plugin
1185 .handle_reset_password_token(&reset_token, &req, &ctx)
1186 .await
1187 .unwrap();
1188 assert_eq!(response.status, 302);
1189
1190 let location_header = response
1192 .headers
1193 .iter()
1194 .find(|(key, _)| *key == "Location")
1195 .map(|(_, value)| value);
1196 assert!(location_header.is_some());
1197 assert!(
1198 location_header
1199 .unwrap()
1200 .contains("http://localhost:3000/reset")
1201 );
1202 assert!(location_header.unwrap().contains(&reset_token));
1203 }
1204
1205 #[tokio::test]
1206 async fn test_reset_password_token_endpoint_invalid_token() {
1207 let plugin = PasswordManagementPlugin::new();
1208 let (ctx, _user, _session) = create_test_context_with_user().await;
1209
1210 let req = create_auth_request(HttpMethod::Get, "/reset-password/token", None, None);
1211
1212 let err = plugin
1213 .handle_reset_password_token("invalid_token", &req, &ctx)
1214 .await
1215 .unwrap_err();
1216 assert_eq!(err.status_code(), 400);
1217 }
1218
1219 #[tokio::test]
1220 async fn test_password_validation() {
1221 let plugin = PasswordManagementPlugin::new();
1222 let mut config = AuthConfig::new("test-secret");
1223 config.password = PasswordConfig {
1224 min_length: 8,
1225 require_uppercase: true,
1226 require_lowercase: true,
1227 require_numbers: true,
1228 require_special: true,
1229 argon2_config: Argon2Config::default(),
1230 };
1231 let ctx = AuthContext::new(Arc::new(config), Arc::new(MemoryDatabaseAdapter::new()));
1232
1233 assert!(plugin.validate_password("Password123!", &ctx).is_ok());
1235
1236 assert!(plugin.validate_password("Pass1!", &ctx).is_err());
1238
1239 assert!(plugin.validate_password("password123!", &ctx).is_err());
1241
1242 assert!(plugin.validate_password("PASSWORD123!", &ctx).is_err());
1244
1245 assert!(plugin.validate_password("Password!", &ctx).is_err());
1247
1248 assert!(plugin.validate_password("Password123", &ctx).is_err());
1250 }
1251
1252 #[tokio::test]
1253 async fn test_password_hashing_and_verification() {
1254 let plugin = PasswordManagementPlugin::new();
1255
1256 let password = "TestPassword123!";
1257 let hash = plugin.hash_password(password).unwrap();
1258
1259 assert!(plugin.verify_password(password, &hash).is_ok());
1261
1262 assert!(plugin.verify_password("WrongPassword123!", &hash).is_err());
1264 }
1265
1266 #[tokio::test]
1267 async fn test_plugin_routes() {
1268 let plugin = PasswordManagementPlugin::new();
1269 let routes = AuthPlugin::<MemoryDatabaseAdapter>::routes(&plugin);
1270
1271 assert_eq!(routes.len(), 5);
1272 assert!(
1273 routes
1274 .iter()
1275 .any(|r| r.path == "/forget-password" && r.method == HttpMethod::Post)
1276 );
1277 assert!(
1278 routes
1279 .iter()
1280 .any(|r| r.path == "/reset-password" && r.method == HttpMethod::Post)
1281 );
1282 assert!(
1283 routes
1284 .iter()
1285 .any(|r| r.path == "/reset-password/{token}" && r.method == HttpMethod::Get)
1286 );
1287 assert!(
1288 routes
1289 .iter()
1290 .any(|r| r.path == "/change-password" && r.method == HttpMethod::Post)
1291 );
1292 }
1293
1294 #[tokio::test]
1295 async fn test_plugin_on_request_routing() {
1296 let plugin = PasswordManagementPlugin::new();
1297 let (ctx, _user, session) = create_test_context_with_user().await;
1298
1299 let body = serde_json::json!({"email": "test@example.com"});
1301 let req = create_auth_request(
1302 HttpMethod::Post,
1303 "/forget-password",
1304 None,
1305 Some(body.to_string().into_bytes()),
1306 );
1307 let response = plugin.on_request(&req, &ctx).await.unwrap();
1308 assert!(response.is_some());
1309 assert_eq!(response.unwrap().status, 200);
1310
1311 let body = serde_json::json!({
1313 "currentPassword": "Password123!",
1314 "newPassword": "NewPassword123!"
1315 });
1316 let req = create_auth_request(
1317 HttpMethod::Post,
1318 "/change-password",
1319 Some(&session.token),
1320 Some(body.to_string().into_bytes()),
1321 );
1322 let response = plugin.on_request(&req, &ctx).await.unwrap();
1323 assert!(response.is_some());
1324 assert_eq!(response.unwrap().status, 200);
1325
1326 let req = create_auth_request(HttpMethod::Get, "/invalid-route", None, None);
1328 let response = plugin.on_request(&req, &ctx).await.unwrap();
1329 assert!(response.is_none());
1330 }
1331
1332 #[tokio::test]
1333 async fn test_configuration() {
1334 let config = PasswordManagementConfig {
1335 reset_token_expiry_hours: 48,
1336 require_current_password: false,
1337 send_email_notifications: false,
1338 };
1339
1340 let plugin = PasswordManagementPlugin::with_config(config);
1341 assert_eq!(plugin.config.reset_token_expiry_hours, 48);
1342 assert!(!plugin.config.require_current_password);
1343 assert!(!plugin.config.send_email_notifications);
1344 }
1345}