1use async_trait::async_trait;
2use chrono::{Duration, Utc};
3use serde::{Deserialize, Deserializer, Serialize};
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use uuid::Uuid;
8use validator::Validate;
9
10use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
11use better_auth_core::{AuthError, AuthResult};
12use better_auth_core::{AuthRequest, AuthResponse, CreateVerification, HttpMethod};
13use better_auth_core::{AuthSession, AuthUser, AuthVerification, DatabaseAdapter};
14
15use better_auth_core::utils::password::{self as password_utils, PasswordHasher};
16
17use super::StatusResponse;
18
19pub type OnPasswordResetCallback =
21 dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = AuthResult<()>> + Send>> + Send + Sync;
22
23#[async_trait]
29pub trait SendResetPassword: Send + Sync {
30 async fn send(&self, user: &serde_json::Value, url: &str, token: &str) -> AuthResult<()>;
36}
37
38pub struct PasswordManagementPlugin {
40 config: PasswordManagementConfig,
41}
42
43#[derive(Clone)]
44pub struct PasswordManagementConfig {
45 pub reset_token_expiry_hours: i64,
46 pub require_current_password: bool,
47 pub send_email_notifications: bool,
48 pub revoke_sessions_on_password_reset: bool,
50 pub send_reset_password: Option<Arc<dyn SendResetPassword>>,
52 pub on_password_reset: Option<Arc<OnPasswordResetCallback>>,
55 pub password_hasher: Option<Arc<dyn PasswordHasher>>,
57}
58
59impl std::fmt::Debug for PasswordManagementConfig {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 f.debug_struct("PasswordManagementConfig")
62 .field("reset_token_expiry_hours", &self.reset_token_expiry_hours)
63 .field("require_current_password", &self.require_current_password)
64 .field("send_email_notifications", &self.send_email_notifications)
65 .field(
66 "revoke_sessions_on_password_reset",
67 &self.revoke_sessions_on_password_reset,
68 )
69 .field(
70 "send_reset_password",
71 &self.send_reset_password.as_ref().map(|_| "custom"),
72 )
73 .field(
74 "on_password_reset",
75 &self.on_password_reset.as_ref().map(|_| "custom"),
76 )
77 .field(
78 "password_hasher",
79 &self.password_hasher.as_ref().map(|_| "custom"),
80 )
81 .finish()
82 }
83}
84
85#[derive(Debug, Deserialize, Validate)]
87struct ForgetPasswordRequest {
88 #[validate(email(message = "Invalid email address"))]
89 email: String,
90 #[serde(rename = "redirectTo")]
91 redirect_to: Option<String>,
92}
93
94#[derive(Debug, Deserialize, Validate)]
95struct ResetPasswordRequest {
96 #[serde(rename = "newPassword")]
97 #[validate(length(min = 1, message = "New password is required"))]
98 new_password: String,
99 token: Option<String>,
100}
101
102#[derive(Debug, Deserialize, Validate)]
103struct SetPasswordRequest {
104 #[serde(rename = "newPassword")]
105 #[validate(length(min = 1, message = "New password is required"))]
106 new_password: String,
107}
108
109#[derive(Debug, Deserialize, Validate)]
110struct ChangePasswordRequest {
111 #[serde(rename = "newPassword")]
112 #[validate(length(min = 1, message = "New password is required"))]
113 new_password: String,
114 #[serde(rename = "currentPassword")]
115 #[validate(length(min = 1, message = "Current password is required"))]
116 current_password: String,
117 #[serde(
118 default,
119 rename = "revokeOtherSessions",
120 deserialize_with = "deserialize_bool_or_string"
121 )]
122 revoke_other_sessions: Option<bool>,
123}
124
125fn deserialize_bool_or_string<'de, D>(deserializer: D) -> Result<Option<bool>, D::Error>
129where
130 D: Deserializer<'de>,
131{
132 let value: Option<serde_json::Value> = Option::deserialize(deserializer)?;
133 match value {
134 None => Ok(None),
135 Some(serde_json::Value::Bool(b)) => Ok(Some(b)),
136 Some(serde_json::Value::String(s)) => match s.to_lowercase().as_str() {
137 "true" => Ok(Some(true)),
138 "false" => Ok(Some(false)),
139 _ => Err(serde::de::Error::custom(format!(
140 "invalid value for revokeOtherSessions: {}",
141 s
142 ))),
143 },
144 Some(other) => Err(serde::de::Error::custom(format!(
145 "invalid type for revokeOtherSessions: {}",
146 other
147 ))),
148 }
149}
150
151#[derive(Debug, Serialize)]
153struct ChangePasswordResponse<U: Serialize> {
154 token: Option<String>,
155 user: U,
156}
157
158#[derive(Debug, Serialize, Deserialize)]
159struct ResetPasswordTokenResponse {
160 token: String,
161}
162
163impl PasswordManagementPlugin {
164 pub fn new() -> Self {
165 Self {
166 config: PasswordManagementConfig::default(),
167 }
168 }
169
170 pub fn with_config(config: PasswordManagementConfig) -> Self {
171 Self { config }
172 }
173
174 pub fn reset_token_expiry_hours(mut self, hours: i64) -> Self {
175 self.config.reset_token_expiry_hours = hours;
176 self
177 }
178
179 pub fn require_current_password(mut self, require: bool) -> Self {
180 self.config.require_current_password = require;
181 self
182 }
183
184 pub fn send_email_notifications(mut self, send: bool) -> Self {
185 self.config.send_email_notifications = send;
186 self
187 }
188
189 pub fn revoke_sessions_on_password_reset(mut self, revoke: bool) -> Self {
190 self.config.revoke_sessions_on_password_reset = revoke;
191 self
192 }
193
194 pub fn send_reset_password(mut self, sender: Arc<dyn SendResetPassword>) -> Self {
195 self.config.send_reset_password = Some(sender);
196 self
197 }
198
199 pub fn on_password_reset(mut self, callback: Arc<OnPasswordResetCallback>) -> Self {
200 self.config.on_password_reset = Some(callback);
201 self
202 }
203
204 pub fn password_hasher(mut self, hasher: Arc<dyn PasswordHasher>) -> Self {
205 self.config.password_hasher = Some(hasher);
206 self
207 }
208}
209
210impl Default for PasswordManagementPlugin {
211 fn default() -> Self {
212 Self::new()
213 }
214}
215
216impl Default for PasswordManagementConfig {
217 fn default() -> Self {
218 Self {
219 reset_token_expiry_hours: 24, require_current_password: true,
221 send_email_notifications: true,
222 revoke_sessions_on_password_reset: true,
223 send_reset_password: None,
224 on_password_reset: None,
225 password_hasher: None,
226 }
227 }
228}
229
230#[async_trait]
231impl<DB: DatabaseAdapter> AuthPlugin<DB> for PasswordManagementPlugin {
232 fn name(&self) -> &'static str {
233 "password-management"
234 }
235
236 fn routes(&self) -> Vec<AuthRoute> {
237 vec![
238 AuthRoute::post("/forget-password", "forget_password"),
239 AuthRoute::post("/reset-password", "reset_password"),
240 AuthRoute::get("/reset-password/{token}", "reset_password_token"),
241 AuthRoute::post("/change-password", "change_password"),
242 AuthRoute::post("/set-password", "set_password"),
243 ]
244 }
245
246 async fn on_request(
247 &self,
248 req: &AuthRequest,
249 ctx: &AuthContext<DB>,
250 ) -> AuthResult<Option<AuthResponse>> {
251 match (req.method(), req.path()) {
252 (HttpMethod::Post, "/forget-password") => {
253 Ok(Some(self.handle_forget_password(req, ctx).await?))
254 }
255 (HttpMethod::Post, "/reset-password") => {
256 Ok(Some(self.handle_reset_password(req, ctx).await?))
257 }
258 (HttpMethod::Post, "/change-password") => {
259 Ok(Some(self.handle_change_password(req, ctx).await?))
260 }
261 (HttpMethod::Post, "/set-password") => {
262 Ok(Some(self.handle_set_password(req, ctx).await?))
263 }
264 (HttpMethod::Get, path) if path.starts_with("/reset-password/") => {
265 let token = &path[16..]; Ok(Some(
267 self.handle_reset_password_token(token, req, ctx).await?,
268 ))
269 }
270 _ => Ok(None),
271 }
272 }
273}
274
275impl PasswordManagementPlugin {
277 async fn handle_forget_password<DB: DatabaseAdapter>(
278 &self,
279 req: &AuthRequest,
280 ctx: &AuthContext<DB>,
281 ) -> AuthResult<AuthResponse> {
282 let forget_req: ForgetPasswordRequest = match better_auth_core::validate_request_body(req) {
283 Ok(v) => v,
284 Err(resp) => return Ok(resp),
285 };
286
287 let user = match ctx.database.get_user_by_email(&forget_req.email).await? {
289 Some(user) => user,
290 None => {
291 let response = StatusResponse { status: true };
293 return Ok(AuthResponse::json(200, &response)?);
294 }
295 };
296
297 let reset_token = format!("reset_{}", Uuid::new_v4());
299 let expires_at = Utc::now() + Duration::hours(self.config.reset_token_expiry_hours);
300
301 let create_verification = CreateVerification {
303 identifier: user.email().unwrap_or_default().to_string(),
304 value: reset_token.clone(),
305 expires_at,
306 };
307
308 ctx.database
309 .create_verification(create_verification)
310 .await?;
311
312 let reset_url = if let Some(redirect_to) = &forget_req.redirect_to {
316 if redirect_to.starts_with('/') || redirect_to.starts_with(&ctx.config.base_url) {
317 format!("{}?token={}", redirect_to, reset_token)
318 } else {
319 tracing::warn!(
321 redirect_to = %redirect_to,
322 "Ignoring untrusted redirect_to"
323 );
324 format!(
325 "{}/reset-password?token={}",
326 ctx.config.base_url, reset_token
327 )
328 }
329 } else {
330 format!(
331 "{}/reset-password?token={}",
332 ctx.config.base_url, reset_token
333 )
334 };
335
336 if self.config.send_email_notifications {
337 if let Some(sender) = &self.config.send_reset_password {
338 let user_value = password_utils::serialize_to_value(&user)?;
339 if let Err(e) = sender.send(&user_value, &reset_url, &reset_token).await {
340 tracing::warn!(
341 email = %forget_req.email,
342 error = %e,
343 "Custom send_reset_password callback failed"
344 );
345 }
346 } else if let Ok(provider) = ctx.email_provider() {
347 let subject = "Reset your password";
348 let html = format!(
349 "<p>Click the link below to reset your password:</p>\
350 <p><a href=\"{url}\">Reset Password</a></p>",
351 url = reset_url
352 );
353 let text = format!("Reset your password: {}", reset_url);
354
355 if let Err(e) = provider
356 .send(&forget_req.email, subject, &html, &text)
357 .await
358 {
359 tracing::warn!(
360 email = %forget_req.email,
361 error = %e,
362 "Failed to send password reset email"
363 );
364 }
365 } else {
366 tracing::warn!(
367 email = %forget_req.email,
368 "No email provider configured, skipping password reset email"
369 );
370 }
371 } let response = StatusResponse { status: true };
374 Ok(AuthResponse::json(200, &response)?)
375 }
376
377 async fn handle_reset_password<DB: DatabaseAdapter>(
378 &self,
379 req: &AuthRequest,
380 ctx: &AuthContext<DB>,
381 ) -> AuthResult<AuthResponse> {
382 let reset_req: ResetPasswordRequest = match better_auth_core::validate_request_body(req) {
383 Ok(v) => v,
384 Err(resp) => return Ok(resp),
385 };
386
387 self.validate_password(&reset_req.new_password, ctx)?;
389
390 let token = reset_req.token.as_deref().unwrap_or("");
392 if token.is_empty() {
393 return Err(AuthError::bad_request("Reset token is required"));
394 }
395
396 let (user, verification) = self
397 .find_user_by_reset_token(token, ctx)
398 .await?
399 .ok_or_else(|| AuthError::bad_request("Invalid or expired reset token"))?;
400
401 let password_hash = self.hash_password(&reset_req.new_password).await?;
403
404 let mut metadata = user.metadata().clone();
406 metadata["password_hash"] = serde_json::Value::String(password_hash);
407
408 ctx.database
409 .update_user(user.id(), password_utils::update_user_metadata(metadata))
410 .await?;
411
412 ctx.database.delete_verification(verification.id()).await?;
414
415 if self.config.revoke_sessions_on_password_reset {
417 ctx.database.delete_user_sessions(user.id()).await?;
418 }
419
420 if let Some(callback) = &self.config.on_password_reset {
424 match password_utils::serialize_to_value(&user) {
425 Ok(user_value) => {
426 if let Err(e) = callback(user_value).await {
427 tracing::warn!(
428 error = %e,
429 "on_password_reset callback failed"
430 );
431 }
432 }
433 Err(e) => {
434 tracing::warn!(
435 error = %e,
436 "Failed to serialize user for on_password_reset callback"
437 );
438 }
439 }
440 }
441
442 let response = StatusResponse { status: true };
443 Ok(AuthResponse::json(200, &response)?)
444 }
445
446 async fn handle_change_password<DB: DatabaseAdapter>(
447 &self,
448 req: &AuthRequest,
449 ctx: &AuthContext<DB>,
450 ) -> AuthResult<AuthResponse> {
451 let change_req: ChangePasswordRequest = match better_auth_core::validate_request_body(req) {
452 Ok(v) => v,
453 Err(resp) => return Ok(resp),
454 };
455
456 let user = self
458 .get_current_user(req, ctx)
459 .await?
460 .ok_or(AuthError::Unauthenticated)?;
461
462 if self.config.require_current_password {
464 let stored_hash = user
465 .metadata()
466 .get("password_hash")
467 .and_then(|v| v.as_str())
468 .ok_or_else(|| AuthError::bad_request("No password set for this user"))?;
469
470 self.verify_password(&change_req.current_password, stored_hash)
471 .await
472 .map_err(|_| AuthError::InvalidCredentials)?;
473 }
474
475 self.validate_password(&change_req.new_password, ctx)?;
477
478 let password_hash = self.hash_password(&change_req.new_password).await?;
480
481 let mut metadata = user.metadata().clone();
483 metadata["password_hash"] = serde_json::Value::String(password_hash);
484
485 let updated_user = ctx
486 .database
487 .update_user(user.id(), password_utils::update_user_metadata(metadata))
488 .await?;
489
490 let new_token = if change_req.revoke_other_sessions == Some(true) {
492 ctx.database.delete_user_sessions(user.id()).await?;
494
495 let session_manager =
497 better_auth_core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
498 let session = session_manager
499 .create_session(&updated_user, None, None)
500 .await?;
501 Some(session.token().to_string())
502 } else {
503 None
504 };
505
506 let response = ChangePasswordResponse {
507 token: new_token.clone(),
508 user: updated_user,
509 };
510
511 let auth_response = AuthResponse::json(200, &response)?;
512
513 if let Some(token) = new_token {
515 let cookie_header =
516 better_auth_core::utils::cookie_utils::create_session_cookie(&token, ctx);
517 Ok(auth_response.with_header("Set-Cookie", cookie_header))
518 } else {
519 Ok(auth_response)
520 }
521 }
522
523 async fn handle_set_password<DB: DatabaseAdapter>(
524 &self,
525 req: &AuthRequest,
526 ctx: &AuthContext<DB>,
527 ) -> AuthResult<AuthResponse> {
528 let set_req: SetPasswordRequest = match better_auth_core::validate_request_body(req) {
529 Ok(v) => v,
530 Err(resp) => return Ok(resp),
531 };
532
533 let user = self
535 .get_current_user(req, ctx)
536 .await?
537 .ok_or(AuthError::Unauthenticated)?;
538
539 if user
541 .metadata()
542 .get("password_hash")
543 .and_then(|v| v.as_str())
544 .is_some()
545 {
546 return Err(AuthError::bad_request(
547 "User already has a password. Use /change-password instead.",
548 ));
549 }
550
551 self.validate_password(&set_req.new_password, ctx)?;
553
554 let password_hash = self.hash_password(&set_req.new_password).await?;
556
557 let mut metadata = user.metadata().clone();
558 metadata["password_hash"] = serde_json::Value::String(password_hash);
559
560 ctx.database
561 .update_user(user.id(), password_utils::update_user_metadata(metadata))
562 .await?;
563
564 let response = StatusResponse { status: true };
565 Ok(AuthResponse::json(200, &response)?)
566 }
567
568 async fn handle_reset_password_token<DB: DatabaseAdapter>(
569 &self,
570 token: &str,
571 _req: &AuthRequest,
572 ctx: &AuthContext<DB>,
573 ) -> AuthResult<AuthResponse> {
574 let callback_url = _req.query.get("callbackURL").cloned();
576
577 let (_user, _verification) = match self.find_user_by_reset_token(token, ctx).await? {
579 Some((user, verification)) => (user, verification),
580 None => {
581 if let Some(callback_url) = callback_url {
583 let redirect_url = format!("{}?error=INVALID_TOKEN", callback_url);
584 let mut headers = std::collections::HashMap::new();
585 headers.insert("Location".to_string(), redirect_url);
586 return Ok(AuthResponse {
587 status: 302,
588 headers,
589 body: Vec::new(),
590 });
591 }
592
593 return Err(AuthError::bad_request("Invalid or expired reset token"));
594 }
595 };
596
597 if let Some(callback_url) = callback_url {
599 let redirect_url = format!("{}?token={}", callback_url, token);
600 let mut headers = std::collections::HashMap::new();
601 headers.insert("Location".to_string(), redirect_url);
602 return Ok(AuthResponse {
603 status: 302,
604 headers,
605 body: Vec::new(),
606 });
607 }
608
609 let response = ResetPasswordTokenResponse {
611 token: token.to_string(),
612 };
613 Ok(AuthResponse::json(200, &response)?)
614 }
615
616 async fn find_user_by_reset_token<DB: DatabaseAdapter>(
617 &self,
618 token: &str,
619 ctx: &AuthContext<DB>,
620 ) -> AuthResult<Option<(DB::User, DB::Verification)>> {
621 let verification = match ctx.database.get_verification_by_value(token).await? {
623 Some(verification) => verification,
624 None => return Ok(None),
625 };
626
627 let user = match ctx
629 .database
630 .get_user_by_email(verification.identifier())
631 .await?
632 {
633 Some(user) => user,
634 None => return Ok(None),
635 };
636
637 Ok(Some((user, verification)))
638 }
639
640 async fn get_current_user<DB: DatabaseAdapter>(
641 &self,
642 req: &AuthRequest,
643 ctx: &AuthContext<DB>,
644 ) -> AuthResult<Option<DB::User>> {
645 let session_manager =
646 better_auth_core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
647
648 if let Some(token) = session_manager.extract_session_token(req)
649 && let Some(session) = session_manager.get_session(&token).await?
650 {
651 return ctx.database.get_user_by_id(session.user_id()).await;
652 }
653
654 Ok(None)
655 }
656
657 fn validate_password<DB: DatabaseAdapter>(
658 &self,
659 password: &str,
660 ctx: &AuthContext<DB>,
661 ) -> AuthResult<()> {
662 password_utils::validate_password(password, ctx.config.password.min_length, usize::MAX, ctx)
666 }
667
668 async fn hash_password(&self, password: &str) -> AuthResult<String> {
669 password_utils::hash_password(self.config.password_hasher.as_ref(), password).await
670 }
671
672 async fn verify_password(&self, password: &str, hash: &str) -> AuthResult<()> {
673 password_utils::verify_password(self.config.password_hasher.as_ref(), password, hash).await
674 }
675}
676
677#[cfg(test)]
678mod tests {
679 use super::*;
680 use crate::plugins::test_helpers;
681 use better_auth_core::adapters::{MemoryDatabaseAdapter, SessionOps, UserOps, VerificationOps};
682 use better_auth_core::config::{Argon2Config, AuthConfig, PasswordConfig};
683 use better_auth_core::{CreateUser, CreateVerification, Session, User};
684 use chrono::Duration;
685 use std::collections::HashMap;
686
687 async fn create_test_context_with_user() -> (AuthContext<MemoryDatabaseAdapter>, User, Session)
688 {
689 let mut config = AuthConfig::new("test-secret-key-at-least-32-chars-long");
690 config.password = PasswordConfig {
691 min_length: 8,
692 require_uppercase: true,
693 require_lowercase: true,
694 require_numbers: true,
695 require_special: true,
696 argon2_config: Argon2Config::default(),
697 };
698
699 let ctx = test_helpers::create_test_context_with_config(config);
700
701 let plugin = PasswordManagementPlugin::new();
703 let password_hash = plugin.hash_password("Password123!").await.unwrap();
704
705 let metadata = serde_json::json!({
706 "password_hash": password_hash,
707 });
708
709 let create_user = CreateUser::new()
710 .with_email("test@example.com")
711 .with_name("Test User")
712 .with_metadata(metadata);
713 let user = test_helpers::create_user(&ctx, create_user).await;
714 let session =
715 test_helpers::create_session(&ctx, user.id.clone(), Duration::hours(24)).await;
716
717 (ctx, user, session)
718 }
719
720 async fn create_reset_token(ctx: &AuthContext<MemoryDatabaseAdapter>, email: &str) -> String {
723 let reset_token = format!("reset_{}", uuid::Uuid::new_v4());
724 let create_verification = CreateVerification {
725 identifier: email.to_string(),
726 value: reset_token.clone(),
727 expires_at: Utc::now() + Duration::hours(24),
728 };
729 ctx.database
730 .create_verification(create_verification)
731 .await
732 .unwrap();
733 reset_token
734 }
735
736 #[tokio::test]
737 async fn test_forget_password_success() {
738 let plugin = PasswordManagementPlugin::new();
739 let (ctx, _user, _session) = create_test_context_with_user().await;
740
741 let body = serde_json::json!({
742 "email": "test@example.com",
743 "redirectTo": "http://localhost:3000/reset"
744 });
745
746 let req = test_helpers::create_auth_request_no_query(
747 HttpMethod::Post,
748 "/forget-password",
749 None,
750 Some(body.to_string().into_bytes()),
751 );
752
753 let response = plugin.handle_forget_password(&req, &ctx).await.unwrap();
754 assert_eq!(response.status, 200);
755
756 let body_str = String::from_utf8(response.body).unwrap();
757 let response_data: StatusResponse = serde_json::from_str(&body_str).unwrap();
758 assert!(response_data.status);
759 }
760
761 #[tokio::test]
762 async fn test_forget_password_unknown_email() {
763 let plugin = PasswordManagementPlugin::new();
764 let (ctx, _user, _session) = create_test_context_with_user().await;
765
766 let body = serde_json::json!({
767 "email": "unknown@example.com"
768 });
769
770 let req = test_helpers::create_auth_request_no_query(
771 HttpMethod::Post,
772 "/forget-password",
773 None,
774 Some(body.to_string().into_bytes()),
775 );
776
777 let response = plugin.handle_forget_password(&req, &ctx).await.unwrap();
778 assert_eq!(response.status, 200);
779
780 let body_str = String::from_utf8(response.body).unwrap();
782 let response_data: StatusResponse = serde_json::from_str(&body_str).unwrap();
783 assert!(response_data.status);
784 }
785
786 #[tokio::test]
787 async fn test_reset_password_success() {
788 let plugin = PasswordManagementPlugin::new();
789 let (ctx, user, _session) = create_test_context_with_user().await;
790
791 let reset_token = create_reset_token(&ctx, user.email.as_deref().unwrap()).await;
792
793 let body = serde_json::json!({
794 "newPassword": "NewPassword123!",
795 "token": reset_token
796 });
797
798 let req = test_helpers::create_auth_request_no_query(
799 HttpMethod::Post,
800 "/reset-password",
801 None,
802 Some(body.to_string().into_bytes()),
803 );
804
805 let response = plugin.handle_reset_password(&req, &ctx).await.unwrap();
806 assert_eq!(response.status, 200);
807
808 let body_str = String::from_utf8(response.body).unwrap();
809 let response_data: StatusResponse = serde_json::from_str(&body_str).unwrap();
810 assert!(response_data.status);
811
812 let updated_user = ctx
814 .database
815 .get_user_by_id(&user.id)
816 .await
817 .unwrap()
818 .unwrap();
819 let stored_hash = updated_user
820 .metadata
821 .get("password_hash")
822 .unwrap()
823 .as_str()
824 .unwrap();
825 assert!(
826 plugin
827 .verify_password("NewPassword123!", stored_hash)
828 .await
829 .is_ok()
830 );
831
832 let verification_check = ctx
834 .database
835 .get_verification_by_value(&reset_token)
836 .await
837 .unwrap();
838 assert!(verification_check.is_none());
839 }
840
841 #[tokio::test]
842 async fn test_reset_password_invalid_token() {
843 let plugin = PasswordManagementPlugin::new();
844 let (ctx, _user, _session) = create_test_context_with_user().await;
845
846 let body = serde_json::json!({
847 "newPassword": "NewPassword123!",
848 "token": "invalid_token"
849 });
850
851 let req = test_helpers::create_auth_request_no_query(
852 HttpMethod::Post,
853 "/reset-password",
854 None,
855 Some(body.to_string().into_bytes()),
856 );
857
858 let err = plugin.handle_reset_password(&req, &ctx).await.unwrap_err();
859 assert_eq!(err.status_code(), 400);
860 }
861
862 #[tokio::test]
863 async fn test_reset_password_weak_password() {
864 let plugin = PasswordManagementPlugin::new();
865 let (ctx, user, _session) = create_test_context_with_user().await;
866
867 let reset_token = create_reset_token(&ctx, user.email.as_deref().unwrap()).await;
868
869 let body = serde_json::json!({
870 "newPassword": "weak",
871 "token": reset_token
872 });
873
874 let req = test_helpers::create_auth_request_no_query(
875 HttpMethod::Post,
876 "/reset-password",
877 None,
878 Some(body.to_string().into_bytes()),
879 );
880
881 let err = plugin.handle_reset_password(&req, &ctx).await.unwrap_err();
882 assert_eq!(err.status_code(), 400);
883 }
884
885 #[tokio::test]
886 async fn test_change_password_success() {
887 let plugin = PasswordManagementPlugin::new();
888 let (ctx, _user, session) = create_test_context_with_user().await;
889
890 let body = serde_json::json!({
891 "currentPassword": "Password123!",
892 "newPassword": "NewPassword123!",
893 "revokeOtherSessions": "false"
894 });
895
896 let req = test_helpers::create_auth_request_no_query(
897 HttpMethod::Post,
898 "/change-password",
899 Some(&session.token),
900 Some(body.to_string().into_bytes()),
901 );
902
903 let response = plugin.handle_change_password(&req, &ctx).await.unwrap();
904 assert_eq!(response.status, 200);
905
906 let body_str = String::from_utf8(response.body).unwrap();
907 let response_data: serde_json::Value = serde_json::from_str(&body_str).unwrap();
908 assert!(response_data["token"].is_null()); let user_id = response_data["user"]["id"].as_str().unwrap();
912 let updated_user = ctx.database.get_user_by_id(user_id).await.unwrap().unwrap();
913 let stored_hash = updated_user
914 .metadata
915 .get("password_hash")
916 .unwrap()
917 .as_str()
918 .unwrap();
919 assert!(
920 plugin
921 .verify_password("NewPassword123!", stored_hash)
922 .await
923 .is_ok()
924 );
925 }
926
927 #[tokio::test]
928 async fn test_change_password_with_session_revocation() {
929 let plugin = PasswordManagementPlugin::new();
930 let (ctx, _user, session) = create_test_context_with_user().await;
931
932 let body = serde_json::json!({
933 "currentPassword": "Password123!",
934 "newPassword": "NewPassword123!",
935 "revokeOtherSessions": "true"
936 });
937
938 let req = test_helpers::create_auth_request_no_query(
939 HttpMethod::Post,
940 "/change-password",
941 Some(&session.token),
942 Some(body.to_string().into_bytes()),
943 );
944
945 let response = plugin.handle_change_password(&req, &ctx).await.unwrap();
946 assert_eq!(response.status, 200);
947
948 let body_str = String::from_utf8(response.body).unwrap();
949 let response_data: serde_json::Value = serde_json::from_str(&body_str).unwrap();
950 assert!(response_data["token"].is_string()); }
952
953 #[tokio::test]
954 async fn test_change_password_sets_cookie_on_session_revocation() {
955 let plugin = PasswordManagementPlugin::new();
956 let (ctx, _user, session) = create_test_context_with_user().await;
957
958 let body = serde_json::json!({
959 "currentPassword": "Password123!",
960 "newPassword": "NewPassword123!",
961 "revokeOtherSessions": true
962 });
963
964 let req = test_helpers::create_auth_request_no_query(
965 HttpMethod::Post,
966 "/change-password",
967 Some(&session.token),
968 Some(body.to_string().into_bytes()),
969 );
970
971 let response = plugin.handle_change_password(&req, &ctx).await.unwrap();
972 assert_eq!(response.status, 200);
973
974 let set_cookie = response.headers.get("Set-Cookie");
976 assert!(
977 set_cookie.is_some(),
978 "Set-Cookie header must be set when revokeOtherSessions is true"
979 );
980
981 let cookie_value = set_cookie.unwrap();
982 assert!(
983 cookie_value.contains(&ctx.config.session.cookie_name),
984 "Cookie must contain the session cookie name"
985 );
986 assert!(
987 cookie_value.contains("Path=/"),
988 "Cookie must contain Path=/"
989 );
990 assert!(
991 cookie_value.contains("Expires="),
992 "Cookie must contain an expiration"
993 );
994 }
995
996 #[tokio::test]
997 async fn test_change_password_no_cookie_without_revocation() {
998 let plugin = PasswordManagementPlugin::new();
999 let (ctx, _user, session) = create_test_context_with_user().await;
1000
1001 let body = serde_json::json!({
1002 "currentPassword": "Password123!",
1003 "newPassword": "NewPassword123!"
1004 });
1005
1006 let req = test_helpers::create_auth_request_no_query(
1007 HttpMethod::Post,
1008 "/change-password",
1009 Some(&session.token),
1010 Some(body.to_string().into_bytes()),
1011 );
1012
1013 let response = plugin.handle_change_password(&req, &ctx).await.unwrap();
1014 assert_eq!(response.status, 200);
1015
1016 let set_cookie = response.headers.get("Set-Cookie");
1018 assert!(
1019 set_cookie.is_none(),
1020 "Set-Cookie header must not be set when revokeOtherSessions is not true"
1021 );
1022 }
1023
1024 #[tokio::test]
1025 async fn test_change_password_revoke_with_boolean() {
1026 let plugin = PasswordManagementPlugin::new();
1027 let (ctx, _user, session) = create_test_context_with_user().await;
1028
1029 let body = serde_json::json!({
1031 "currentPassword": "Password123!",
1032 "newPassword": "NewPassword123!",
1033 "revokeOtherSessions": true
1034 });
1035
1036 let req = test_helpers::create_auth_request_no_query(
1037 HttpMethod::Post,
1038 "/change-password",
1039 Some(&session.token),
1040 Some(body.to_string().into_bytes()),
1041 );
1042
1043 let response = plugin.handle_change_password(&req, &ctx).await.unwrap();
1044 assert_eq!(response.status, 200);
1045
1046 let body_str = String::from_utf8(response.body).unwrap();
1047 let response_data: serde_json::Value = serde_json::from_str(&body_str).unwrap();
1048 assert!(
1049 response_data["token"].is_string(),
1050 "New token must be returned when revokeOtherSessions is boolean true"
1051 );
1052 }
1053
1054 #[tokio::test]
1055 async fn test_change_password_wrong_current_password() {
1056 let plugin = PasswordManagementPlugin::new();
1057 let (ctx, _user, session) = create_test_context_with_user().await;
1058
1059 let body = serde_json::json!({
1060 "currentPassword": "WrongPassword123!",
1061 "newPassword": "NewPassword123!"
1062 });
1063
1064 let req = test_helpers::create_auth_request_no_query(
1065 HttpMethod::Post,
1066 "/change-password",
1067 Some(&session.token),
1068 Some(body.to_string().into_bytes()),
1069 );
1070
1071 let err = plugin.handle_change_password(&req, &ctx).await.unwrap_err();
1072 assert_eq!(err.status_code(), 401);
1073 }
1074
1075 #[tokio::test]
1076 async fn test_change_password_unauthorized() {
1077 let plugin = PasswordManagementPlugin::new();
1078 let (ctx, _user, _session) = create_test_context_with_user().await;
1079
1080 let body = serde_json::json!({
1081 "currentPassword": "Password123!",
1082 "newPassword": "NewPassword123!"
1083 });
1084
1085 let req = test_helpers::create_auth_request_no_query(
1086 HttpMethod::Post,
1087 "/change-password",
1088 None,
1089 Some(body.to_string().into_bytes()),
1090 );
1091
1092 let err = plugin.handle_change_password(&req, &ctx).await.unwrap_err();
1093 assert_eq!(err.status_code(), 401);
1094 }
1095
1096 #[tokio::test]
1097 async fn test_reset_password_token_endpoint_success() {
1098 let plugin = PasswordManagementPlugin::new();
1099 let (ctx, user, _session) = create_test_context_with_user().await;
1100
1101 let reset_token = create_reset_token(&ctx, user.email.as_deref().unwrap()).await;
1102
1103 let req = test_helpers::create_auth_request_no_query(
1104 HttpMethod::Get,
1105 "/reset-password/token",
1106 None,
1107 None,
1108 );
1109
1110 let response = plugin
1111 .handle_reset_password_token(&reset_token, &req, &ctx)
1112 .await
1113 .unwrap();
1114 assert_eq!(response.status, 200);
1115
1116 let body_str = String::from_utf8(response.body).unwrap();
1117 let response_data: ResetPasswordTokenResponse = serde_json::from_str(&body_str).unwrap();
1118 assert_eq!(response_data.token, reset_token);
1119 }
1120
1121 #[tokio::test]
1122 async fn test_reset_password_token_endpoint_with_callback() {
1123 let plugin = PasswordManagementPlugin::new();
1124 let (ctx, user, _session) = create_test_context_with_user().await;
1125
1126 let reset_token = create_reset_token(&ctx, user.email.as_deref().unwrap()).await;
1127
1128 let mut query = HashMap::new();
1129 query.insert(
1130 "callbackURL".to_string(),
1131 "http://localhost:3000/reset".to_string(),
1132 );
1133
1134 let req = AuthRequest::from_parts(
1135 HttpMethod::Get,
1136 "/reset-password/token".to_string(),
1137 HashMap::new(),
1138 None,
1139 query,
1140 );
1141
1142 let response = plugin
1143 .handle_reset_password_token(&reset_token, &req, &ctx)
1144 .await
1145 .unwrap();
1146 assert_eq!(response.status, 302);
1147
1148 let location_header = response
1150 .headers
1151 .iter()
1152 .find(|(key, _)| *key == "Location")
1153 .map(|(_, value)| value);
1154 assert!(location_header.is_some());
1155 assert!(
1156 location_header
1157 .unwrap()
1158 .contains("http://localhost:3000/reset")
1159 );
1160 assert!(location_header.unwrap().contains(&reset_token));
1161 }
1162
1163 #[tokio::test]
1164 async fn test_reset_password_token_endpoint_invalid_token() {
1165 let plugin = PasswordManagementPlugin::new();
1166 let (ctx, _user, _session) = create_test_context_with_user().await;
1167
1168 let req = test_helpers::create_auth_request_no_query(
1169 HttpMethod::Get,
1170 "/reset-password/token",
1171 None,
1172 None,
1173 );
1174
1175 let err = plugin
1176 .handle_reset_password_token("invalid_token", &req, &ctx)
1177 .await
1178 .unwrap_err();
1179 assert_eq!(err.status_code(), 400);
1180 }
1181
1182 #[tokio::test]
1183 async fn test_password_validation() {
1184 let plugin = PasswordManagementPlugin::new();
1185 let mut config = AuthConfig::new("test-secret");
1186 config.password = PasswordConfig {
1187 min_length: 8,
1188 require_uppercase: true,
1189 require_lowercase: true,
1190 require_numbers: true,
1191 require_special: true,
1192 argon2_config: Argon2Config::default(),
1193 };
1194 let ctx = AuthContext::new(Arc::new(config), Arc::new(MemoryDatabaseAdapter::new()));
1195
1196 assert!(plugin.validate_password("Password123!", &ctx).is_ok());
1198
1199 assert!(plugin.validate_password("Pass1!", &ctx).is_err());
1201
1202 assert!(plugin.validate_password("password123!", &ctx).is_err());
1204
1205 assert!(plugin.validate_password("PASSWORD123!", &ctx).is_err());
1207
1208 assert!(plugin.validate_password("Password!", &ctx).is_err());
1210
1211 assert!(plugin.validate_password("Password123", &ctx).is_err());
1213 }
1214
1215 #[tokio::test]
1216 async fn test_password_hashing_and_verification() {
1217 let plugin = PasswordManagementPlugin::new();
1218
1219 let password = "TestPassword123!";
1220 let hash = plugin.hash_password(password).await.unwrap();
1221
1222 assert!(plugin.verify_password(password, &hash).await.is_ok());
1224
1225 assert!(
1227 plugin
1228 .verify_password("WrongPassword123!", &hash)
1229 .await
1230 .is_err()
1231 );
1232 }
1233
1234 #[tokio::test]
1235 async fn test_plugin_routes() {
1236 let plugin = PasswordManagementPlugin::new();
1237 let routes = AuthPlugin::<MemoryDatabaseAdapter>::routes(&plugin);
1238
1239 assert_eq!(routes.len(), 5);
1240 assert!(
1241 routes
1242 .iter()
1243 .any(|r| r.path == "/forget-password" && r.method == HttpMethod::Post)
1244 );
1245 assert!(
1246 routes
1247 .iter()
1248 .any(|r| r.path == "/reset-password" && r.method == HttpMethod::Post)
1249 );
1250 assert!(
1251 routes
1252 .iter()
1253 .any(|r| r.path == "/reset-password/{token}" && r.method == HttpMethod::Get)
1254 );
1255 assert!(
1256 routes
1257 .iter()
1258 .any(|r| r.path == "/change-password" && r.method == HttpMethod::Post)
1259 );
1260 }
1261
1262 #[tokio::test]
1263 async fn test_plugin_on_request_routing() {
1264 let plugin = PasswordManagementPlugin::new();
1265 let (ctx, _user, session) = create_test_context_with_user().await;
1266
1267 let body = serde_json::json!({"email": "test@example.com"});
1269 let req = test_helpers::create_auth_request_no_query(
1270 HttpMethod::Post,
1271 "/forget-password",
1272 None,
1273 Some(body.to_string().into_bytes()),
1274 );
1275 let response = plugin.on_request(&req, &ctx).await.unwrap();
1276 assert!(response.is_some());
1277 assert_eq!(response.unwrap().status, 200);
1278
1279 let body = serde_json::json!({
1281 "currentPassword": "Password123!",
1282 "newPassword": "NewPassword123!"
1283 });
1284 let req = test_helpers::create_auth_request_no_query(
1285 HttpMethod::Post,
1286 "/change-password",
1287 Some(&session.token),
1288 Some(body.to_string().into_bytes()),
1289 );
1290 let response = plugin.on_request(&req, &ctx).await.unwrap();
1291 assert!(response.is_some());
1292 assert_eq!(response.unwrap().status, 200);
1293
1294 let req = test_helpers::create_auth_request_no_query(
1296 HttpMethod::Get,
1297 "/invalid-route",
1298 None,
1299 None,
1300 );
1301 let response = plugin.on_request(&req, &ctx).await.unwrap();
1302 assert!(response.is_none());
1303 }
1304
1305 #[tokio::test]
1306 async fn test_configuration() {
1307 let config = PasswordManagementConfig {
1308 reset_token_expiry_hours: 48,
1309 require_current_password: false,
1310 send_email_notifications: false,
1311 ..Default::default()
1312 };
1313
1314 let plugin = PasswordManagementPlugin::with_config(config);
1315 assert_eq!(plugin.config.reset_token_expiry_hours, 48);
1316 assert!(!plugin.config.require_current_password);
1317 assert!(!plugin.config.send_email_notifications);
1318 }
1319
1320 #[tokio::test]
1321 async fn test_send_reset_password_custom_sender() {
1322 use std::sync::atomic::{AtomicBool, Ordering};
1323
1324 struct TestSender {
1326 called: Arc<AtomicBool>,
1327 }
1328
1329 #[async_trait]
1330 impl SendResetPassword for TestSender {
1331 async fn send(
1332 &self,
1333 _user: &serde_json::Value,
1334 _url: &str,
1335 _token: &str,
1336 ) -> AuthResult<()> {
1337 self.called.store(true, Ordering::SeqCst);
1338 Ok(())
1339 }
1340 }
1341
1342 let called = Arc::new(AtomicBool::new(false));
1343 let sender: Arc<dyn SendResetPassword> = Arc::new(TestSender {
1344 called: called.clone(),
1345 });
1346
1347 let plugin = PasswordManagementPlugin::new().send_reset_password(sender);
1348 let (ctx, _user, _session) = create_test_context_with_user().await;
1349
1350 let body = serde_json::json!({
1351 "email": "test@example.com",
1352 "redirectTo": "http://localhost:3000/reset"
1353 });
1354 let req = test_helpers::create_auth_request_no_query(
1355 HttpMethod::Post,
1356 "/forget-password",
1357 None,
1358 Some(body.to_string().into_bytes()),
1359 );
1360
1361 let response = plugin.handle_forget_password(&req, &ctx).await.unwrap();
1362 assert_eq!(response.status, 200);
1363
1364 assert!(
1366 called.load(Ordering::SeqCst),
1367 "Custom send_reset_password should be invoked"
1368 );
1369 }
1370
1371 #[tokio::test]
1372 async fn test_on_password_reset_callback() {
1373 use std::sync::atomic::{AtomicBool, Ordering};
1374
1375 let callback_called = Arc::new(AtomicBool::new(false));
1376 let called_clone = callback_called.clone();
1377
1378 let callback: Arc<OnPasswordResetCallback> = Arc::new(move |_user_value| {
1379 let called = called_clone.clone();
1380 Box::pin(async move {
1381 called.store(true, Ordering::SeqCst);
1382 Ok(())
1383 })
1384 });
1385
1386 let plugin = PasswordManagementPlugin::new().on_password_reset(callback);
1387 let (ctx, user, _session) = create_test_context_with_user().await;
1388
1389 let reset_token = create_reset_token(&ctx, user.email.as_deref().unwrap()).await;
1390
1391 let body = serde_json::json!({
1392 "newPassword": "NewPassword123!",
1393 "token": reset_token
1394 });
1395 let req = test_helpers::create_auth_request_no_query(
1396 HttpMethod::Post,
1397 "/reset-password",
1398 None,
1399 Some(body.to_string().into_bytes()),
1400 );
1401
1402 let response = plugin.handle_reset_password(&req, &ctx).await.unwrap();
1403 assert_eq!(response.status, 200);
1404
1405 assert!(
1407 callback_called.load(Ordering::SeqCst),
1408 "on_password_reset callback should be invoked after password reset"
1409 );
1410 }
1411
1412 #[tokio::test]
1413 async fn test_revoke_sessions_on_password_reset_false() {
1414 let plugin = PasswordManagementPlugin::new().revoke_sessions_on_password_reset(false);
1415 let (ctx, user, session) = create_test_context_with_user().await;
1416
1417 let reset_token = create_reset_token(&ctx, user.email.as_deref().unwrap()).await;
1418
1419 let body = serde_json::json!({
1420 "newPassword": "NewPassword123!",
1421 "token": reset_token
1422 });
1423 let req = test_helpers::create_auth_request_no_query(
1424 HttpMethod::Post,
1425 "/reset-password",
1426 None,
1427 Some(body.to_string().into_bytes()),
1428 );
1429
1430 let response = plugin.handle_reset_password(&req, &ctx).await.unwrap();
1431 assert_eq!(response.status, 200);
1432
1433 let sessions = ctx.database.get_user_sessions(&user.id).await.unwrap();
1435 assert!(
1436 !sessions.is_empty(),
1437 "Sessions should remain when revoke_sessions_on_password_reset=false"
1438 );
1439 assert!(
1440 sessions.iter().any(|s| s.token == session.token),
1441 "The original session should still exist"
1442 );
1443 }
1444
1445 #[tokio::test]
1446 async fn test_revoke_sessions_on_password_reset_true() {
1447 let plugin = PasswordManagementPlugin::new();
1449 let (ctx, user, _session) = create_test_context_with_user().await;
1450
1451 let reset_token = create_reset_token(&ctx, user.email.as_deref().unwrap()).await;
1452
1453 let body = serde_json::json!({
1454 "newPassword": "NewPassword123!",
1455 "token": reset_token
1456 });
1457 let req = test_helpers::create_auth_request_no_query(
1458 HttpMethod::Post,
1459 "/reset-password",
1460 None,
1461 Some(body.to_string().into_bytes()),
1462 );
1463
1464 let response = plugin.handle_reset_password(&req, &ctx).await.unwrap();
1465 assert_eq!(response.status, 200);
1466
1467 let sessions = ctx.database.get_user_sessions(&user.id).await.unwrap();
1469 assert!(
1470 sessions.is_empty(),
1471 "Sessions should be revoked when revoke_sessions_on_password_reset=true"
1472 );
1473 }
1474}