1use argon2::password_hash::{SaltString, rand_core::OsRng};
2use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use validator::Validate;
6
7use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
8use better_auth_core::{AuthError, AuthResult};
9use better_auth_core::{AuthRequest, AuthResponse, CreateUser, HttpMethod, User};
10
11pub struct EmailPasswordPlugin {
13 config: EmailPasswordConfig,
14}
15
16#[derive(Debug, Clone)]
17pub struct EmailPasswordConfig {
18 pub enable_signup: bool,
19 pub require_email_verification: bool,
20 pub password_min_length: usize,
21}
22
23#[derive(Debug, Deserialize, Validate)]
24#[allow(dead_code)]
25struct SignUpRequest {
26 #[validate(length(min = 1, message = "Name is required"))]
27 name: String,
28 #[validate(email(message = "Invalid email address"))]
29 email: String,
30 #[validate(length(min = 1, message = "Password is required"))]
31 password: String,
32 username: Option<String>,
33 #[serde(rename = "displayUsername")]
34 display_username: Option<String>,
35 #[serde(rename = "callbackURL")]
36 callback_url: Option<String>,
37}
38
39#[derive(Debug, Deserialize, Validate)]
40#[allow(dead_code)]
41struct SignInRequest {
42 #[validate(email(message = "Invalid email address"))]
43 email: String,
44 #[validate(length(min = 1, message = "Password is required"))]
45 password: String,
46 #[serde(rename = "callbackURL")]
47 callback_url: Option<String>,
48 #[serde(rename = "rememberMe")]
49 remember_me: Option<bool>,
50}
51
52#[derive(Debug, Deserialize, Validate)]
53#[allow(dead_code)]
54struct SignInUsernameRequest {
55 #[validate(length(min = 1, message = "Username is required"))]
56 username: String,
57 #[validate(length(min = 1, message = "Password is required"))]
58 password: String,
59 #[serde(rename = "rememberMe")]
60 remember_me: Option<bool>,
61}
62
63#[derive(Debug, Serialize)]
64struct SignUpResponse {
65 token: Option<String>,
66 user: User,
67}
68
69#[derive(Debug, Serialize)]
70struct SignInResponse {
71 redirect: bool,
72 token: String,
73 url: Option<String>,
74 user: User,
75}
76
77impl EmailPasswordPlugin {
78 #[allow(clippy::new_without_default)]
79 pub fn new() -> Self {
80 Self {
81 config: EmailPasswordConfig::default(),
82 }
83 }
84
85 pub fn with_config(config: EmailPasswordConfig) -> Self {
86 Self { config }
87 }
88
89 pub fn enable_signup(mut self, enable: bool) -> Self {
90 self.config.enable_signup = enable;
91 self
92 }
93
94 pub fn require_email_verification(mut self, require: bool) -> Self {
95 self.config.require_email_verification = require;
96 self
97 }
98
99 pub fn password_min_length(mut self, length: usize) -> Self {
100 self.config.password_min_length = length;
101 self
102 }
103
104 async fn handle_sign_up(
105 &self,
106 req: &AuthRequest,
107 ctx: &AuthContext,
108 ) -> AuthResult<AuthResponse> {
109 if !self.config.enable_signup {
110 return Err(AuthError::forbidden("User registration is not enabled"));
111 }
112
113 let signup_req: SignUpRequest = match better_auth_core::validate_request_body(req) {
114 Ok(v) => v,
115 Err(resp) => return Ok(resp),
116 };
117
118 self.validate_password(&signup_req.password, ctx)?;
120
121 if ctx
123 .database
124 .get_user_by_email(&signup_req.email)
125 .await?
126 .is_some()
127 {
128 return Err(AuthError::conflict("A user with this email already exists"));
129 }
130
131 let password_hash = self.hash_password(&signup_req.password)?;
133
134 let mut metadata = std::collections::HashMap::new();
136 metadata.insert(
137 "password_hash".to_string(),
138 serde_json::Value::String(password_hash),
139 );
140
141 let mut create_user = CreateUser::new()
142 .with_email(&signup_req.email)
143 .with_name(&signup_req.name);
144 if let Some(username) = signup_req.username {
145 create_user = create_user.with_username(username);
146 }
147 if let Some(display_username) = signup_req.display_username {
148 create_user.display_username = Some(display_username);
149 }
150 create_user.metadata = Some(metadata);
151
152 let user = ctx.database.create_user(create_user).await?;
153
154 let session_manager =
156 better_auth_core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
157 let session = session_manager.create_session(&user, None, None).await?;
158
159 let response = SignUpResponse {
160 token: Some(session.token.clone()),
161 user,
162 };
163
164 let cookie_header = self.create_session_cookie(&session.token, ctx);
166
167 Ok(AuthResponse::json(200, &response)?.with_header("Set-Cookie", cookie_header))
168 }
169
170 async fn handle_sign_in(
171 &self,
172 req: &AuthRequest,
173 ctx: &AuthContext,
174 ) -> AuthResult<AuthResponse> {
175 let signin_req: SignInRequest = match better_auth_core::validate_request_body(req) {
176 Ok(v) => v,
177 Err(resp) => return Ok(resp),
178 };
179
180 let user = ctx
182 .database
183 .get_user_by_email(&signin_req.email)
184 .await?
185 .ok_or(AuthError::InvalidCredentials)?;
186
187 let stored_hash = user
189 .metadata
190 .get("password_hash")
191 .and_then(|v| v.as_str())
192 .ok_or(AuthError::InvalidCredentials)?;
193
194 self.verify_password(&signin_req.password, stored_hash)?;
195
196 let session_manager =
198 better_auth_core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
199 let session = session_manager.create_session(&user, None, None).await?;
200
201 let response = SignInResponse {
202 redirect: false,
203 token: session.token.clone(),
204 url: None,
205 user,
206 };
207
208 let cookie_header = self.create_session_cookie(&session.token, ctx);
210
211 Ok(AuthResponse::json(200, &response)?.with_header("Set-Cookie", cookie_header))
212 }
213
214 async fn handle_sign_in_username(
215 &self,
216 req: &AuthRequest,
217 ctx: &AuthContext,
218 ) -> AuthResult<AuthResponse> {
219 let signin_req: SignInUsernameRequest = match better_auth_core::validate_request_body(req) {
220 Ok(v) => v,
221 Err(resp) => return Ok(resp),
222 };
223
224 let user = ctx
226 .database
227 .get_user_by_username(&signin_req.username)
228 .await?
229 .ok_or(AuthError::InvalidCredentials)?;
230
231 let stored_hash = user
233 .metadata
234 .get("password_hash")
235 .and_then(|v| v.as_str())
236 .ok_or(AuthError::InvalidCredentials)?;
237
238 self.verify_password(&signin_req.password, stored_hash)?;
239
240 let session_manager =
242 better_auth_core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
243 let session = session_manager.create_session(&user, None, None).await?;
244
245 let response = SignInResponse {
246 redirect: false,
247 token: session.token.clone(),
248 url: None,
249 user,
250 };
251
252 let cookie_header = self.create_session_cookie(&session.token, ctx);
254
255 Ok(AuthResponse::json(200, &response)?.with_header("Set-Cookie", cookie_header))
256 }
257
258 fn validate_password(&self, password: &str, ctx: &AuthContext) -> AuthResult<()> {
259 if password.len() < ctx.config.password.min_length {
260 return Err(AuthError::bad_request(format!(
261 "Password must be at least {} characters long",
262 ctx.config.password.min_length
263 )));
264 }
265 Ok(())
266 }
267
268 fn hash_password(&self, password: &str) -> AuthResult<String> {
269 let salt = SaltString::generate(&mut OsRng);
270 let argon2 = Argon2::default();
271
272 let password_hash = argon2
273 .hash_password(password.as_bytes(), &salt)
274 .map_err(|e| AuthError::PasswordHash(format!("Failed to hash password: {}", e)))?;
275
276 Ok(password_hash.to_string())
277 }
278
279 fn create_session_cookie(&self, token: &str, ctx: &AuthContext) -> String {
280 let session_config = &ctx.config.session;
281 let secure = if session_config.cookie_secure {
282 "; Secure"
283 } else {
284 ""
285 };
286 let http_only = if session_config.cookie_http_only {
287 "; HttpOnly"
288 } else {
289 ""
290 };
291 let same_site = match session_config.cookie_same_site {
292 better_auth_core::config::SameSite::Strict => "; SameSite=Strict",
293 better_auth_core::config::SameSite::Lax => "; SameSite=Lax",
294 better_auth_core::config::SameSite::None => "; SameSite=None",
295 };
296
297 let expires = chrono::Utc::now() + session_config.expires_in;
298 let expires_str = expires.format("%a, %d %b %Y %H:%M:%S GMT");
299
300 format!(
301 "{}={}; Path=/; Expires={}{}{}{}",
302 session_config.cookie_name, token, expires_str, secure, http_only, same_site
303 )
304 }
305
306 fn verify_password(&self, password: &str, hash: &str) -> AuthResult<()> {
307 let parsed_hash = PasswordHash::new(hash)
308 .map_err(|e| AuthError::PasswordHash(format!("Invalid password hash: {}", e)))?;
309
310 let argon2 = Argon2::default();
311 argon2
312 .verify_password(password.as_bytes(), &parsed_hash)
313 .map_err(|_| AuthError::InvalidCredentials)?;
314
315 Ok(())
316 }
317}
318
319impl Default for EmailPasswordConfig {
320 fn default() -> Self {
321 Self {
322 enable_signup: true,
323 require_email_verification: false,
324 password_min_length: 8,
325 }
326 }
327}
328
329#[async_trait]
330impl AuthPlugin for EmailPasswordPlugin {
331 fn name(&self) -> &'static str {
332 "email-password"
333 }
334
335 fn routes(&self) -> Vec<AuthRoute> {
336 let mut routes = vec![
337 AuthRoute::post("/sign-in/email", "sign_in_email"),
338 AuthRoute::post("/sign-in/username", "sign_in_username"),
339 ];
340
341 if self.config.enable_signup {
342 routes.push(AuthRoute::post("/sign-up/email", "sign_up_email"));
343 }
344
345 routes
346 }
347
348 async fn on_request(
349 &self,
350 req: &AuthRequest,
351 ctx: &AuthContext,
352 ) -> AuthResult<Option<AuthResponse>> {
353 match (req.method(), req.path()) {
354 (HttpMethod::Post, "/sign-up/email") if self.config.enable_signup => {
355 Ok(Some(self.handle_sign_up(req, ctx).await?))
356 }
357 (HttpMethod::Post, "/sign-in/email") => Ok(Some(self.handle_sign_in(req, ctx).await?)),
358 (HttpMethod::Post, "/sign-in/username") => {
359 Ok(Some(self.handle_sign_in_username(req, ctx).await?))
360 }
361 _ => Ok(None),
362 }
363 }
364
365 async fn on_user_created(&self, user: &User, _ctx: &AuthContext) -> AuthResult<()> {
366 if self.config.require_email_verification
367 && !user.email_verified
368 && let Some(email) = &user.email
369 {
370 println!("Email verification required for user: {}", email);
371 }
372 Ok(())
373 }
374}