Skip to main content

better_auth_api/plugins/password_management/
mod.rs

1use async_trait::async_trait;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
7use better_auth_core::{AuthError, AuthResult};
8use better_auth_core::{AuthRequest, AuthResponse, HttpMethod};
9use better_auth_core::{AuthSession, DatabaseAdapter};
10
11use better_auth_core::utils::password::PasswordHasher;
12
13use super::StatusResponse;
14
15pub(super) mod handlers;
16pub(super) mod types;
17
18#[cfg(test)]
19mod tests;
20
21use handlers::*;
22use types::*;
23
24/// Type alias for the async password-reset callback to keep Clippy happy.
25pub type OnPasswordResetCallback =
26    dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = AuthResult<()>> + Send>> + Send + Sync;
27
28/// Trait for sending password reset emails.
29///
30/// When set in `PasswordManagementConfig`, this overrides the default
31/// `EmailProvider`-based reset email sending. The user is provided as a
32/// serialized `serde_json::Value` since `AuthUser` is not object-safe.
33#[async_trait]
34pub trait SendResetPassword: Send + Sync {
35    /// Send a password reset notification.
36    ///
37    /// * `user` - The user as a serialized JSON value (from `serde_json::to_value`)
38    /// * `url` - The full reset URL including the token
39    /// * `token` - The raw reset token
40    async fn send(&self, user: &serde_json::Value, url: &str, token: &str) -> AuthResult<()>;
41}
42
43/// Password management plugin for password reset and change functionality
44pub struct PasswordManagementPlugin {
45    config: PasswordManagementConfig,
46}
47
48#[derive(Clone, better_auth_core::PluginConfig)]
49#[plugin(name = "PasswordManagementPlugin")]
50pub struct PasswordManagementConfig {
51    #[config(default = 24)]
52    pub reset_token_expiry_hours: i64,
53    #[config(default = true)]
54    pub require_current_password: bool,
55    #[config(default = true)]
56    pub send_email_notifications: bool,
57    /// When true, all existing sessions are revoked on password reset (default: true).
58    #[config(default = true)]
59    pub revoke_sessions_on_password_reset: bool,
60    /// Custom password reset email sender. When set, overrides the default `EmailProvider`.
61    #[config(default = None)]
62    pub send_reset_password: Option<Arc<dyn SendResetPassword>>,
63    /// Callback invoked after a password is successfully reset.
64    /// The user is provided as a serialized `serde_json::Value`.
65    #[config(default = None)]
66    pub on_password_reset: Option<Arc<OnPasswordResetCallback>>,
67    /// Custom password hasher. When `None`, the default Argon2 hasher is used.
68    #[config(default = None)]
69    pub password_hasher: Option<Arc<dyn PasswordHasher>>,
70}
71
72impl std::fmt::Debug for PasswordManagementConfig {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        f.debug_struct("PasswordManagementConfig")
75            .field("reset_token_expiry_hours", &self.reset_token_expiry_hours)
76            .field("require_current_password", &self.require_current_password)
77            .field("send_email_notifications", &self.send_email_notifications)
78            .field(
79                "revoke_sessions_on_password_reset",
80                &self.revoke_sessions_on_password_reset,
81            )
82            .field(
83                "send_reset_password",
84                &self.send_reset_password.as_ref().map(|_| "custom"),
85            )
86            .field(
87                "on_password_reset",
88                &self.on_password_reset.as_ref().map(|_| "custom"),
89            )
90            .field(
91                "password_hasher",
92                &self.password_hasher.as_ref().map(|_| "custom"),
93            )
94            .finish()
95    }
96}
97
98#[async_trait]
99impl<DB: DatabaseAdapter> AuthPlugin<DB> for PasswordManagementPlugin {
100    fn name(&self) -> &'static str {
101        "password-management"
102    }
103
104    fn routes(&self) -> Vec<AuthRoute> {
105        vec![
106            AuthRoute::post("/forget-password", "forget_password"),
107            AuthRoute::post("/reset-password", "reset_password"),
108            AuthRoute::get("/reset-password/{token}", "reset_password_token"),
109            AuthRoute::post("/change-password", "change_password"),
110            AuthRoute::post("/set-password", "set_password"),
111        ]
112    }
113
114    async fn on_request(
115        &self,
116        req: &AuthRequest,
117        ctx: &AuthContext<DB>,
118    ) -> AuthResult<Option<AuthResponse>> {
119        match (req.method(), req.path()) {
120            (HttpMethod::Post, "/forget-password") => {
121                Ok(Some(self.handle_forget_password(req, ctx).await?))
122            }
123            (HttpMethod::Post, "/reset-password") => {
124                Ok(Some(self.handle_reset_password(req, ctx).await?))
125            }
126            (HttpMethod::Post, "/change-password") => {
127                Ok(Some(self.handle_change_password(req, ctx).await?))
128            }
129            (HttpMethod::Post, "/set-password") => {
130                Ok(Some(self.handle_set_password(req, ctx).await?))
131            }
132            (HttpMethod::Get, path) if path.starts_with("/reset-password/") => {
133                let token = &path[16..]; // Remove "/reset-password/" prefix
134                Ok(Some(
135                    self.handle_reset_password_token(token, req, ctx).await?,
136                ))
137            }
138            _ => Ok(None),
139        }
140    }
141}
142
143// Implementation methods outside the trait
144impl PasswordManagementPlugin {
145    async fn handle_forget_password<DB: DatabaseAdapter>(
146        &self,
147        req: &AuthRequest,
148        ctx: &AuthContext<DB>,
149    ) -> AuthResult<AuthResponse> {
150        let body: ForgetPasswordRequest = match better_auth_core::validate_request_body(req) {
151            Ok(v) => v,
152            Err(resp) => return Ok(resp),
153        };
154        let response = forget_password_core(&body, &self.config, ctx).await?;
155        Ok(AuthResponse::json(200, &response)?)
156    }
157
158    async fn handle_reset_password<DB: DatabaseAdapter>(
159        &self,
160        req: &AuthRequest,
161        ctx: &AuthContext<DB>,
162    ) -> AuthResult<AuthResponse> {
163        let body: ResetPasswordRequest = match better_auth_core::validate_request_body(req) {
164            Ok(v) => v,
165            Err(resp) => return Ok(resp),
166        };
167        let response = reset_password_core(&body, &self.config, ctx).await?;
168        Ok(AuthResponse::json(200, &response)?)
169    }
170
171    async fn handle_change_password<DB: DatabaseAdapter>(
172        &self,
173        req: &AuthRequest,
174        ctx: &AuthContext<DB>,
175    ) -> AuthResult<AuthResponse> {
176        let body: ChangePasswordRequest = match better_auth_core::validate_request_body(req) {
177            Ok(v) => v,
178            Err(resp) => return Ok(resp),
179        };
180
181        // Get current user from session
182        let user = self
183            .get_current_user(req, ctx)
184            .await?
185            .ok_or(AuthError::Unauthenticated)?;
186
187        let (response, new_token) = change_password_core(&body, &user, &self.config, ctx).await?;
188
189        let auth_response = AuthResponse::json(200, &response)?;
190
191        // Set session cookie if a new session was created
192        if let Some(token) = new_token {
193            let cookie_header =
194                better_auth_core::utils::cookie_utils::create_session_cookie(&token, &ctx.config);
195            Ok(auth_response.with_header("Set-Cookie", cookie_header))
196        } else {
197            Ok(auth_response)
198        }
199    }
200
201    async fn handle_set_password<DB: DatabaseAdapter>(
202        &self,
203        req: &AuthRequest,
204        ctx: &AuthContext<DB>,
205    ) -> AuthResult<AuthResponse> {
206        let body: SetPasswordRequest = match better_auth_core::validate_request_body(req) {
207            Ok(v) => v,
208            Err(resp) => return Ok(resp),
209        };
210
211        // Authenticate user
212        let user = self
213            .get_current_user(req, ctx)
214            .await?
215            .ok_or(AuthError::Unauthenticated)?;
216
217        let response = set_password_core(&body, &user, &self.config, ctx).await?;
218        Ok(AuthResponse::json(200, &response)?)
219    }
220
221    async fn handle_reset_password_token<DB: DatabaseAdapter>(
222        &self,
223        token: &str,
224        req: &AuthRequest,
225        ctx: &AuthContext<DB>,
226    ) -> AuthResult<AuthResponse> {
227        let query = ResetPasswordTokenQuery {
228            callback_url: req.query.get("callbackURL").cloned(),
229        };
230        match reset_password_token_core(token, &query, ctx).await? {
231            ResetPasswordTokenResult::Redirect(url) => {
232                let mut headers = std::collections::HashMap::new();
233                headers.insert("Location".to_string(), url);
234                Ok(AuthResponse {
235                    status: 302,
236                    headers,
237                    body: Vec::new(),
238                })
239            }
240            ResetPasswordTokenResult::Json(data) => Ok(AuthResponse::json(200, &data)?),
241        }
242    }
243
244    async fn get_current_user<DB: DatabaseAdapter>(
245        &self,
246        req: &AuthRequest,
247        ctx: &AuthContext<DB>,
248    ) -> AuthResult<Option<DB::User>> {
249        let session_manager = ctx.session_manager();
250
251        if let Some(token) = session_manager.extract_session_token(req)
252            && let Some(session) = session_manager.get_session(&token).await?
253        {
254            return ctx.database.get_user_by_id(session.user_id()).await;
255        }
256
257        Ok(None)
258    }
259}
260
261#[cfg(test)]
262impl PasswordManagementPlugin {
263    fn validate_password<DB: DatabaseAdapter>(
264        &self,
265        password: &str,
266        ctx: &AuthContext<DB>,
267    ) -> AuthResult<()> {
268        better_auth_core::utils::password::validate_password(
269            password,
270            ctx.config.password.min_length,
271            usize::MAX,
272            ctx,
273        )
274    }
275
276    async fn hash_password(&self, password: &str) -> AuthResult<String> {
277        better_auth_core::utils::password::hash_password(
278            self.config.password_hasher.as_ref(),
279            password,
280        )
281        .await
282    }
283
284    async fn verify_password(&self, password: &str, hash: &str) -> AuthResult<()> {
285        better_auth_core::utils::password::verify_password(
286            self.config.password_hasher.as_ref(),
287            password,
288            hash,
289        )
290        .await
291    }
292}
293
294#[cfg(feature = "axum")]
295mod axum_impl {
296    use super::*;
297    use std::sync::Arc;
298
299    use axum::extract::{Extension, Path, Query, State};
300    use axum::response::IntoResponse;
301    use axum::{Json, http::header};
302    use better_auth_core::{AuthState, CurrentSession, ValidatedJson};
303
304    #[derive(Clone)]
305    struct PluginState {
306        config: PasswordManagementConfig,
307    }
308
309    async fn handle_forget_password<DB: DatabaseAdapter>(
310        State(state): State<AuthState<DB>>,
311        Extension(ps): Extension<Arc<PluginState>>,
312        ValidatedJson(body): ValidatedJson<ForgetPasswordRequest>,
313    ) -> Result<Json<StatusResponse>, AuthError> {
314        let ctx = state.to_context();
315        let response = forget_password_core(&body, &ps.config, &ctx).await?;
316        Ok(Json(response))
317    }
318
319    async fn handle_reset_password<DB: DatabaseAdapter>(
320        State(state): State<AuthState<DB>>,
321        Extension(ps): Extension<Arc<PluginState>>,
322        ValidatedJson(body): ValidatedJson<ResetPasswordRequest>,
323    ) -> Result<Json<StatusResponse>, AuthError> {
324        let ctx = state.to_context();
325        let response = reset_password_core(&body, &ps.config, &ctx).await?;
326        Ok(Json(response))
327    }
328
329    async fn handle_reset_password_token<DB: DatabaseAdapter>(
330        State(state): State<AuthState<DB>>,
331        Path(token): Path<String>,
332        Query(query): Query<ResetPasswordTokenQuery>,
333    ) -> Result<axum::response::Response, AuthError> {
334        let ctx = state.to_context();
335        match reset_password_token_core(&token, &query, &ctx).await? {
336            ResetPasswordTokenResult::Redirect(url) => {
337                Ok(axum::response::Redirect::to(&url).into_response())
338            }
339            ResetPasswordTokenResult::Json(data) => Ok(Json(data).into_response()),
340        }
341    }
342
343    async fn handle_change_password<DB: DatabaseAdapter>(
344        State(state): State<AuthState<DB>>,
345        Extension(ps): Extension<Arc<PluginState>>,
346        CurrentSession { user, .. }: CurrentSession<DB>,
347        ValidatedJson(body): ValidatedJson<ChangePasswordRequest>,
348    ) -> Result<axum::response::Response, AuthError> {
349        let ctx = state.to_context();
350        let (response, new_token) = change_password_core(&body, &user, &ps.config, &ctx).await?;
351
352        if let Some(ref token) = new_token {
353            let cookie = state.session_cookie(token);
354            Ok(([(header::SET_COOKIE, cookie)], Json(response)).into_response())
355        } else {
356            Ok(Json(response).into_response())
357        }
358    }
359
360    async fn handle_set_password<DB: DatabaseAdapter>(
361        State(state): State<AuthState<DB>>,
362        Extension(ps): Extension<Arc<PluginState>>,
363        CurrentSession { user, .. }: CurrentSession<DB>,
364        ValidatedJson(body): ValidatedJson<SetPasswordRequest>,
365    ) -> Result<Json<StatusResponse>, AuthError> {
366        let ctx = state.to_context();
367        let response = set_password_core(&body, &user, &ps.config, &ctx).await?;
368        Ok(Json(response))
369    }
370
371    impl<DB: DatabaseAdapter> better_auth_core::AxumPlugin<DB> for PasswordManagementPlugin {
372        fn name(&self) -> &'static str {
373            "password-management"
374        }
375
376        fn router(&self) -> axum::Router<AuthState<DB>> {
377            use axum::routing::{get, post};
378
379            let plugin_state = Arc::new(PluginState {
380                config: self.config.clone(),
381            });
382
383            axum::Router::new()
384                .route("/forget-password", post(handle_forget_password::<DB>))
385                .route("/reset-password", post(handle_reset_password::<DB>))
386                .route(
387                    "/reset-password/:token",
388                    get(handle_reset_password_token::<DB>),
389                )
390                .route("/change-password", post(handle_change_password::<DB>))
391                .route("/set-password", post(handle_set_password::<DB>))
392                .layer(Extension(plugin_state))
393        }
394    }
395}