1use async_trait::async_trait;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
7use better_auth_core::{AuthError, AuthResult};
8use better_auth_core::{AuthRequest, AuthResponse, HttpMethod};
9use better_auth_core::{AuthSession, DatabaseAdapter};
10
11use better_auth_core::utils::password::PasswordHasher;
12
13use super::StatusResponse;
14
15pub(super) mod handlers;
16pub(super) mod types;
17
18#[cfg(test)]
19mod tests;
20
21use handlers::*;
22use types::*;
23
24pub type OnPasswordResetCallback =
26 dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = AuthResult<()>> + Send>> + Send + Sync;
27
28#[async_trait]
34pub trait SendResetPassword: Send + Sync {
35 async fn send(&self, user: &serde_json::Value, url: &str, token: &str) -> AuthResult<()>;
41}
42
43pub struct PasswordManagementPlugin {
45 config: PasswordManagementConfig,
46}
47
48#[derive(Clone, better_auth_core::PluginConfig)]
49#[plugin(name = "PasswordManagementPlugin")]
50pub struct PasswordManagementConfig {
51 #[config(default = 24)]
52 pub reset_token_expiry_hours: i64,
53 #[config(default = true)]
54 pub require_current_password: bool,
55 #[config(default = true)]
56 pub send_email_notifications: bool,
57 #[config(default = true)]
59 pub revoke_sessions_on_password_reset: bool,
60 #[config(default = None)]
62 pub send_reset_password: Option<Arc<dyn SendResetPassword>>,
63 #[config(default = None)]
66 pub on_password_reset: Option<Arc<OnPasswordResetCallback>>,
67 #[config(default = None)]
69 pub password_hasher: Option<Arc<dyn PasswordHasher>>,
70}
71
72impl std::fmt::Debug for PasswordManagementConfig {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 f.debug_struct("PasswordManagementConfig")
75 .field("reset_token_expiry_hours", &self.reset_token_expiry_hours)
76 .field("require_current_password", &self.require_current_password)
77 .field("send_email_notifications", &self.send_email_notifications)
78 .field(
79 "revoke_sessions_on_password_reset",
80 &self.revoke_sessions_on_password_reset,
81 )
82 .field(
83 "send_reset_password",
84 &self.send_reset_password.as_ref().map(|_| "custom"),
85 )
86 .field(
87 "on_password_reset",
88 &self.on_password_reset.as_ref().map(|_| "custom"),
89 )
90 .field(
91 "password_hasher",
92 &self.password_hasher.as_ref().map(|_| "custom"),
93 )
94 .finish()
95 }
96}
97
98#[async_trait]
99impl<DB: DatabaseAdapter> AuthPlugin<DB> for PasswordManagementPlugin {
100 fn name(&self) -> &'static str {
101 "password-management"
102 }
103
104 fn routes(&self) -> Vec<AuthRoute> {
105 vec![
106 AuthRoute::post("/forget-password", "forget_password"),
107 AuthRoute::post("/reset-password", "reset_password"),
108 AuthRoute::get("/reset-password/{token}", "reset_password_token"),
109 AuthRoute::post("/change-password", "change_password"),
110 AuthRoute::post("/set-password", "set_password"),
111 ]
112 }
113
114 async fn on_request(
115 &self,
116 req: &AuthRequest,
117 ctx: &AuthContext<DB>,
118 ) -> AuthResult<Option<AuthResponse>> {
119 match (req.method(), req.path()) {
120 (HttpMethod::Post, "/forget-password") => {
121 Ok(Some(self.handle_forget_password(req, ctx).await?))
122 }
123 (HttpMethod::Post, "/reset-password") => {
124 Ok(Some(self.handle_reset_password(req, ctx).await?))
125 }
126 (HttpMethod::Post, "/change-password") => {
127 Ok(Some(self.handle_change_password(req, ctx).await?))
128 }
129 (HttpMethod::Post, "/set-password") => {
130 Ok(Some(self.handle_set_password(req, ctx).await?))
131 }
132 (HttpMethod::Get, path) if path.starts_with("/reset-password/") => {
133 let token = &path[16..]; Ok(Some(
135 self.handle_reset_password_token(token, req, ctx).await?,
136 ))
137 }
138 _ => Ok(None),
139 }
140 }
141}
142
143impl PasswordManagementPlugin {
145 async fn handle_forget_password<DB: DatabaseAdapter>(
146 &self,
147 req: &AuthRequest,
148 ctx: &AuthContext<DB>,
149 ) -> AuthResult<AuthResponse> {
150 let body: ForgetPasswordRequest = match better_auth_core::validate_request_body(req) {
151 Ok(v) => v,
152 Err(resp) => return Ok(resp),
153 };
154 let response = forget_password_core(&body, &self.config, ctx).await?;
155 Ok(AuthResponse::json(200, &response)?)
156 }
157
158 async fn handle_reset_password<DB: DatabaseAdapter>(
159 &self,
160 req: &AuthRequest,
161 ctx: &AuthContext<DB>,
162 ) -> AuthResult<AuthResponse> {
163 let body: ResetPasswordRequest = match better_auth_core::validate_request_body(req) {
164 Ok(v) => v,
165 Err(resp) => return Ok(resp),
166 };
167 let response = reset_password_core(&body, &self.config, ctx).await?;
168 Ok(AuthResponse::json(200, &response)?)
169 }
170
171 async fn handle_change_password<DB: DatabaseAdapter>(
172 &self,
173 req: &AuthRequest,
174 ctx: &AuthContext<DB>,
175 ) -> AuthResult<AuthResponse> {
176 let body: ChangePasswordRequest = match better_auth_core::validate_request_body(req) {
177 Ok(v) => v,
178 Err(resp) => return Ok(resp),
179 };
180
181 let user = self
183 .get_current_user(req, ctx)
184 .await?
185 .ok_or(AuthError::Unauthenticated)?;
186
187 let (response, new_token) = change_password_core(&body, &user, &self.config, ctx).await?;
188
189 let auth_response = AuthResponse::json(200, &response)?;
190
191 if let Some(token) = new_token {
193 let cookie_header =
194 better_auth_core::utils::cookie_utils::create_session_cookie(&token, &ctx.config);
195 Ok(auth_response.with_header("Set-Cookie", cookie_header))
196 } else {
197 Ok(auth_response)
198 }
199 }
200
201 async fn handle_set_password<DB: DatabaseAdapter>(
202 &self,
203 req: &AuthRequest,
204 ctx: &AuthContext<DB>,
205 ) -> AuthResult<AuthResponse> {
206 let body: SetPasswordRequest = match better_auth_core::validate_request_body(req) {
207 Ok(v) => v,
208 Err(resp) => return Ok(resp),
209 };
210
211 let user = self
213 .get_current_user(req, ctx)
214 .await?
215 .ok_or(AuthError::Unauthenticated)?;
216
217 let response = set_password_core(&body, &user, &self.config, ctx).await?;
218 Ok(AuthResponse::json(200, &response)?)
219 }
220
221 async fn handle_reset_password_token<DB: DatabaseAdapter>(
222 &self,
223 token: &str,
224 req: &AuthRequest,
225 ctx: &AuthContext<DB>,
226 ) -> AuthResult<AuthResponse> {
227 let query = ResetPasswordTokenQuery {
228 callback_url: req.query.get("callbackURL").cloned(),
229 };
230 match reset_password_token_core(token, &query, ctx).await? {
231 ResetPasswordTokenResult::Redirect(url) => {
232 let mut headers = std::collections::HashMap::new();
233 headers.insert("Location".to_string(), url);
234 Ok(AuthResponse {
235 status: 302,
236 headers,
237 body: Vec::new(),
238 })
239 }
240 ResetPasswordTokenResult::Json(data) => Ok(AuthResponse::json(200, &data)?),
241 }
242 }
243
244 async fn get_current_user<DB: DatabaseAdapter>(
245 &self,
246 req: &AuthRequest,
247 ctx: &AuthContext<DB>,
248 ) -> AuthResult<Option<DB::User>> {
249 let session_manager = ctx.session_manager();
250
251 if let Some(token) = session_manager.extract_session_token(req)
252 && let Some(session) = session_manager.get_session(&token).await?
253 {
254 return ctx.database.get_user_by_id(session.user_id()).await;
255 }
256
257 Ok(None)
258 }
259}
260
261#[cfg(test)]
262impl PasswordManagementPlugin {
263 fn validate_password<DB: DatabaseAdapter>(
264 &self,
265 password: &str,
266 ctx: &AuthContext<DB>,
267 ) -> AuthResult<()> {
268 better_auth_core::utils::password::validate_password(
269 password,
270 ctx.config.password.min_length,
271 usize::MAX,
272 ctx,
273 )
274 }
275
276 async fn hash_password(&self, password: &str) -> AuthResult<String> {
277 better_auth_core::utils::password::hash_password(
278 self.config.password_hasher.as_ref(),
279 password,
280 )
281 .await
282 }
283
284 async fn verify_password(&self, password: &str, hash: &str) -> AuthResult<()> {
285 better_auth_core::utils::password::verify_password(
286 self.config.password_hasher.as_ref(),
287 password,
288 hash,
289 )
290 .await
291 }
292}
293
294#[cfg(feature = "axum")]
295mod axum_impl {
296 use super::*;
297 use std::sync::Arc;
298
299 use axum::extract::{Extension, Path, Query, State};
300 use axum::response::IntoResponse;
301 use axum::{Json, http::header};
302 use better_auth_core::{AuthState, CurrentSession, ValidatedJson};
303
304 #[derive(Clone)]
305 struct PluginState {
306 config: PasswordManagementConfig,
307 }
308
309 async fn handle_forget_password<DB: DatabaseAdapter>(
310 State(state): State<AuthState<DB>>,
311 Extension(ps): Extension<Arc<PluginState>>,
312 ValidatedJson(body): ValidatedJson<ForgetPasswordRequest>,
313 ) -> Result<Json<StatusResponse>, AuthError> {
314 let ctx = state.to_context();
315 let response = forget_password_core(&body, &ps.config, &ctx).await?;
316 Ok(Json(response))
317 }
318
319 async fn handle_reset_password<DB: DatabaseAdapter>(
320 State(state): State<AuthState<DB>>,
321 Extension(ps): Extension<Arc<PluginState>>,
322 ValidatedJson(body): ValidatedJson<ResetPasswordRequest>,
323 ) -> Result<Json<StatusResponse>, AuthError> {
324 let ctx = state.to_context();
325 let response = reset_password_core(&body, &ps.config, &ctx).await?;
326 Ok(Json(response))
327 }
328
329 async fn handle_reset_password_token<DB: DatabaseAdapter>(
330 State(state): State<AuthState<DB>>,
331 Path(token): Path<String>,
332 Query(query): Query<ResetPasswordTokenQuery>,
333 ) -> Result<axum::response::Response, AuthError> {
334 let ctx = state.to_context();
335 match reset_password_token_core(&token, &query, &ctx).await? {
336 ResetPasswordTokenResult::Redirect(url) => {
337 Ok(axum::response::Redirect::to(&url).into_response())
338 }
339 ResetPasswordTokenResult::Json(data) => Ok(Json(data).into_response()),
340 }
341 }
342
343 async fn handle_change_password<DB: DatabaseAdapter>(
344 State(state): State<AuthState<DB>>,
345 Extension(ps): Extension<Arc<PluginState>>,
346 CurrentSession { user, .. }: CurrentSession<DB>,
347 ValidatedJson(body): ValidatedJson<ChangePasswordRequest>,
348 ) -> Result<axum::response::Response, AuthError> {
349 let ctx = state.to_context();
350 let (response, new_token) = change_password_core(&body, &user, &ps.config, &ctx).await?;
351
352 if let Some(ref token) = new_token {
353 let cookie = state.session_cookie(token);
354 Ok(([(header::SET_COOKIE, cookie)], Json(response)).into_response())
355 } else {
356 Ok(Json(response).into_response())
357 }
358 }
359
360 async fn handle_set_password<DB: DatabaseAdapter>(
361 State(state): State<AuthState<DB>>,
362 Extension(ps): Extension<Arc<PluginState>>,
363 CurrentSession { user, .. }: CurrentSession<DB>,
364 ValidatedJson(body): ValidatedJson<SetPasswordRequest>,
365 ) -> Result<Json<StatusResponse>, AuthError> {
366 let ctx = state.to_context();
367 let response = set_password_core(&body, &user, &ps.config, &ctx).await?;
368 Ok(Json(response))
369 }
370
371 impl<DB: DatabaseAdapter> better_auth_core::AxumPlugin<DB> for PasswordManagementPlugin {
372 fn name(&self) -> &'static str {
373 "password-management"
374 }
375
376 fn router(&self) -> axum::Router<AuthState<DB>> {
377 use axum::routing::{get, post};
378
379 let plugin_state = Arc::new(PluginState {
380 config: self.config.clone(),
381 });
382
383 axum::Router::new()
384 .route("/forget-password", post(handle_forget_password::<DB>))
385 .route("/reset-password", post(handle_reset_password::<DB>))
386 .route(
387 "/reset-password/:token",
388 get(handle_reset_password_token::<DB>),
389 )
390 .route("/change-password", post(handle_change_password::<DB>))
391 .route("/set-password", post(handle_set_password::<DB>))
392 .layer(Extension(plugin_state))
393 }
394 }
395}