Skip to main content

better_auth_api/plugins/
two_factor.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use totp_rs::{Algorithm, Secret, TOTP};
4use validator::Validate;
5
6use better_auth_core::adapters::DatabaseAdapter;
7use better_auth_core::entity::{AuthSession, AuthTwoFactor, AuthUser, AuthVerification};
8use better_auth_core::{AuthContext, AuthPlugin, AuthRoute};
9use better_auth_core::{AuthError, AuthResult};
10use better_auth_core::{
11    AuthRequest, AuthResponse, CreateTwoFactor, CreateVerification, HttpMethod, UpdateUser,
12};
13
14use better_auth_core::utils::cookie_utils::create_session_cookie;
15
16use super::StatusResponse;
17
18/// Two-factor authentication plugin providing TOTP, OTP, and backup code flows.
19pub struct TwoFactorPlugin {
20    config: TwoFactorConfig,
21}
22
23#[derive(Debug, Clone)]
24pub struct TwoFactorConfig {
25    pub issuer: String,
26    pub backup_code_count: usize,
27    pub backup_code_length: usize,
28    pub totp_period: u64,
29    pub totp_digits: usize,
30}
31
32impl Default for TwoFactorConfig {
33    fn default() -> Self {
34        Self {
35            issuer: "BetterAuth".to_string(),
36            backup_code_count: 10,
37            backup_code_length: 8,
38            totp_period: 30,
39            totp_digits: 6,
40        }
41    }
42}
43
44// -- Request types --
45
46#[derive(Debug, Deserialize, Validate)]
47struct EnableRequest {
48    password: String,
49    issuer: Option<String>,
50}
51
52#[derive(Debug, Deserialize, Validate)]
53struct DisableRequest {
54    password: String,
55}
56
57#[derive(Debug, Deserialize, Validate)]
58struct GetTotpUriRequest {
59    password: String,
60}
61
62#[derive(Debug, Deserialize, Validate)]
63struct VerifyTotpRequest {
64    code: String,
65    #[serde(rename = "trustDevice")]
66    #[allow(dead_code)]
67    trust_device: Option<String>,
68}
69
70#[derive(Debug, Deserialize, Validate)]
71struct VerifyOtpRequest {
72    code: String,
73    #[serde(rename = "trustDevice")]
74    #[allow(dead_code)]
75    trust_device: Option<String>,
76}
77
78#[derive(Debug, Deserialize, Validate)]
79struct GenerateBackupCodesRequest {
80    password: String,
81}
82
83#[derive(Debug, Deserialize, Validate)]
84struct VerifyBackupCodeRequest {
85    code: String,
86    #[serde(rename = "disableSession")]
87    #[allow(dead_code)]
88    disable_session: Option<String>,
89    #[serde(rename = "trustDevice")]
90    #[allow(dead_code)]
91    trust_device: Option<String>,
92}
93
94// -- Response types --
95
96#[derive(Debug, Serialize)]
97struct EnableResponse {
98    #[serde(rename = "totpURI")]
99    totp_uri: String,
100    #[serde(rename = "backupCodes")]
101    backup_codes: Vec<String>,
102}
103
104#[derive(Debug, Serialize)]
105struct TotpUriResponse {
106    #[serde(rename = "totpURI")]
107    totp_uri: String,
108}
109
110#[derive(Debug, Serialize)]
111struct VerifyTotpResponse<U: Serialize> {
112    status: bool,
113    token: String,
114    user: U,
115}
116
117#[derive(Debug, Serialize)]
118struct VerifyBackupCodeResponse<U: Serialize, S: Serialize> {
119    user: U,
120    session: S,
121}
122
123#[derive(Debug, Serialize)]
124struct BackupCodesResponse {
125    status: bool,
126    #[serde(rename = "backupCodes")]
127    backup_codes: Vec<String>,
128}
129
130impl TwoFactorPlugin {
131    pub fn new() -> Self {
132        Self {
133            config: TwoFactorConfig::default(),
134        }
135    }
136
137    pub fn with_config(config: TwoFactorConfig) -> Self {
138        Self { config }
139    }
140
141    pub fn issuer(mut self, issuer: impl Into<String>) -> Self {
142        self.config.issuer = issuer.into();
143        self
144    }
145
146    // -- Helpers --
147
148    fn generate_backup_codes(&self) -> Vec<String> {
149        use rand::Rng;
150        (0..self.config.backup_code_count)
151            .map(|_| {
152                rand::thread_rng()
153                    .sample_iter(&rand::distributions::Alphanumeric)
154                    .take(self.config.backup_code_length)
155                    .map(char::from)
156                    .collect::<String>()
157                    .to_uppercase()
158            })
159            .collect()
160    }
161
162    async fn hash_backup_codes(codes: &[String]) -> AuthResult<String> {
163        let mut hashed = Vec::with_capacity(codes.len());
164        for code in codes {
165            hashed.push(better_auth_core::hash_password(None, code).await?);
166        }
167        serde_json::to_string(&hashed).map_err(|e| AuthError::internal(e.to_string()))
168    }
169
170    fn build_totp(&self, secret: &[u8], email: &str, issuer: &str) -> AuthResult<TOTP> {
171        TOTP::new(
172            Algorithm::SHA1,
173            self.config.totp_digits,
174            1,
175            self.config.totp_period,
176            secret.to_vec(),
177            Some(issuer.to_string()),
178            email.to_string(),
179        )
180        .map_err(|e| AuthError::internal(format!("Failed to create TOTP: {}", e)))
181    }
182
183    // -- Session / auth helpers --
184
185    /// Extract the user_id from a `2fa_xxx` pending verification token.
186    async fn get_pending_2fa_user<DB: DatabaseAdapter>(
187        req: &AuthRequest,
188        ctx: &AuthContext<DB>,
189    ) -> AuthResult<(DB::User, String)> {
190        let token = req
191            .headers
192            .get("authorization")
193            .and_then(|v| v.strip_prefix("Bearer "))
194            .ok_or(AuthError::Unauthenticated)?;
195
196        if !token.starts_with("2fa_") {
197            return Err(AuthError::bad_request("Invalid 2FA pending token"));
198        }
199
200        let identifier = format!("2fa_pending:{}", token);
201        let verification = ctx
202            .database
203            .get_verification_by_identifier(&identifier)
204            .await?
205            .ok_or_else(|| AuthError::bad_request("Invalid or expired 2FA token"))?;
206
207        if verification.expires_at() < chrono::Utc::now() {
208            return Err(AuthError::bad_request("2FA token expired"));
209        }
210
211        let user_id = verification.value();
212        let user = ctx
213            .database
214            .get_user_by_id(user_id)
215            .await?
216            .ok_or(AuthError::UserNotFound)?;
217
218        Ok((user, verification.id().to_string()))
219    }
220
221    async fn verify_user_password<U: AuthUser>(user: &U, password: &str) -> AuthResult<()> {
222        let stored_hash = user
223            .metadata()
224            .get("password_hash")
225            .and_then(|v| v.as_str())
226            .ok_or(AuthError::InvalidCredentials)?;
227
228        better_auth_core::verify_password(None, password, stored_hash).await
229    }
230
231    // -- Handlers --
232
233    async fn handle_enable<DB: DatabaseAdapter>(
234        &self,
235        req: &AuthRequest,
236        ctx: &AuthContext<DB>,
237    ) -> AuthResult<AuthResponse> {
238        let (user, _session) = ctx.require_session(req).await?;
239
240        let enable_req: EnableRequest = match better_auth_core::validate_request_body(req) {
241            Ok(v) => v,
242            Err(resp) => return Ok(resp),
243        };
244
245        Self::verify_user_password(&user, &enable_req.password).await?;
246
247        // Generate TOTP secret
248        let secret = Secret::generate_secret();
249        let secret_encoded = secret.to_encoded().to_string();
250        let secret_bytes = secret.to_bytes().map_err(|e| {
251            AuthError::internal(format!("Failed to convert secret to bytes: {}", e))
252        })?;
253
254        let issuer = enable_req.issuer.as_deref().unwrap_or(&self.config.issuer);
255        let email = user.email().unwrap_or("user");
256
257        let totp = self.build_totp(&secret_bytes, email, issuer)?;
258        let totp_uri = totp.get_url();
259
260        // Generate and hash backup codes
261        let backup_codes = self.generate_backup_codes();
262        let hashed_codes = Self::hash_backup_codes(&backup_codes).await?;
263
264        // Store 2FA record
265        ctx.database
266            .create_two_factor(CreateTwoFactor {
267                user_id: user.id().to_string(),
268                secret: secret_encoded,
269                backup_codes: Some(hashed_codes),
270            })
271            .await?;
272
273        // Update user flag
274        ctx.database
275            .update_user(
276                user.id(),
277                UpdateUser {
278                    two_factor_enabled: Some(true),
279                    ..Default::default()
280                },
281            )
282            .await?;
283
284        let response = EnableResponse {
285            totp_uri,
286            backup_codes,
287        };
288        AuthResponse::json(200, &response).map_err(AuthError::from)
289    }
290
291    async fn handle_disable<DB: DatabaseAdapter>(
292        &self,
293        req: &AuthRequest,
294        ctx: &AuthContext<DB>,
295    ) -> AuthResult<AuthResponse> {
296        let (user, _session) = ctx.require_session(req).await?;
297
298        let disable_req: DisableRequest = match better_auth_core::validate_request_body(req) {
299            Ok(v) => v,
300            Err(resp) => return Ok(resp),
301        };
302
303        Self::verify_user_password(&user, &disable_req.password).await?;
304
305        ctx.database.delete_two_factor(user.id()).await?;
306
307        ctx.database
308            .update_user(
309                user.id(),
310                UpdateUser {
311                    two_factor_enabled: Some(false),
312                    ..Default::default()
313                },
314            )
315            .await?;
316
317        let response = StatusResponse { status: true };
318        AuthResponse::json(200, &response).map_err(AuthError::from)
319    }
320
321    async fn handle_get_totp_uri<DB: DatabaseAdapter>(
322        &self,
323        req: &AuthRequest,
324        ctx: &AuthContext<DB>,
325    ) -> AuthResult<AuthResponse> {
326        let (user, _session) = ctx.require_session(req).await?;
327
328        let uri_req: GetTotpUriRequest = match better_auth_core::validate_request_body(req) {
329            Ok(v) => v,
330            Err(resp) => return Ok(resp),
331        };
332
333        Self::verify_user_password(&user, &uri_req.password).await?;
334
335        let two_factor = ctx
336            .database
337            .get_two_factor_by_user_id(user.id())
338            .await?
339            .ok_or_else(|| AuthError::not_found("Two-factor authentication not enabled"))?;
340
341        let secret = Secret::Encoded(two_factor.secret().to_string());
342        let secret_bytes = secret
343            .to_bytes()
344            .map_err(|e| AuthError::internal(format!("Failed to decode secret: {}", e)))?;
345
346        let email = user.email().unwrap_or("user");
347        let totp = self.build_totp(&secret_bytes, email, &self.config.issuer)?;
348
349        let response = TotpUriResponse {
350            totp_uri: totp.get_url(),
351        };
352        AuthResponse::json(200, &response).map_err(AuthError::from)
353    }
354
355    async fn handle_verify_totp<DB: DatabaseAdapter>(
356        &self,
357        req: &AuthRequest,
358        ctx: &AuthContext<DB>,
359    ) -> AuthResult<AuthResponse> {
360        let (user, verification_id) = Self::get_pending_2fa_user(req, ctx).await?;
361
362        let verify_req: VerifyTotpRequest = match better_auth_core::validate_request_body(req) {
363            Ok(v) => v,
364            Err(resp) => return Ok(resp),
365        };
366
367        let two_factor = ctx
368            .database
369            .get_two_factor_by_user_id(user.id())
370            .await?
371            .ok_or_else(|| AuthError::not_found("Two-factor authentication not enabled"))?;
372
373        let secret = Secret::Encoded(two_factor.secret().to_string());
374        let secret_bytes = secret
375            .to_bytes()
376            .map_err(|e| AuthError::internal(format!("Failed to decode secret: {}", e)))?;
377
378        let email = user.email().unwrap_or("user");
379        let totp = self.build_totp(&secret_bytes, email, &self.config.issuer)?;
380
381        if !totp
382            .check_current(&verify_req.code)
383            .map_err(|e| AuthError::internal(format!("TOTP check error: {}", e)))?
384        {
385            return Err(AuthError::bad_request("Invalid TOTP code"));
386        }
387
388        // Code valid — create session
389        let session_manager =
390            better_auth_core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
391        let session = session_manager.create_session(&user, None, None).await?;
392
393        // Delete the pending verification
394        ctx.database.delete_verification(&verification_id).await?;
395
396        let cookie_header = create_session_cookie(session.token(), ctx);
397        let response = VerifyTotpResponse {
398            status: true,
399            token: session.token().to_string(),
400            user,
401        };
402
403        Ok(AuthResponse::json(200, &response)?.with_header("Set-Cookie", cookie_header))
404    }
405
406    async fn handle_send_otp<DB: DatabaseAdapter>(
407        &self,
408        req: &AuthRequest,
409        ctx: &AuthContext<DB>,
410    ) -> AuthResult<AuthResponse> {
411        let (user, _verification_id) = Self::get_pending_2fa_user(req, ctx).await?;
412
413        // Generate 6-digit OTP
414        use rand::Rng;
415        let otp: String = format!("{:06}", rand::thread_rng().gen_range(0..1_000_000u32));
416
417        // Store the OTP verification (expires in 5 minutes)
418        let expires_at = chrono::Utc::now() + chrono::Duration::minutes(5);
419        ctx.database
420            .create_verification(CreateVerification {
421                identifier: format!("2fa_otp:{}", user.id()),
422                value: otp.clone(),
423                expires_at,
424            })
425            .await?;
426
427        // Send via email if provider is available
428        if let Some(email) = user.email()
429            && let Ok(provider) = ctx.email_provider()
430        {
431            let body = format!("Your 2FA verification code is: {}", otp);
432            let _ = provider
433                .send(email, "Your verification code", &body, &body)
434                .await;
435        }
436
437        let response = StatusResponse { status: true };
438        AuthResponse::json(200, &response).map_err(AuthError::from)
439    }
440
441    async fn handle_verify_otp<DB: DatabaseAdapter>(
442        &self,
443        req: &AuthRequest,
444        ctx: &AuthContext<DB>,
445    ) -> AuthResult<AuthResponse> {
446        let (user, pending_verification_id) = Self::get_pending_2fa_user(req, ctx).await?;
447
448        let verify_req: VerifyOtpRequest = match better_auth_core::validate_request_body(req) {
449            Ok(v) => v,
450            Err(resp) => return Ok(resp),
451        };
452
453        // Look up the OTP verification
454        let otp_identifier = format!("2fa_otp:{}", user.id());
455        let otp_verification = ctx
456            .database
457            .get_verification_by_identifier(&otp_identifier)
458            .await?
459            .ok_or_else(|| AuthError::bad_request("No OTP found. Please request a new one."))?;
460
461        if otp_verification.expires_at() < chrono::Utc::now() {
462            return Err(AuthError::bad_request("OTP has expired"));
463        }
464
465        if otp_verification.value() != verify_req.code {
466            return Err(AuthError::bad_request("Invalid OTP code"));
467        }
468
469        // Valid — create session
470        let session_manager =
471            better_auth_core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
472        let session = session_manager.create_session(&user, None, None).await?;
473
474        // Clean up verifications
475        ctx.database
476            .delete_verification(otp_verification.id())
477            .await?;
478        ctx.database
479            .delete_verification(&pending_verification_id)
480            .await?;
481
482        let cookie_header = create_session_cookie(session.token(), ctx);
483        let response = VerifyTotpResponse {
484            status: true,
485            token: session.token().to_string(),
486            user,
487        };
488
489        Ok(AuthResponse::json(200, &response)?.with_header("Set-Cookie", cookie_header))
490    }
491
492    async fn handle_generate_backup_codes<DB: DatabaseAdapter>(
493        &self,
494        req: &AuthRequest,
495        ctx: &AuthContext<DB>,
496    ) -> AuthResult<AuthResponse> {
497        let (user, _session) = ctx.require_session(req).await?;
498
499        let gen_req: GenerateBackupCodesRequest = match better_auth_core::validate_request_body(req)
500        {
501            Ok(v) => v,
502            Err(resp) => return Ok(resp),
503        };
504
505        Self::verify_user_password(&user, &gen_req.password).await?;
506
507        // Generate new codes
508        let backup_codes = self.generate_backup_codes();
509        let hashed_codes = Self::hash_backup_codes(&backup_codes).await?;
510
511        ctx.database
512            .update_two_factor_backup_codes(user.id(), &hashed_codes)
513            .await?;
514
515        let response = BackupCodesResponse {
516            status: true,
517            backup_codes,
518        };
519        AuthResponse::json(200, &response).map_err(AuthError::from)
520    }
521
522    async fn handle_verify_backup_code<DB: DatabaseAdapter>(
523        &self,
524        req: &AuthRequest,
525        ctx: &AuthContext<DB>,
526    ) -> AuthResult<AuthResponse> {
527        let (user, pending_verification_id) = Self::get_pending_2fa_user(req, ctx).await?;
528
529        let verify_req: VerifyBackupCodeRequest = match better_auth_core::validate_request_body(req)
530        {
531            Ok(v) => v,
532            Err(resp) => return Ok(resp),
533        };
534
535        let two_factor = ctx
536            .database
537            .get_two_factor_by_user_id(user.id())
538            .await?
539            .ok_or_else(|| AuthError::not_found("Two-factor authentication not enabled"))?;
540
541        let codes_json = two_factor
542            .backup_codes()
543            .ok_or_else(|| AuthError::bad_request("No backup codes available"))?;
544
545        let hashed_codes: Vec<String> = serde_json::from_str(codes_json)
546            .map_err(|e| AuthError::internal(format!("Failed to parse backup codes: {}", e)))?;
547
548        // Try to match the provided code against each hashed code
549        let mut matched_index: Option<usize> = None;
550
551        for (i, hash_str) in hashed_codes.iter().enumerate() {
552            if better_auth_core::verify_password(None, &verify_req.code, hash_str)
553                .await
554                .is_ok()
555            {
556                matched_index = Some(i);
557                break;
558            }
559        }
560
561        let idx = matched_index.ok_or_else(|| AuthError::bad_request("Invalid backup code"))?;
562
563        // Remove used code and update
564        let mut remaining_codes = hashed_codes;
565        remaining_codes.remove(idx);
566
567        let updated_codes_json = serde_json::to_string(&remaining_codes)
568            .map_err(|e| AuthError::internal(e.to_string()))?;
569
570        ctx.database
571            .update_two_factor_backup_codes(user.id(), &updated_codes_json)
572            .await?;
573
574        // Create session
575        let session_manager =
576            better_auth_core::SessionManager::new(ctx.config.clone(), ctx.database.clone());
577        let session = session_manager.create_session(&user, None, None).await?;
578
579        // Clean up pending verification
580        ctx.database
581            .delete_verification(&pending_verification_id)
582            .await?;
583
584        let cookie_header = create_session_cookie(session.token(), ctx);
585        let response = VerifyBackupCodeResponse { user, session };
586
587        Ok(AuthResponse::json(200, &response)?.with_header("Set-Cookie", cookie_header))
588    }
589}
590
591impl Default for TwoFactorPlugin {
592    fn default() -> Self {
593        Self::new()
594    }
595}
596
597#[async_trait]
598impl<DB: DatabaseAdapter> AuthPlugin<DB> for TwoFactorPlugin {
599    fn name(&self) -> &'static str {
600        "two-factor"
601    }
602
603    fn routes(&self) -> Vec<AuthRoute> {
604        vec![
605            AuthRoute::post("/two-factor/enable", "enable_two_factor"),
606            AuthRoute::post("/two-factor/disable", "disable_two_factor"),
607            AuthRoute::post("/two-factor/get-totp-uri", "get_totp_uri"),
608            AuthRoute::post("/two-factor/verify-totp", "verify_totp"),
609            AuthRoute::post("/two-factor/send-otp", "send_otp"),
610            AuthRoute::post("/two-factor/verify-otp", "verify_otp"),
611            AuthRoute::post("/two-factor/generate-backup-codes", "generate_backup_codes"),
612            AuthRoute::post("/two-factor/verify-backup-code", "verify_backup_code"),
613        ]
614    }
615
616    async fn on_request(
617        &self,
618        req: &AuthRequest,
619        ctx: &AuthContext<DB>,
620    ) -> AuthResult<Option<AuthResponse>> {
621        match (req.method(), req.path()) {
622            (HttpMethod::Post, "/two-factor/enable") => {
623                Ok(Some(self.handle_enable(req, ctx).await?))
624            }
625            (HttpMethod::Post, "/two-factor/disable") => {
626                Ok(Some(self.handle_disable(req, ctx).await?))
627            }
628            (HttpMethod::Post, "/two-factor/get-totp-uri") => {
629                Ok(Some(self.handle_get_totp_uri(req, ctx).await?))
630            }
631            (HttpMethod::Post, "/two-factor/verify-totp") => {
632                Ok(Some(self.handle_verify_totp(req, ctx).await?))
633            }
634            (HttpMethod::Post, "/two-factor/send-otp") => {
635                Ok(Some(self.handle_send_otp(req, ctx).await?))
636            }
637            (HttpMethod::Post, "/two-factor/verify-otp") => {
638                Ok(Some(self.handle_verify_otp(req, ctx).await?))
639            }
640            (HttpMethod::Post, "/two-factor/generate-backup-codes") => {
641                Ok(Some(self.handle_generate_backup_codes(req, ctx).await?))
642            }
643            (HttpMethod::Post, "/two-factor/verify-backup-code") => {
644                Ok(Some(self.handle_verify_backup_code(req, ctx).await?))
645            }
646            _ => Ok(None),
647        }
648    }
649}