1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use totp_rs::{Algorithm, Secret, TOTP};
4use validator::Validate;
5
6use better_auth_core::adapters::DatabaseAdapter;
7use better_auth_core::entity::{AuthSession, AuthTwoFactor, AuthUser, AuthVerification};
8use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
9use better_auth_core::{AuthError, AuthResult};
10use better_auth_core::{
11 AuthRequest, AuthResponse, CreateTwoFactor, CreateVerification, HttpMethod, UpdateUser,
12};
13
14use better_auth_core::utils::cookie_utils::create_session_cookie;
15
16use super::StatusResponse;
17
18pub struct TwoFactorPlugin {
20 config: TwoFactorConfig,
21}
22
23#[derive(Debug, Clone)]
24pub struct TwoFactorConfig {
25 pub issuer: String,
26 pub backup_code_count: usize,
27 pub backup_code_length: usize,
28 pub totp_period: u64,
29 pub totp_digits: usize,
30}
31
32impl Default for TwoFactorConfig {
33 fn default() -> Self {
34 Self {
35 issuer: "BetterAuth".to_string(),
36 backup_code_count: 10,
37 backup_code_length: 8,
38 totp_period: 30,
39 totp_digits: 6,
40 }
41 }
42}
43
44#[derive(Debug, Deserialize, Validate)]
47struct EnableRequest {
48 password: String,
49 issuer: Option<String>,
50}
51
52#[derive(Debug, Deserialize, Validate)]
53struct DisableRequest {
54 password: String,
55}
56
57#[derive(Debug, Deserialize, Validate)]
58struct GetTotpUriRequest {
59 password: String,
60}
61
62#[derive(Debug, Deserialize, Validate)]
63struct VerifyTotpRequest {
64 code: String,
65 #[serde(rename = "trustDevice")]
66 #[allow(dead_code)]
67 trust_device: Option<String>,
68}
69
70#[derive(Debug, Deserialize, Validate)]
71struct VerifyOtpRequest {
72 code: String,
73 #[serde(rename = "trustDevice")]
74 #[allow(dead_code)]
75 trust_device: Option<String>,
76}
77
78#[derive(Debug, Deserialize, Validate)]
79struct GenerateBackupCodesRequest {
80 password: String,
81}
82
83#[derive(Debug, Deserialize, Validate)]
84struct VerifyBackupCodeRequest {
85 code: String,
86 #[serde(rename = "disableSession")]
87 #[allow(dead_code)]
88 disable_session: Option<String>,
89 #[serde(rename = "trustDevice")]
90 #[allow(dead_code)]
91 trust_device: Option<String>,
92}
93
94#[derive(Debug, Serialize)]
97struct EnableResponse {
98 #[serde(rename = "totpURI")]
99 totp_uri: String,
100 #[serde(rename = "backupCodes")]
101 backup_codes: Vec<String>,
102}
103
104#[derive(Debug, Serialize)]
105struct TotpUriResponse {
106 #[serde(rename = "totpURI")]
107 totp_uri: String,
108}
109
110#[derive(Debug, Serialize)]
111struct VerifyTotpResponse<U: Serialize> {
112 status: bool,
113 token: String,
114 user: U,
115}
116
117#[derive(Debug, Serialize)]
118struct VerifyBackupCodeResponse<U: Serialize, S: Serialize> {
119 user: U,
120 session: S,
121}
122
123#[derive(Debug, Serialize)]
124struct BackupCodesResponse {
125 status: bool,
126 #[serde(rename = "backupCodes")]
127 backup_codes: Vec<String>,
128}
129
130impl TwoFactorPlugin {
131 pub fn new() -> Self {
132 Self {
133 config: TwoFactorConfig::default(),
134 }
135 }
136
137 pub fn with_config(config: TwoFactorConfig) -> Self {
138 Self { config }
139 }
140
141 pub fn issuer(mut self, issuer: impl Into<String>) -> Self {
142 self.config.issuer = issuer.into();
143 self
144 }
145
146 fn generate_backup_codes(&self) -> Vec<String> {
149 use rand::Rng;
150 (0..self.config.backup_code_count)
151 .map(|_| {
152 rand::thread_rng()
153 .sample_iter(&rand::distributions::Alphanumeric)
154 .take(self.config.backup_code_length)
155 .map(char::from)
156 .collect::<String>()
157 .to_uppercase()
158 })
159 .collect()
160 }
161
162 async fn hash_backup_codes(codes: &[String]) -> AuthResult<String> {
163 let mut hashed = Vec::with_capacity(codes.len());
164 for code in codes {
165 hashed.push(better_auth_core::hash_password(None, code).await?);
166 }
167 serde_json::to_string(&hashed).map_err(|e| AuthError::internal(e.to_string()))
168 }
169
170 fn build_totp(&self, secret: &[u8], email: &str, issuer: &str) -> AuthResult<TOTP> {
171 TOTP::new(
172 Algorithm::SHA1,
173 self.config.totp_digits,
174 1,
175 self.config.totp_period,
176 secret.to_vec(),
177 Some(issuer.to_string()),
178 email.to_string(),
179 )
180 .map_err(|e| AuthError::internal(format!("Failed to create TOTP: {}", e)))
181 }
182
183 async fn get_pending_2fa_user<DB: DatabaseAdapter>(
187 req: &AuthRequest,
188 ctx: &AuthContext<DB>,
189 ) -> AuthResult<(DB::User, String)> {
190 let token = req
191 .headers
192 .get("authorization")
193 .and_then(|v| v.strip_prefix("Bearer "))
194 .ok_or(AuthError::Unauthenticated)?;
195
196 if !token.starts_with("2fa_") {
197 return Err(AuthError::bad_request("Invalid 2FA pending token"));
198 }
199
200 let identifier = format!("2fa_pending:{}", token);
201 let verification = ctx
202 .database
203 .get_verification_by_identifier(&identifier)
204 .await?
205 .ok_or_else(|| AuthError::bad_request("Invalid or expired 2FA token"))?;
206
207 if verification.expires_at() < chrono::Utc::now() {
208 return Err(AuthError::bad_request("2FA token expired"));
209 }
210
211 let user_id = verification.value();
212 let user = ctx
213 .database
214 .get_user_by_id(user_id)
215 .await?
216 .ok_or(AuthError::UserNotFound)?;
217
218 Ok((user, verification.id().to_string()))
219 }
220
221 async fn verify_user_password<U: AuthUser>(user: &U, password: &str) -> AuthResult<()> {
222 let stored_hash = user
223 .metadata()
224 .get("password_hash")
225 .and_then(|v| v.as_str())
226 .ok_or(AuthError::InvalidCredentials)?;
227
228 better_auth_core::verify_password(None, password, stored_hash).await
229 }
230
231 async fn handle_enable<DB: DatabaseAdapter>(
234 &self,
235 req: &AuthRequest,
236 ctx: &AuthContext<DB>,
237 ) -> AuthResult<AuthResponse> {
238 let (user, _session) = ctx.require_session(req).await?;
239
240 let enable_req: EnableRequest = match better_auth_core::validate_request_body(req) {
241 Ok(v) => v,
242 Err(resp) => return Ok(resp),
243 };
244
245 Self::verify_user_password(&user, &enable_req.password).await?;
246
247 let secret = Secret::generate_secret();
249 let secret_encoded = secret.to_encoded().to_string();
250 let secret_bytes = secret.to_bytes().map_err(|e| {
251 AuthError::internal(format!("Failed to convert secret to bytes: {}", e))
252 })?;
253
254 let issuer = enable_req.issuer.as_deref().unwrap_or(&self.config.issuer);
255 let email = user.email().unwrap_or("user");
256
257 let totp = self.build_totp(&secret_bytes, email, issuer)?;
258 let totp_uri = totp.get_url();
259
260 let backup_codes = self.generate_backup_codes();
262 let hashed_codes = Self::hash_backup_codes(&backup_codes).await?;
263
264 ctx.database
266 .create_two_factor(CreateTwoFactor {
267 user_id: user.id().to_string(),
268 secret: secret_encoded,
269 backup_codes: Some(hashed_codes),
270 })
271 .await?;
272
273 ctx.database
275 .update_user(
276 user.id(),
277 UpdateUser {
278 two_factor_enabled: Some(true),
279 ..Default::default()
280 },
281 )
282 .await?;
283
284 let response = EnableResponse {
285 totp_uri,
286 backup_codes,
287 };
288 AuthResponse::json(200, &response).map_err(AuthError::from)
289 }
290
291 async fn handle_disable<DB: DatabaseAdapter>(
292 &self,
293 req: &AuthRequest,
294 ctx: &AuthContext<DB>,
295 ) -> AuthResult<AuthResponse> {
296 let (user, _session) = ctx.require_session(req).await?;
297
298 let disable_req: DisableRequest = match better_auth_core::validate_request_body(req) {
299 Ok(v) => v,
300 Err(resp) => return Ok(resp),
301 };
302
303 Self::verify_user_password(&user, &disable_req.password).await?;
304
305 ctx.database.delete_two_factor(user.id()).await?;
306
307 ctx.database
308 .update_user(
309 user.id(),
310 UpdateUser {
311 two_factor_enabled: Some(false),
312 ..Default::default()
313 },
314 )
315 .await?;
316
317 let response = StatusResponse { status: true };
318 AuthResponse::json(200, &response).map_err(AuthError::from)
319 }
320
321 async fn handle_get_totp_uri<DB: DatabaseAdapter>(
322 &self,
323 req: &AuthRequest,
324 ctx: &AuthContext<DB>,
325 ) -> AuthResult<AuthResponse> {
326 let (user, _session) = ctx.require_session(req).await?;
327
328 let uri_req: GetTotpUriRequest = match better_auth_core::validate_request_body(req) {
329 Ok(v) => v,
330 Err(resp) => return Ok(resp),
331 };
332
333 Self::verify_user_password(&user, &uri_req.password).await?;
334
335 let two_factor = ctx
336 .database
337 .get_two_factor_by_user_id(user.id())
338 .await?
339 .ok_or_else(|| AuthError::not_found("Two-factor authentication not enabled"))?;
340
341 let secret = Secret::Encoded(two_factor.secret().to_string());
342 let secret_bytes = secret
343 .to_bytes()
344 .map_err(|e| AuthError::internal(format!("Failed to decode secret: {}", e)))?;
345
346 let email = user.email().unwrap_or("user");
347 let totp = self.build_totp(&secret_bytes, email, &self.config.issuer)?;
348
349 let response = TotpUriResponse {
350 totp_uri: totp.get_url(),
351 };
352 AuthResponse::json(200, &response).map_err(AuthError::from)
353 }
354
355 async fn handle_verify_totp<DB: DatabaseAdapter>(
356 &self,
357 req: &AuthRequest,
358 ctx: &AuthContext<DB>,
359 ) -> AuthResult<AuthResponse> {
360 let (user, verification_id) = Self::get_pending_2fa_user(req, ctx).await?;
361
362 let verify_req: VerifyTotpRequest = match better_auth_core::validate_request_body(req) {
363 Ok(v) => v,
364 Err(resp) => return Ok(resp),
365 };
366
367 let two_factor = ctx
368 .database
369 .get_two_factor_by_user_id(user.id())
370 .await?
371 .ok_or_else(|| AuthError::not_found("Two-factor authentication not enabled"))?;
372
373 let secret = Secret::Encoded(two_factor.secret().to_string());
374 let secret_bytes = secret
375 .to_bytes()
376 .map_err(|e| AuthError::internal(format!("Failed to decode secret: {}", e)))?;
377
378 let email = user.email().unwrap_or("user");
379 let totp = self.build_totp(&secret_bytes, email, &self.config.issuer)?;
380
381 if !totp
382 .check_current(&verify_req.code)
383 .map_err(|e| AuthError::internal(format!("TOTP check error: {}", e)))?
384 {
385 return Err(AuthError::bad_request("Invalid TOTP code"));
386 }
387
388 let session_manager =
390 better_auth_core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
391 let session = session_manager.create_session(&user, None, None).await?;
392
393 ctx.database.delete_verification(&verification_id).await?;
395
396 let cookie_header = create_session_cookie(session.token(), ctx);
397 let response = VerifyTotpResponse {
398 status: true,
399 token: session.token().to_string(),
400 user,
401 };
402
403 Ok(AuthResponse::json(200, &response)?.with_header("Set-Cookie", cookie_header))
404 }
405
406 async fn handle_send_otp<DB: DatabaseAdapter>(
407 &self,
408 req: &AuthRequest,
409 ctx: &AuthContext<DB>,
410 ) -> AuthResult<AuthResponse> {
411 let (user, _verification_id) = Self::get_pending_2fa_user(req, ctx).await?;
412
413 use rand::Rng;
415 let otp: String = format!("{:06}", rand::thread_rng().gen_range(0..1_000_000u32));
416
417 let expires_at = chrono::Utc::now() + chrono::Duration::minutes(5);
419 ctx.database
420 .create_verification(CreateVerification {
421 identifier: format!("2fa_otp:{}", user.id()),
422 value: otp.clone(),
423 expires_at,
424 })
425 .await?;
426
427 if let Some(email) = user.email()
429 && let Ok(provider) = ctx.email_provider()
430 {
431 let body = format!("Your 2FA verification code is: {}", otp);
432 let _ = provider
433 .send(email, "Your verification code", &body, &body)
434 .await;
435 }
436
437 let response = StatusResponse { status: true };
438 AuthResponse::json(200, &response).map_err(AuthError::from)
439 }
440
441 async fn handle_verify_otp<DB: DatabaseAdapter>(
442 &self,
443 req: &AuthRequest,
444 ctx: &AuthContext<DB>,
445 ) -> AuthResult<AuthResponse> {
446 let (user, pending_verification_id) = Self::get_pending_2fa_user(req, ctx).await?;
447
448 let verify_req: VerifyOtpRequest = match better_auth_core::validate_request_body(req) {
449 Ok(v) => v,
450 Err(resp) => return Ok(resp),
451 };
452
453 let otp_identifier = format!("2fa_otp:{}", user.id());
455 let otp_verification = ctx
456 .database
457 .get_verification_by_identifier(&otp_identifier)
458 .await?
459 .ok_or_else(|| AuthError::bad_request("No OTP found. Please request a new one."))?;
460
461 if otp_verification.expires_at() < chrono::Utc::now() {
462 return Err(AuthError::bad_request("OTP has expired"));
463 }
464
465 if otp_verification.value() != verify_req.code {
466 return Err(AuthError::bad_request("Invalid OTP code"));
467 }
468
469 let session_manager =
471 better_auth_core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
472 let session = session_manager.create_session(&user, None, None).await?;
473
474 ctx.database
476 .delete_verification(otp_verification.id())
477 .await?;
478 ctx.database
479 .delete_verification(&pending_verification_id)
480 .await?;
481
482 let cookie_header = create_session_cookie(session.token(), ctx);
483 let response = VerifyTotpResponse {
484 status: true,
485 token: session.token().to_string(),
486 user,
487 };
488
489 Ok(AuthResponse::json(200, &response)?.with_header("Set-Cookie", cookie_header))
490 }
491
492 async fn handle_generate_backup_codes<DB: DatabaseAdapter>(
493 &self,
494 req: &AuthRequest,
495 ctx: &AuthContext<DB>,
496 ) -> AuthResult<AuthResponse> {
497 let (user, _session) = ctx.require_session(req).await?;
498
499 let gen_req: GenerateBackupCodesRequest = match better_auth_core::validate_request_body(req)
500 {
501 Ok(v) => v,
502 Err(resp) => return Ok(resp),
503 };
504
505 Self::verify_user_password(&user, &gen_req.password).await?;
506
507 let backup_codes = self.generate_backup_codes();
509 let hashed_codes = Self::hash_backup_codes(&backup_codes).await?;
510
511 ctx.database
512 .update_two_factor_backup_codes(user.id(), &hashed_codes)
513 .await?;
514
515 let response = BackupCodesResponse {
516 status: true,
517 backup_codes,
518 };
519 AuthResponse::json(200, &response).map_err(AuthError::from)
520 }
521
522 async fn handle_verify_backup_code<DB: DatabaseAdapter>(
523 &self,
524 req: &AuthRequest,
525 ctx: &AuthContext<DB>,
526 ) -> AuthResult<AuthResponse> {
527 let (user, pending_verification_id) = Self::get_pending_2fa_user(req, ctx).await?;
528
529 let verify_req: VerifyBackupCodeRequest = match better_auth_core::validate_request_body(req)
530 {
531 Ok(v) => v,
532 Err(resp) => return Ok(resp),
533 };
534
535 let two_factor = ctx
536 .database
537 .get_two_factor_by_user_id(user.id())
538 .await?
539 .ok_or_else(|| AuthError::not_found("Two-factor authentication not enabled"))?;
540
541 let codes_json = two_factor
542 .backup_codes()
543 .ok_or_else(|| AuthError::bad_request("No backup codes available"))?;
544
545 let hashed_codes: Vec<String> = serde_json::from_str(codes_json)
546 .map_err(|e| AuthError::internal(format!("Failed to parse backup codes: {}", e)))?;
547
548 let mut matched_index: Option<usize> = None;
550
551 for (i, hash_str) in hashed_codes.iter().enumerate() {
552 if better_auth_core::verify_password(None, &verify_req.code, hash_str)
553 .await
554 .is_ok()
555 {
556 matched_index = Some(i);
557 break;
558 }
559 }
560
561 let idx = matched_index.ok_or_else(|| AuthError::bad_request("Invalid backup code"))?;
562
563 let mut remaining_codes = hashed_codes;
565 remaining_codes.remove(idx);
566
567 let updated_codes_json = serde_json::to_string(&remaining_codes)
568 .map_err(|e| AuthError::internal(e.to_string()))?;
569
570 ctx.database
571 .update_two_factor_backup_codes(user.id(), &updated_codes_json)
572 .await?;
573
574 let session_manager =
576 better_auth_core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
577 let session = session_manager.create_session(&user, None, None).await?;
578
579 ctx.database
581 .delete_verification(&pending_verification_id)
582 .await?;
583
584 let cookie_header = create_session_cookie(session.token(), ctx);
585 let response = VerifyBackupCodeResponse { user, session };
586
587 Ok(AuthResponse::json(200, &response)?.with_header("Set-Cookie", cookie_header))
588 }
589}
590
591impl Default for TwoFactorPlugin {
592 fn default() -> Self {
593 Self::new()
594 }
595}
596
597#[async_trait]
598impl<DB: DatabaseAdapter> AuthPlugin<DB> for TwoFactorPlugin {
599 fn name(&self) -> &'static str {
600 "two-factor"
601 }
602
603 fn routes(&self) -> Vec<AuthRoute> {
604 vec![
605 AuthRoute::post("/two-factor/enable", "enable_two_factor"),
606 AuthRoute::post("/two-factor/disable", "disable_two_factor"),
607 AuthRoute::post("/two-factor/get-totp-uri", "get_totp_uri"),
608 AuthRoute::post("/two-factor/verify-totp", "verify_totp"),
609 AuthRoute::post("/two-factor/send-otp", "send_otp"),
610 AuthRoute::post("/two-factor/verify-otp", "verify_otp"),
611 AuthRoute::post("/two-factor/generate-backup-codes", "generate_backup_codes"),
612 AuthRoute::post("/two-factor/verify-backup-code", "verify_backup_code"),
613 ]
614 }
615
616 async fn on_request(
617 &self,
618 req: &AuthRequest,
619 ctx: &AuthContext<DB>,
620 ) -> AuthResult<Option<AuthResponse>> {
621 match (req.method(), req.path()) {
622 (HttpMethod::Post, "/two-factor/enable") => {
623 Ok(Some(self.handle_enable(req, ctx).await?))
624 }
625 (HttpMethod::Post, "/two-factor/disable") => {
626 Ok(Some(self.handle_disable(req, ctx).await?))
627 }
628 (HttpMethod::Post, "/two-factor/get-totp-uri") => {
629 Ok(Some(self.handle_get_totp_uri(req, ctx).await?))
630 }
631 (HttpMethod::Post, "/two-factor/verify-totp") => {
632 Ok(Some(self.handle_verify_totp(req, ctx).await?))
633 }
634 (HttpMethod::Post, "/two-factor/send-otp") => {
635 Ok(Some(self.handle_send_otp(req, ctx).await?))
636 }
637 (HttpMethod::Post, "/two-factor/verify-otp") => {
638 Ok(Some(self.handle_verify_otp(req, ctx).await?))
639 }
640 (HttpMethod::Post, "/two-factor/generate-backup-codes") => {
641 Ok(Some(self.handle_generate_backup_codes(req, ctx).await?))
642 }
643 (HttpMethod::Post, "/two-factor/verify-backup-code") => {
644 Ok(Some(self.handle_verify_backup_code(req, ctx).await?))
645 }
646 _ => Ok(None),
647 }
648 }
649}