1use argon2::password_hash::{SaltString, rand_core::OsRng};
2use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use totp_rs::{Algorithm, Secret, TOTP};
6use validator::Validate;
7
8use better_auth_core::adapters::DatabaseAdapter;
9use better_auth_core::entity::{AuthSession, AuthTwoFactor, AuthUser, AuthVerification};
10use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
11use better_auth_core::{AuthError, AuthResult};
12use better_auth_core::{
13 AuthRequest, AuthResponse, CreateTwoFactor, CreateVerification, HttpMethod, UpdateUser,
14};
15
16pub struct TwoFactorPlugin {
18 config: TwoFactorConfig,
19}
20
21#[derive(Debug, Clone)]
22pub struct TwoFactorConfig {
23 pub issuer: String,
24 pub backup_code_count: usize,
25 pub backup_code_length: usize,
26 pub totp_period: u64,
27 pub totp_digits: usize,
28}
29
30impl Default for TwoFactorConfig {
31 fn default() -> Self {
32 Self {
33 issuer: "BetterAuth".to_string(),
34 backup_code_count: 10,
35 backup_code_length: 8,
36 totp_period: 30,
37 totp_digits: 6,
38 }
39 }
40}
41
42#[derive(Debug, Deserialize, Validate)]
45struct EnableRequest {
46 password: String,
47 issuer: Option<String>,
48}
49
50#[derive(Debug, Deserialize, Validate)]
51struct DisableRequest {
52 password: String,
53}
54
55#[derive(Debug, Deserialize, Validate)]
56struct GetTotpUriRequest {
57 password: String,
58}
59
60#[derive(Debug, Deserialize, Validate)]
61struct VerifyTotpRequest {
62 code: String,
63 #[serde(rename = "trustDevice")]
64 #[allow(dead_code)]
65 trust_device: Option<String>,
66}
67
68#[derive(Debug, Deserialize, Validate)]
69struct VerifyOtpRequest {
70 code: String,
71 #[serde(rename = "trustDevice")]
72 #[allow(dead_code)]
73 trust_device: Option<String>,
74}
75
76#[derive(Debug, Deserialize, Validate)]
77struct GenerateBackupCodesRequest {
78 password: String,
79}
80
81#[derive(Debug, Deserialize, Validate)]
82struct VerifyBackupCodeRequest {
83 code: String,
84 #[serde(rename = "disableSession")]
85 #[allow(dead_code)]
86 disable_session: Option<String>,
87 #[serde(rename = "trustDevice")]
88 #[allow(dead_code)]
89 trust_device: Option<String>,
90}
91
92#[derive(Debug, Serialize)]
95struct EnableResponse {
96 #[serde(rename = "totpURI")]
97 totp_uri: String,
98 #[serde(rename = "backupCodes")]
99 backup_codes: Vec<String>,
100}
101
102#[derive(Debug, Serialize)]
103struct TotpUriResponse {
104 #[serde(rename = "totpURI")]
105 totp_uri: String,
106}
107
108#[derive(Debug, Serialize)]
109struct VerifyTotpResponse<U: Serialize> {
110 status: bool,
111 token: String,
112 user: U,
113}
114
115#[derive(Debug, Serialize)]
116struct VerifyBackupCodeResponse<U: Serialize, S: Serialize> {
117 user: U,
118 session: S,
119}
120
121#[derive(Debug, Serialize)]
122struct StatusResponse {
123 status: bool,
124 #[serde(rename = "backupCodes", skip_serializing_if = "Option::is_none")]
125 backup_codes: Option<Vec<String>>,
126}
127
128impl TwoFactorPlugin {
129 pub fn new() -> Self {
130 Self {
131 config: TwoFactorConfig::default(),
132 }
133 }
134
135 pub fn with_config(config: TwoFactorConfig) -> Self {
136 Self { config }
137 }
138
139 pub fn issuer(mut self, issuer: impl Into<String>) -> Self {
140 self.config.issuer = issuer.into();
141 self
142 }
143
144 fn generate_backup_codes(&self) -> Vec<String> {
147 use rand::Rng;
148 (0..self.config.backup_code_count)
149 .map(|_| {
150 rand::thread_rng()
151 .sample_iter(&rand::distributions::Alphanumeric)
152 .take(self.config.backup_code_length)
153 .map(char::from)
154 .collect::<String>()
155 .to_uppercase()
156 })
157 .collect()
158 }
159
160 fn hash_backup_codes(codes: &[String]) -> AuthResult<String> {
161 let argon2 = Argon2::default();
162 let mut hashed = Vec::with_capacity(codes.len());
163 for code in codes {
164 let salt = SaltString::generate(&mut OsRng);
165 let hash = argon2
166 .hash_password(code.as_bytes(), &salt)
167 .map_err(|e| AuthError::internal(format!("Failed to hash backup code: {}", e)))?;
168 hashed.push(hash.to_string());
169 }
170 serde_json::to_string(&hashed).map_err(|e| AuthError::internal(e.to_string()))
171 }
172
173 fn build_totp(&self, secret: &[u8], email: &str, issuer: &str) -> AuthResult<TOTP> {
174 TOTP::new(
175 Algorithm::SHA1,
176 self.config.totp_digits,
177 1,
178 self.config.totp_period,
179 secret.to_vec(),
180 Some(issuer.to_string()),
181 email.to_string(),
182 )
183 .map_err(|e| AuthError::internal(format!("Failed to create TOTP: {}", e)))
184 }
185
186 async fn get_authenticated_user<DB: DatabaseAdapter>(
189 req: &AuthRequest,
190 ctx: &AuthContext<DB>,
191 ) -> AuthResult<(DB::User, DB::Session)> {
192 let token = req
193 .headers
194 .get("authorization")
195 .and_then(|v| v.strip_prefix("Bearer "))
196 .ok_or(AuthError::Unauthenticated)?;
197
198 let session = ctx
199 .database
200 .get_session(token)
201 .await?
202 .ok_or(AuthError::Unauthenticated)?;
203
204 if session.expires_at() < chrono::Utc::now() {
205 return Err(AuthError::Unauthenticated);
206 }
207
208 let user = ctx
209 .database
210 .get_user_by_id(session.user_id())
211 .await?
212 .ok_or(AuthError::UserNotFound)?;
213
214 Ok((user, session))
215 }
216
217 async fn get_pending_2fa_user<DB: DatabaseAdapter>(
219 req: &AuthRequest,
220 ctx: &AuthContext<DB>,
221 ) -> AuthResult<(DB::User, String)> {
222 let token = req
223 .headers
224 .get("authorization")
225 .and_then(|v| v.strip_prefix("Bearer "))
226 .ok_or(AuthError::Unauthenticated)?;
227
228 if !token.starts_with("2fa_") {
229 return Err(AuthError::bad_request("Invalid 2FA pending token"));
230 }
231
232 let identifier = format!("2fa_pending:{}", token);
233 let verification = ctx
234 .database
235 .get_verification_by_identifier(&identifier)
236 .await?
237 .ok_or_else(|| AuthError::bad_request("Invalid or expired 2FA token"))?;
238
239 if verification.expires_at() < chrono::Utc::now() {
240 return Err(AuthError::bad_request("2FA token expired"));
241 }
242
243 let user_id = verification.value();
244 let user = ctx
245 .database
246 .get_user_by_id(user_id)
247 .await?
248 .ok_or(AuthError::UserNotFound)?;
249
250 Ok((user, verification.id().to_string()))
251 }
252
253 fn verify_user_password<U: AuthUser>(user: &U, password: &str) -> AuthResult<()> {
254 let stored_hash = user
255 .metadata()
256 .get("password_hash")
257 .and_then(|v| v.as_str())
258 .ok_or(AuthError::InvalidCredentials)?;
259
260 let parsed_hash = PasswordHash::new(stored_hash)
261 .map_err(|e| AuthError::internal(format!("Invalid password hash: {}", e)))?;
262
263 Argon2::default()
264 .verify_password(password.as_bytes(), &parsed_hash)
265 .map_err(|_| AuthError::InvalidCredentials)?;
266
267 Ok(())
268 }
269
270 fn create_session_cookie<DB: DatabaseAdapter>(token: &str, ctx: &AuthContext<DB>) -> String {
271 let session_config = &ctx.config.session;
272 let secure = if session_config.cookie_secure {
273 "; Secure"
274 } else {
275 ""
276 };
277 let http_only = if session_config.cookie_http_only {
278 "; HttpOnly"
279 } else {
280 ""
281 };
282 let same_site = match session_config.cookie_same_site {
283 better_auth_core::config::SameSite::Strict => "; SameSite=Strict",
284 better_auth_core::config::SameSite::Lax => "; SameSite=Lax",
285 better_auth_core::config::SameSite::None => "; SameSite=None",
286 };
287
288 let expires = chrono::Utc::now() + session_config.expires_in;
289 let expires_str = expires.format("%a, %d %b %Y %H:%M:%S GMT");
290
291 format!(
292 "{}={}; Path=/; Expires={}{}{}{}",
293 session_config.cookie_name, token, expires_str, secure, http_only, same_site
294 )
295 }
296
297 async fn handle_enable<DB: DatabaseAdapter>(
300 &self,
301 req: &AuthRequest,
302 ctx: &AuthContext<DB>,
303 ) -> AuthResult<AuthResponse> {
304 let (user, _session) = Self::get_authenticated_user(req, ctx).await?;
305
306 let enable_req: EnableRequest = match better_auth_core::validate_request_body(req) {
307 Ok(v) => v,
308 Err(resp) => return Ok(resp),
309 };
310
311 Self::verify_user_password(&user, &enable_req.password)?;
312
313 let secret = Secret::generate_secret();
315 let secret_encoded = secret.to_encoded().to_string();
316 let secret_bytes = secret.to_bytes().map_err(|e| {
317 AuthError::internal(format!("Failed to convert secret to bytes: {}", e))
318 })?;
319
320 let issuer = enable_req.issuer.as_deref().unwrap_or(&self.config.issuer);
321 let email = user.email().unwrap_or("user");
322
323 let totp = self.build_totp(&secret_bytes, email, issuer)?;
324 let totp_uri = totp.get_url();
325
326 let backup_codes = self.generate_backup_codes();
328 let hashed_codes = Self::hash_backup_codes(&backup_codes)?;
329
330 ctx.database
332 .create_two_factor(CreateTwoFactor {
333 user_id: user.id().to_string(),
334 secret: secret_encoded,
335 backup_codes: Some(hashed_codes),
336 })
337 .await?;
338
339 ctx.database
341 .update_user(
342 user.id(),
343 UpdateUser {
344 email: None,
345 name: None,
346 image: None,
347 email_verified: None,
348 username: None,
349 display_username: None,
350 role: None,
351 banned: None,
352 ban_reason: None,
353 ban_expires: None,
354 two_factor_enabled: Some(true),
355 metadata: None,
356 },
357 )
358 .await?;
359
360 let response = EnableResponse {
361 totp_uri,
362 backup_codes,
363 };
364 AuthResponse::json(200, &response).map_err(AuthError::from)
365 }
366
367 async fn handle_disable<DB: DatabaseAdapter>(
368 &self,
369 req: &AuthRequest,
370 ctx: &AuthContext<DB>,
371 ) -> AuthResult<AuthResponse> {
372 let (user, _session) = Self::get_authenticated_user(req, ctx).await?;
373
374 let disable_req: DisableRequest = match better_auth_core::validate_request_body(req) {
375 Ok(v) => v,
376 Err(resp) => return Ok(resp),
377 };
378
379 Self::verify_user_password(&user, &disable_req.password)?;
380
381 ctx.database.delete_two_factor(user.id()).await?;
382
383 ctx.database
384 .update_user(
385 user.id(),
386 UpdateUser {
387 email: None,
388 name: None,
389 image: None,
390 email_verified: None,
391 username: None,
392 display_username: None,
393 role: None,
394 banned: None,
395 ban_reason: None,
396 ban_expires: None,
397 two_factor_enabled: Some(false),
398 metadata: None,
399 },
400 )
401 .await?;
402
403 let response = StatusResponse {
404 status: true,
405 backup_codes: None,
406 };
407 AuthResponse::json(200, &response).map_err(AuthError::from)
408 }
409
410 async fn handle_get_totp_uri<DB: DatabaseAdapter>(
411 &self,
412 req: &AuthRequest,
413 ctx: &AuthContext<DB>,
414 ) -> AuthResult<AuthResponse> {
415 let (user, _session) = Self::get_authenticated_user(req, ctx).await?;
416
417 let uri_req: GetTotpUriRequest = match better_auth_core::validate_request_body(req) {
418 Ok(v) => v,
419 Err(resp) => return Ok(resp),
420 };
421
422 Self::verify_user_password(&user, &uri_req.password)?;
423
424 let two_factor = ctx
425 .database
426 .get_two_factor_by_user_id(user.id())
427 .await?
428 .ok_or_else(|| AuthError::not_found("Two-factor authentication not enabled"))?;
429
430 let secret = Secret::Encoded(two_factor.secret().to_string());
431 let secret_bytes = secret
432 .to_bytes()
433 .map_err(|e| AuthError::internal(format!("Failed to decode secret: {}", e)))?;
434
435 let email = user.email().unwrap_or("user");
436 let totp = self.build_totp(&secret_bytes, email, &self.config.issuer)?;
437
438 let response = TotpUriResponse {
439 totp_uri: totp.get_url(),
440 };
441 AuthResponse::json(200, &response).map_err(AuthError::from)
442 }
443
444 async fn handle_verify_totp<DB: DatabaseAdapter>(
445 &self,
446 req: &AuthRequest,
447 ctx: &AuthContext<DB>,
448 ) -> AuthResult<AuthResponse> {
449 let (user, verification_id) = Self::get_pending_2fa_user(req, ctx).await?;
450
451 let verify_req: VerifyTotpRequest = match better_auth_core::validate_request_body(req) {
452 Ok(v) => v,
453 Err(resp) => return Ok(resp),
454 };
455
456 let two_factor = ctx
457 .database
458 .get_two_factor_by_user_id(user.id())
459 .await?
460 .ok_or_else(|| AuthError::not_found("Two-factor authentication not enabled"))?;
461
462 let secret = Secret::Encoded(two_factor.secret().to_string());
463 let secret_bytes = secret
464 .to_bytes()
465 .map_err(|e| AuthError::internal(format!("Failed to decode secret: {}", e)))?;
466
467 let email = user.email().unwrap_or("user");
468 let totp = self.build_totp(&secret_bytes, email, &self.config.issuer)?;
469
470 if !totp
471 .check_current(&verify_req.code)
472 .map_err(|e| AuthError::internal(format!("TOTP check error: {}", e)))?
473 {
474 return Err(AuthError::bad_request("Invalid TOTP code"));
475 }
476
477 let session_manager =
479 better_auth_core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
480 let session = session_manager.create_session(&user, None, None).await?;
481
482 ctx.database.delete_verification(&verification_id).await?;
484
485 let cookie_header = Self::create_session_cookie(session.token(), ctx);
486 let response = VerifyTotpResponse {
487 status: true,
488 token: session.token().to_string(),
489 user,
490 };
491
492 Ok(AuthResponse::json(200, &response)?.with_header("Set-Cookie", cookie_header))
493 }
494
495 async fn handle_send_otp<DB: DatabaseAdapter>(
496 &self,
497 req: &AuthRequest,
498 ctx: &AuthContext<DB>,
499 ) -> AuthResult<AuthResponse> {
500 let (user, _verification_id) = Self::get_pending_2fa_user(req, ctx).await?;
501
502 use rand::Rng;
504 let otp: String = format!("{:06}", rand::thread_rng().gen_range(0..1_000_000u32));
505
506 let expires_at = chrono::Utc::now() + chrono::Duration::minutes(5);
508 ctx.database
509 .create_verification(CreateVerification {
510 identifier: format!("2fa_otp:{}", user.id()),
511 value: otp.clone(),
512 expires_at,
513 })
514 .await?;
515
516 if let Some(email) = user.email()
518 && let Ok(provider) = ctx.email_provider()
519 {
520 let body = format!("Your 2FA verification code is: {}", otp);
521 let _ = provider
522 .send(email, "Your verification code", &body, &body)
523 .await;
524 }
525
526 let response = StatusResponse {
527 status: true,
528 backup_codes: None,
529 };
530 AuthResponse::json(200, &response).map_err(AuthError::from)
531 }
532
533 async fn handle_verify_otp<DB: DatabaseAdapter>(
534 &self,
535 req: &AuthRequest,
536 ctx: &AuthContext<DB>,
537 ) -> AuthResult<AuthResponse> {
538 let (user, pending_verification_id) = Self::get_pending_2fa_user(req, ctx).await?;
539
540 let verify_req: VerifyOtpRequest = match better_auth_core::validate_request_body(req) {
541 Ok(v) => v,
542 Err(resp) => return Ok(resp),
543 };
544
545 let otp_identifier = format!("2fa_otp:{}", user.id());
547 let otp_verification = ctx
548 .database
549 .get_verification_by_identifier(&otp_identifier)
550 .await?
551 .ok_or_else(|| AuthError::bad_request("No OTP found. Please request a new one."))?;
552
553 if otp_verification.expires_at() < chrono::Utc::now() {
554 return Err(AuthError::bad_request("OTP has expired"));
555 }
556
557 if otp_verification.value() != verify_req.code {
558 return Err(AuthError::bad_request("Invalid OTP code"));
559 }
560
561 let session_manager =
563 better_auth_core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
564 let session = session_manager.create_session(&user, None, None).await?;
565
566 ctx.database
568 .delete_verification(otp_verification.id())
569 .await?;
570 ctx.database
571 .delete_verification(&pending_verification_id)
572 .await?;
573
574 let cookie_header = Self::create_session_cookie(session.token(), ctx);
575 let response = VerifyTotpResponse {
576 status: true,
577 token: session.token().to_string(),
578 user,
579 };
580
581 Ok(AuthResponse::json(200, &response)?.with_header("Set-Cookie", cookie_header))
582 }
583
584 async fn handle_generate_backup_codes<DB: DatabaseAdapter>(
585 &self,
586 req: &AuthRequest,
587 ctx: &AuthContext<DB>,
588 ) -> AuthResult<AuthResponse> {
589 let (user, _session) = Self::get_authenticated_user(req, ctx).await?;
590
591 let gen_req: GenerateBackupCodesRequest = match better_auth_core::validate_request_body(req)
592 {
593 Ok(v) => v,
594 Err(resp) => return Ok(resp),
595 };
596
597 Self::verify_user_password(&user, &gen_req.password)?;
598
599 let backup_codes = self.generate_backup_codes();
601 let hashed_codes = Self::hash_backup_codes(&backup_codes)?;
602
603 ctx.database
604 .update_two_factor_backup_codes(user.id(), &hashed_codes)
605 .await?;
606
607 let response = StatusResponse {
608 status: true,
609 backup_codes: Some(backup_codes),
610 };
611 AuthResponse::json(200, &response).map_err(AuthError::from)
612 }
613
614 async fn handle_verify_backup_code<DB: DatabaseAdapter>(
615 &self,
616 req: &AuthRequest,
617 ctx: &AuthContext<DB>,
618 ) -> AuthResult<AuthResponse> {
619 let (user, pending_verification_id) = Self::get_pending_2fa_user(req, ctx).await?;
620
621 let verify_req: VerifyBackupCodeRequest = match better_auth_core::validate_request_body(req)
622 {
623 Ok(v) => v,
624 Err(resp) => return Ok(resp),
625 };
626
627 let two_factor = ctx
628 .database
629 .get_two_factor_by_user_id(user.id())
630 .await?
631 .ok_or_else(|| AuthError::not_found("Two-factor authentication not enabled"))?;
632
633 let codes_json = two_factor
634 .backup_codes()
635 .ok_or_else(|| AuthError::bad_request("No backup codes available"))?;
636
637 let hashed_codes: Vec<String> = serde_json::from_str(codes_json)
638 .map_err(|e| AuthError::internal(format!("Failed to parse backup codes: {}", e)))?;
639
640 let argon2 = Argon2::default();
642 let mut matched_index: Option<usize> = None;
643
644 for (i, hash_str) in hashed_codes.iter().enumerate() {
645 if let Ok(parsed_hash) = PasswordHash::new(hash_str)
646 && argon2
647 .verify_password(verify_req.code.as_bytes(), &parsed_hash)
648 .is_ok()
649 {
650 matched_index = Some(i);
651 break;
652 }
653 }
654
655 let idx = matched_index.ok_or_else(|| AuthError::bad_request("Invalid backup code"))?;
656
657 let mut remaining_codes = hashed_codes;
659 remaining_codes.remove(idx);
660
661 let updated_codes_json = serde_json::to_string(&remaining_codes)
662 .map_err(|e| AuthError::internal(e.to_string()))?;
663
664 ctx.database
665 .update_two_factor_backup_codes(user.id(), &updated_codes_json)
666 .await?;
667
668 let session_manager =
670 better_auth_core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
671 let session = session_manager.create_session(&user, None, None).await?;
672
673 ctx.database
675 .delete_verification(&pending_verification_id)
676 .await?;
677
678 let cookie_header = Self::create_session_cookie(session.token(), ctx);
679 let response = VerifyBackupCodeResponse { user, session };
680
681 Ok(AuthResponse::json(200, &response)?.with_header("Set-Cookie", cookie_header))
682 }
683}
684
685impl Default for TwoFactorPlugin {
686 fn default() -> Self {
687 Self::new()
688 }
689}
690
691#[async_trait]
692impl<DB: DatabaseAdapter> AuthPlugin<DB> for TwoFactorPlugin {
693 fn name(&self) -> &'static str {
694 "two-factor"
695 }
696
697 fn routes(&self) -> Vec<AuthRoute> {
698 vec![
699 AuthRoute::post("/two-factor/enable", "enable_two_factor"),
700 AuthRoute::post("/two-factor/disable", "disable_two_factor"),
701 AuthRoute::post("/two-factor/get-totp-uri", "get_totp_uri"),
702 AuthRoute::post("/two-factor/verify-totp", "verify_totp"),
703 AuthRoute::post("/two-factor/send-otp", "send_otp"),
704 AuthRoute::post("/two-factor/verify-otp", "verify_otp"),
705 AuthRoute::post("/two-factor/generate-backup-codes", "generate_backup_codes"),
706 AuthRoute::post("/two-factor/verify-backup-code", "verify_backup_code"),
707 ]
708 }
709
710 async fn on_request(
711 &self,
712 req: &AuthRequest,
713 ctx: &AuthContext<DB>,
714 ) -> AuthResult<Option<AuthResponse>> {
715 match (req.method(), req.path()) {
716 (HttpMethod::Post, "/two-factor/enable") => {
717 Ok(Some(self.handle_enable(req, ctx).await?))
718 }
719 (HttpMethod::Post, "/two-factor/disable") => {
720 Ok(Some(self.handle_disable(req, ctx).await?))
721 }
722 (HttpMethod::Post, "/two-factor/get-totp-uri") => {
723 Ok(Some(self.handle_get_totp_uri(req, ctx).await?))
724 }
725 (HttpMethod::Post, "/two-factor/verify-totp") => {
726 Ok(Some(self.handle_verify_totp(req, ctx).await?))
727 }
728 (HttpMethod::Post, "/two-factor/send-otp") => {
729 Ok(Some(self.handle_send_otp(req, ctx).await?))
730 }
731 (HttpMethod::Post, "/two-factor/verify-otp") => {
732 Ok(Some(self.handle_verify_otp(req, ctx).await?))
733 }
734 (HttpMethod::Post, "/two-factor/generate-backup-codes") => {
735 Ok(Some(self.handle_generate_backup_codes(req, ctx).await?))
736 }
737 (HttpMethod::Post, "/two-factor/verify-backup-code") => {
738 Ok(Some(self.handle_verify_backup_code(req, ctx).await?))
739 }
740 _ => Ok(None),
741 }
742 }
743}