Skip to main content

better_auth/core/
auth.rs

1use std::sync::Arc;
2
3use serde::Deserialize;
4
5use better_auth_core::{
6    AuthConfig, AuthContext, AuthError, AuthPlugin, AuthRequest, AuthResponse, AuthResult,
7    BeforeRequestAction, DatabaseAdapter, DatabaseHooks, EmailProvider, HttpMethod, OkResponse,
8    OpenApiBuilder, OpenApiSpec, SessionManager, StatusMessageResponse, SuccessMessageResponse,
9    UpdateUser, UpdateUserRequest, core_paths,
10    entity::{AuthAccount, AuthSession, AuthUser, AuthVerification},
11    middleware::{
12        self, BodyLimitConfig, BodyLimitMiddleware, CorsConfig, CorsMiddleware, CsrfConfig,
13        CsrfMiddleware, Middleware, RateLimitConfig, RateLimitMiddleware,
14    },
15};
16
17#[derive(Debug, Deserialize)]
18struct ChangeEmailRequest {
19    #[serde(rename = "newEmail")]
20    new_email: String,
21}
22
23/// The main BetterAuth instance, generic over the database adapter.
24pub struct BetterAuth<DB: DatabaseAdapter> {
25    config: Arc<AuthConfig>,
26    plugins: Vec<Box<dyn AuthPlugin<DB>>>,
27    middlewares: Vec<Box<dyn Middleware>>,
28    database: Arc<DB>,
29    session_manager: SessionManager<DB>,
30    context: AuthContext<DB>,
31    /// Kept on the built instance so the axum entry handler can bound
32    /// `to_bytes` at the caller-configured limit instead of a hard-coded
33    /// 1 MiB (which would otherwise override the user's
34    /// `AuthBuilder::body_limit(...)` choice).
35    body_limit_config: BodyLimitConfig,
36}
37
38/// Initial builder for configuring BetterAuth.
39///
40/// Call `.database(adapter)` to obtain a [`TypedAuthBuilder`] that can
41/// accept plugins and hooks.
42pub struct AuthBuilder {
43    config: AuthConfig,
44    csrf_config: Option<CsrfConfig>,
45    rate_limit_config: Option<RateLimitConfig>,
46    cors_config: Option<CorsConfig>,
47    body_limit_config: Option<BodyLimitConfig>,
48    custom_middlewares: Vec<Box<dyn Middleware>>,
49}
50
51/// Typed builder returned by [`AuthBuilder::database`].
52///
53/// Accepts plugins, hooks, and middleware before calling `.build()`.
54pub struct TypedAuthBuilder<DB: DatabaseAdapter> {
55    config: AuthConfig,
56    database: Arc<DB>,
57    plugins: Vec<Box<dyn AuthPlugin<DB>>>,
58    hooks: Vec<Arc<dyn DatabaseHooks<DB>>>,
59    csrf_config: Option<CsrfConfig>,
60    rate_limit_config: Option<RateLimitConfig>,
61    cors_config: Option<CorsConfig>,
62    body_limit_config: Option<BodyLimitConfig>,
63    custom_middlewares: Vec<Box<dyn Middleware>>,
64}
65
66impl AuthBuilder {
67    pub fn new(config: AuthConfig) -> Self {
68        Self {
69            config,
70            csrf_config: None,
71            rate_limit_config: None,
72            cors_config: None,
73            body_limit_config: None,
74            custom_middlewares: Vec::new(),
75        }
76    }
77
78    /// Set the database adapter, returning a [`TypedAuthBuilder`].
79    pub fn database<DB: DatabaseAdapter>(self, database: DB) -> TypedAuthBuilder<DB> {
80        TypedAuthBuilder {
81            config: self.config,
82            database: Arc::new(database),
83            plugins: Vec::new(),
84            hooks: Vec::new(),
85            csrf_config: self.csrf_config,
86            rate_limit_config: self.rate_limit_config,
87            cors_config: self.cors_config,
88            body_limit_config: self.body_limit_config,
89            custom_middlewares: self.custom_middlewares,
90        }
91    }
92
93    /// Configure CSRF protection.
94    pub fn csrf(mut self, config: CsrfConfig) -> Self {
95        self.csrf_config = Some(config);
96        self
97    }
98
99    /// Configure rate limiting.
100    pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
101        self.rate_limit_config = Some(config);
102        self
103    }
104
105    /// Configure CORS.
106    pub fn cors(mut self, config: CorsConfig) -> Self {
107        self.cors_config = Some(config);
108        self
109    }
110
111    /// Configure body size limit.
112    pub fn body_limit(mut self, config: BodyLimitConfig) -> Self {
113        self.body_limit_config = Some(config);
114        self
115    }
116
117    /// Set the email provider.
118    pub fn email_provider<E: EmailProvider + 'static>(mut self, provider: E) -> Self {
119        self.config.email_provider = Some(Arc::new(provider));
120        self
121    }
122}
123
124impl<DB: DatabaseAdapter> TypedAuthBuilder<DB> {
125    /// Add a plugin to the authentication system.
126    pub fn plugin<P: AuthPlugin<DB> + 'static>(mut self, plugin: P) -> Self {
127        self.plugins.push(Box::new(plugin));
128        self
129    }
130
131    /// Add a database lifecycle hook.
132    pub fn hook<H: DatabaseHooks<DB> + 'static>(mut self, hook: H) -> Self {
133        self.hooks.push(Arc::new(hook));
134        self
135    }
136
137    /// Configure CSRF protection.
138    pub fn csrf(mut self, config: CsrfConfig) -> Self {
139        self.csrf_config = Some(config);
140        self
141    }
142
143    /// Configure rate limiting.
144    pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
145        self.rate_limit_config = Some(config);
146        self
147    }
148
149    /// Configure CORS.
150    pub fn cors(mut self, config: CorsConfig) -> Self {
151        self.cors_config = Some(config);
152        self
153    }
154
155    /// Configure body size limit.
156    pub fn body_limit(mut self, config: BodyLimitConfig) -> Self {
157        self.body_limit_config = Some(config);
158        self
159    }
160
161    /// Set the email provider for sending emails.
162    pub fn email_provider<E: EmailProvider + 'static>(mut self, provider: E) -> Self {
163        self.config.email_provider = Some(Arc::new(provider));
164        self
165    }
166
167    /// Add a custom middleware.
168    pub fn middleware<M: Middleware + 'static>(mut self, mw: M) -> Self {
169        self.custom_middlewares.push(Box::new(mw));
170        self
171    }
172
173    /// Build the BetterAuth instance.
174    pub async fn build(self) -> AuthResult<BetterAuth<DB>> {
175        // Validate configuration
176        self.config.validate()?;
177
178        let config = Arc::new(self.config);
179
180        // If hooks are registered, the user should wrap the adapter themselves:
181        //   let db = HookedDatabaseAdapter::new(Arc::new(my_db)).with_hook(hook);
182        //   BetterAuth::new(config).database(db).plugin(...).build().await
183        if !self.hooks.is_empty() {
184            return Err(AuthError::config(
185                "Use HookedDatabaseAdapter directly: \
186                 BetterAuth::new(config).database(HookedDatabaseAdapter::new(Arc::new(db)).with_hook(h))",
187            ));
188        }
189
190        let database = self.database;
191
192        // Create session manager
193        let session_manager = SessionManager::new(config.clone(), database.clone());
194
195        // Create context
196        let mut context = AuthContext::new(config.clone(), database.clone());
197
198        // Initialize all plugins
199        for plugin in &self.plugins {
200            plugin.on_init(&mut context).await?;
201        }
202
203        // Build middleware chain (order matters: body limit → rate limit → CSRF → CORS → custom)
204        let body_limit_config = self.body_limit_config.unwrap_or_default();
205        let mut middlewares: Vec<Box<dyn Middleware>> = vec![
206            Box::new(BodyLimitMiddleware::new(body_limit_config.clone())),
207            Box::new(RateLimitMiddleware::new(
208                self.rate_limit_config.unwrap_or_default(),
209            )),
210            Box::new(CsrfMiddleware::new(
211                self.csrf_config.unwrap_or_default(),
212                config.clone(),
213            )),
214            Box::new(CorsMiddleware::new(self.cors_config.unwrap_or_default())),
215        ];
216
217        middlewares.extend(self.custom_middlewares);
218
219        Ok(BetterAuth {
220            config,
221            plugins: self.plugins,
222            middlewares,
223            database,
224            session_manager,
225            context,
226            body_limit_config,
227        })
228    }
229}
230
231impl<DB: DatabaseAdapter> BetterAuth<DB> {
232    /// Create a new BetterAuth builder.
233    #[allow(clippy::new_ret_no_self)]
234    pub fn new(config: AuthConfig) -> AuthBuilder {
235        AuthBuilder::new(config)
236    }
237
238    /// Handle an authentication request.
239    ///
240    /// Errors from plugins and core handlers are automatically converted
241    /// into standardized JSON responses via [`AuthError::into_response`],
242    /// producing `{ "message": "..." }` with the appropriate HTTP status code.
243    pub async fn handle_request(&self, req: AuthRequest) -> AuthResult<AuthResponse> {
244        // Ignore any caller-supplied virtual session value; only internal
245        // before_request hooks may inject this during dispatch.
246        let mut req =
247            AuthRequest::from_parts(req.method, req.path, req.headers, req.body, req.query);
248
249        match self.handle_request_inner(&mut req).await {
250            Ok(response) => {
251                // Run after-request middleware chain
252                middleware::run_after(&self.middlewares, &req, response).await
253            }
254            Err(err) => {
255                // Convert error to standardized response, then run after-middleware
256                let response = err.into_response();
257                middleware::run_after(&self.middlewares, &req, response).await
258            }
259        }
260    }
261
262    /// Inner request handler that may return errors.
263    async fn handle_request_inner(&self, req: &mut AuthRequest) -> AuthResult<AuthResponse> {
264        // Run before-request middleware chain
265        if let Some(response) = middleware::run_before(&self.middlewares, req).await? {
266            return Ok(response);
267        }
268
269        // Strip base_path prefix from the request path for internal routing.
270        // This happens BEFORE plugin hooks so that `before_request` sees the
271        // same normalised path that `on_request` / core handlers use.
272        // External callers send e.g. "/api/auth/sign-in/email"; internally
273        // handlers match against "/sign-in/email".
274        let base_path = &self.config.base_path;
275        let stripped_path = if !base_path.is_empty() && base_path != "/" {
276            req.path().strip_prefix(base_path).unwrap_or(req.path())
277        } else {
278            req.path()
279        };
280
281        // Build a request with the stripped path for all subsequent dispatch
282        let mut internal_req = if stripped_path != req.path() {
283            let mut r = req.clone();
284            r.path = stripped_path.to_string();
285            r
286        } else {
287            req.clone()
288        };
289
290        // Check if this path is disabled
291        if self.config.is_path_disabled(internal_req.path()) {
292            return Err(AuthError::not_found("This endpoint has been disabled"));
293        }
294
295        // Run plugin before_request hooks (e.g. API-key → session emulation)
296        // Plugins now see the normalised (base_path-stripped) path.
297        for plugin in &self.plugins {
298            if let Some(action) = plugin.before_request(&internal_req, &self.context).await? {
299                match action {
300                    BeforeRequestAction::Respond(response) => {
301                        return Ok(response);
302                    }
303                    BeforeRequestAction::InjectSession {
304                        user_id,
305                        session_token: _,
306                    } => {
307                        // Set the virtual user id on the request so that
308                        // `extract_current_user` can resolve the user without
309                        // creating a real database session.  This mirrors the
310                        // TypeScript `ctx.context.session` virtual-session
311                        // approach — no DB writes on every API-key request.
312                        internal_req.set_virtual_user_id(user_id);
313                    }
314                }
315            }
316        }
317
318        // Handle core endpoints first
319        if let Some(response) = self.handle_core_request(&internal_req).await? {
320            return Ok(response);
321        }
322
323        // Try each plugin until one handles the request
324        for plugin in &self.plugins {
325            if let Some(response) = plugin.on_request(&internal_req, &self.context).await? {
326                return Ok(response);
327            }
328        }
329
330        // No handler found
331        Err(AuthError::not_found("No handler found for this request"))
332    }
333
334    /// Get the configuration.
335    pub fn config(&self) -> &AuthConfig {
336        &self.config
337    }
338
339    /// Get the body-limit configuration; the axum entry handler uses
340    /// this to cap `to_bytes` at the same ceiling the user configured on
341    /// `AuthBuilder::body_limit`, rather than overriding it with a
342    /// hard-coded value.
343    pub fn body_limit_config(&self) -> &BodyLimitConfig {
344        &self.body_limit_config
345    }
346
347    /// Get the database adapter.
348    pub fn database(&self) -> &Arc<DB> {
349        &self.database
350    }
351
352    /// Get the session manager.
353    pub fn session_manager(&self) -> &SessionManager<DB> {
354        &self.session_manager
355    }
356
357    /// Get all routes from plugins.
358    pub fn routes(&self) -> Vec<(String, &dyn AuthPlugin<DB>)> {
359        let mut routes = Vec::new();
360        for plugin in &self.plugins {
361            for route in plugin.routes() {
362                routes.push((route.path, plugin.as_ref()));
363            }
364        }
365        routes
366    }
367
368    /// Get all plugins.
369    pub fn plugins(&self) -> &[Box<dyn AuthPlugin<DB>>] {
370        &self.plugins
371    }
372
373    /// Get plugin by name.
374    pub fn get_plugin(&self, name: &str) -> Option<&dyn AuthPlugin<DB>> {
375        self.plugins
376            .iter()
377            .find(|p| p.name() == name)
378            .map(|p| p.as_ref())
379    }
380
381    /// List all plugin names.
382    pub fn plugin_names(&self) -> Vec<&'static str> {
383        self.plugins.iter().map(|p| p.name()).collect()
384    }
385
386    /// Generate the OpenAPI spec for all registered routes.
387    pub fn openapi_spec(&self) -> OpenApiSpec {
388        let mut builder = OpenApiBuilder::new("Better Auth", env!("CARGO_PKG_VERSION"))
389            .description("Authentication API")
390            .core_routes();
391
392        for plugin in &self.plugins {
393            builder = builder.plugin(plugin.as_ref());
394        }
395
396        builder.build()
397    }
398
399    /// Handle core authentication requests.
400    async fn handle_core_request(&self, req: &AuthRequest) -> AuthResult<Option<AuthResponse>> {
401        match (req.method(), req.path()) {
402            (HttpMethod::Get, core_paths::OK) => {
403                Ok(Some(AuthResponse::json(200, &OkResponse { ok: true })?))
404            }
405            (HttpMethod::Get, core_paths::ERROR) => {
406                Ok(Some(AuthResponse::json(200, &OkResponse { ok: false })?))
407            }
408            (HttpMethod::Get, core_paths::OPENAPI_SPEC) => {
409                let spec = self.openapi_spec();
410                Ok(Some(AuthResponse::json(200, &spec)?))
411            }
412            (HttpMethod::Post, core_paths::UPDATE_USER) => {
413                Ok(Some(self.handle_update_user(req).await?))
414            }
415            (HttpMethod::Post | HttpMethod::Delete, core_paths::DELETE_USER) => {
416                Ok(Some(self.handle_delete_user(req).await?))
417            }
418            (HttpMethod::Post, core_paths::CHANGE_EMAIL) => {
419                Ok(Some(self.handle_change_email(req).await?))
420            }
421            (HttpMethod::Get, core_paths::DELETE_USER_CALLBACK) => {
422                Ok(Some(self.handle_delete_user_callback(req).await?))
423            }
424            _ => Ok(None),
425        }
426    }
427
428    /// Handle user profile update.
429    async fn handle_update_user(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
430        let current_user = self.extract_current_user(req).await?;
431
432        let update_req: UpdateUserRequest = req
433            .body_as_json()
434            .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
435
436        let update_user = UpdateUser {
437            email: update_req.email,
438            name: update_req.name,
439            image: update_req.image,
440            email_verified: None,
441            username: update_req.username,
442            display_username: update_req.display_username,
443            role: update_req.role,
444            banned: None,
445            ban_reason: None,
446            ban_expires: None,
447            two_factor_enabled: None,
448            metadata: update_req.metadata,
449        };
450
451        self.database
452            .update_user(current_user.id(), update_user)
453            .await?;
454
455        Ok(AuthResponse::json(
456            200,
457            &better_auth_core::StatusResponse { status: true },
458        )?)
459    }
460
461    /// Handle user deletion.
462    async fn handle_delete_user(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
463        let current_user = self.extract_current_user(req).await?;
464
465        self.database
466            .delete_user_sessions(current_user.id())
467            .await?;
468        self.database.delete_user(current_user.id()).await?;
469
470        let response = SuccessMessageResponse {
471            success: true,
472            message: "User account successfully deleted".to_string(),
473        };
474
475        Ok(AuthResponse::json(200, &response)?)
476    }
477
478    /// Handle email change.
479    async fn handle_change_email(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
480        let current_user = self.extract_current_user(req).await?;
481
482        let change_req: ChangeEmailRequest = req
483            .body_as_json()
484            .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
485
486        if !change_req.new_email.contains('@') || change_req.new_email.is_empty() {
487            return Err(AuthError::bad_request("Invalid email address"));
488        }
489
490        if self
491            .database
492            .get_user_by_email(&change_req.new_email)
493            .await?
494            .is_some()
495        {
496            return Err(AuthError::conflict("A user with this email already exists"));
497        }
498
499        let update_user = UpdateUser {
500            email: Some(change_req.new_email),
501            name: None,
502            image: None,
503            email_verified: Some(false),
504            username: None,
505            display_username: None,
506            role: None,
507            banned: None,
508            ban_reason: None,
509            ban_expires: None,
510            two_factor_enabled: None,
511            metadata: None,
512        };
513
514        self.database
515            .update_user(current_user.id(), update_user)
516            .await?;
517
518        Ok(AuthResponse::json(
519            200,
520            &StatusMessageResponse {
521                status: true,
522                message: "Email updated".to_string(),
523            },
524        )?)
525    }
526
527    /// Handle delete-user callback (token-based deletion confirmation).
528    async fn handle_delete_user_callback(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
529        let token = req
530            .query
531            .get("token")
532            .ok_or_else(|| AuthError::bad_request("Deletion token is required"))?;
533
534        let verification = self
535            .database
536            .get_verification_by_value(token)
537            .await?
538            .ok_or_else(|| AuthError::bad_request("Invalid or expired deletion token"))?;
539
540        let user_id = verification.identifier();
541
542        self.database.delete_user_sessions(user_id).await?;
543
544        let accounts = self.database.get_user_accounts(user_id).await?;
545        for account in accounts {
546            self.database.delete_account(account.id()).await?;
547        }
548
549        self.database.delete_user(user_id).await?;
550        self.database.delete_verification(verification.id()).await?;
551
552        let response = SuccessMessageResponse {
553            success: true,
554            message: "User account successfully deleted".to_string(),
555        };
556
557        Ok(AuthResponse::json(200, &response)?)
558    }
559
560    /// Extract current user from request (validates session).
561    ///
562    /// If a virtual session was injected by a `before_request` hook (e.g.
563    /// API-key session emulation), the user is resolved directly by ID
564    /// **without** a database session lookup — matching the TypeScript
565    /// `ctx.context.session` virtual-session behaviour.
566    async fn extract_current_user(&self, req: &AuthRequest) -> AuthResult<DB::User> {
567        // Fast path: virtual session injected by before_request hook
568        if let Some(uid) = req.virtual_user_id() {
569            return self
570                .database
571                .get_user_by_id(uid)
572                .await?
573                .ok_or(AuthError::UserNotFound);
574        }
575
576        let token = self
577            .session_manager
578            .extract_session_token(req)
579            .ok_or(AuthError::Unauthenticated)?;
580
581        let session = self
582            .session_manager
583            .get_session(&token)
584            .await?
585            .ok_or(AuthError::SessionNotFound)?;
586
587        let user = self
588            .database
589            .get_user_by_id(session.user_id())
590            .await?
591            .ok_or(AuthError::UserNotFound)?;
592
593        Ok(user)
594    }
595}