Skip to main content

better_auth/core/
auth.rs

1use std::sync::Arc;
2
3use serde::Deserialize;
4
5use better_auth_core::{
6    AuthConfig, AuthContext, AuthError, AuthPlugin, AuthRequest, AuthResponse, AuthResult,
7    BeforeRequestAction, DatabaseAdapter, DatabaseHooks, EmailProvider, HttpMethod, OkResponse,
8    OpenApiBuilder, OpenApiSpec, SessionManager, StatusMessageResponse, SuccessMessageResponse,
9    UpdateUser, UpdateUserRequest, core_paths,
10    entity::{AuthAccount, AuthSession, AuthUser, AuthVerification},
11    middleware::{
12        self, BodyLimitConfig, BodyLimitMiddleware, CorsConfig, CorsMiddleware, CsrfConfig,
13        CsrfMiddleware, Middleware, RateLimitConfig, RateLimitMiddleware,
14    },
15};
16
17#[derive(Debug, Deserialize)]
18struct ChangeEmailRequest {
19    #[serde(rename = "newEmail")]
20    new_email: String,
21}
22
23/// The main BetterAuth instance, generic over the database adapter.
24pub struct BetterAuth<DB: DatabaseAdapter> {
25    config: Arc<AuthConfig>,
26    plugins: Vec<Box<dyn AuthPlugin<DB>>>,
27    middlewares: Vec<Box<dyn Middleware>>,
28    database: Arc<DB>,
29    session_manager: SessionManager<DB>,
30    context: AuthContext<DB>,
31}
32
33/// Initial builder for configuring BetterAuth.
34///
35/// Call `.database(adapter)` to obtain a [`TypedAuthBuilder`] that can
36/// accept plugins and hooks.
37pub struct AuthBuilder {
38    config: AuthConfig,
39    csrf_config: Option<CsrfConfig>,
40    rate_limit_config: Option<RateLimitConfig>,
41    cors_config: Option<CorsConfig>,
42    body_limit_config: Option<BodyLimitConfig>,
43    custom_middlewares: Vec<Box<dyn Middleware>>,
44}
45
46/// Typed builder returned by [`AuthBuilder::database`].
47///
48/// Accepts plugins, hooks, and middleware before calling `.build()`.
49pub struct TypedAuthBuilder<DB: DatabaseAdapter> {
50    config: AuthConfig,
51    database: Arc<DB>,
52    plugins: Vec<Box<dyn AuthPlugin<DB>>>,
53    hooks: Vec<Arc<dyn DatabaseHooks<DB>>>,
54    csrf_config: Option<CsrfConfig>,
55    rate_limit_config: Option<RateLimitConfig>,
56    cors_config: Option<CorsConfig>,
57    body_limit_config: Option<BodyLimitConfig>,
58    custom_middlewares: Vec<Box<dyn Middleware>>,
59}
60
61impl AuthBuilder {
62    pub fn new(config: AuthConfig) -> Self {
63        Self {
64            config,
65            csrf_config: None,
66            rate_limit_config: None,
67            cors_config: None,
68            body_limit_config: None,
69            custom_middlewares: Vec::new(),
70        }
71    }
72
73    /// Set the database adapter, returning a [`TypedAuthBuilder`].
74    pub fn database<DB: DatabaseAdapter>(self, database: DB) -> TypedAuthBuilder<DB> {
75        TypedAuthBuilder {
76            config: self.config,
77            database: Arc::new(database),
78            plugins: Vec::new(),
79            hooks: Vec::new(),
80            csrf_config: self.csrf_config,
81            rate_limit_config: self.rate_limit_config,
82            cors_config: self.cors_config,
83            body_limit_config: self.body_limit_config,
84            custom_middlewares: self.custom_middlewares,
85        }
86    }
87
88    /// Configure CSRF protection.
89    pub fn csrf(mut self, config: CsrfConfig) -> Self {
90        self.csrf_config = Some(config);
91        self
92    }
93
94    /// Configure rate limiting.
95    pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
96        self.rate_limit_config = Some(config);
97        self
98    }
99
100    /// Configure CORS.
101    pub fn cors(mut self, config: CorsConfig) -> Self {
102        self.cors_config = Some(config);
103        self
104    }
105
106    /// Configure body size limit.
107    pub fn body_limit(mut self, config: BodyLimitConfig) -> Self {
108        self.body_limit_config = Some(config);
109        self
110    }
111
112    /// Set the email provider.
113    pub fn email_provider<E: EmailProvider + 'static>(mut self, provider: E) -> Self {
114        self.config.email_provider = Some(Arc::new(provider));
115        self
116    }
117}
118
119impl<DB: DatabaseAdapter> TypedAuthBuilder<DB> {
120    /// Add a plugin to the authentication system.
121    pub fn plugin<P: AuthPlugin<DB> + 'static>(mut self, plugin: P) -> Self {
122        self.plugins.push(Box::new(plugin));
123        self
124    }
125
126    /// Add a database lifecycle hook.
127    pub fn hook<H: DatabaseHooks<DB> + 'static>(mut self, hook: H) -> Self {
128        self.hooks.push(Arc::new(hook));
129        self
130    }
131
132    /// Configure CSRF protection.
133    pub fn csrf(mut self, config: CsrfConfig) -> Self {
134        self.csrf_config = Some(config);
135        self
136    }
137
138    /// Configure rate limiting.
139    pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
140        self.rate_limit_config = Some(config);
141        self
142    }
143
144    /// Configure CORS.
145    pub fn cors(mut self, config: CorsConfig) -> Self {
146        self.cors_config = Some(config);
147        self
148    }
149
150    /// Configure body size limit.
151    pub fn body_limit(mut self, config: BodyLimitConfig) -> Self {
152        self.body_limit_config = Some(config);
153        self
154    }
155
156    /// Set the email provider for sending emails.
157    pub fn email_provider<E: EmailProvider + 'static>(mut self, provider: E) -> Self {
158        self.config.email_provider = Some(Arc::new(provider));
159        self
160    }
161
162    /// Add a custom middleware.
163    pub fn middleware<M: Middleware + 'static>(mut self, mw: M) -> Self {
164        self.custom_middlewares.push(Box::new(mw));
165        self
166    }
167
168    /// Build the BetterAuth instance.
169    pub async fn build(self) -> AuthResult<BetterAuth<DB>> {
170        // Validate configuration
171        self.config.validate()?;
172
173        let config = Arc::new(self.config);
174
175        // If hooks are registered, the user should wrap the adapter themselves:
176        //   let db = HookedDatabaseAdapter::new(Arc::new(my_db)).with_hook(hook);
177        //   BetterAuth::new(config).database(db).plugin(...).build().await
178        if !self.hooks.is_empty() {
179            return Err(AuthError::config(
180                "Use HookedDatabaseAdapter directly: \
181                 BetterAuth::new(config).database(HookedDatabaseAdapter::new(Arc::new(db)).with_hook(h))",
182            ));
183        }
184
185        let database = self.database;
186
187        // Create session manager
188        let session_manager = SessionManager::new(config.clone(), database.clone());
189
190        // Create context
191        let mut context = AuthContext::new(config.clone(), database.clone());
192
193        // Initialize all plugins
194        for plugin in &self.plugins {
195            plugin.on_init(&mut context).await?;
196        }
197
198        // Build middleware chain (order matters: body limit → rate limit → CSRF → CORS → custom)
199        let mut middlewares: Vec<Box<dyn Middleware>> = vec![
200            Box::new(BodyLimitMiddleware::new(
201                self.body_limit_config.unwrap_or_default(),
202            )),
203            Box::new(RateLimitMiddleware::new(
204                self.rate_limit_config.unwrap_or_default(),
205            )),
206            Box::new(CsrfMiddleware::new(
207                self.csrf_config.unwrap_or_default(),
208                config.clone(),
209            )),
210            Box::new(CorsMiddleware::new(self.cors_config.unwrap_or_default())),
211        ];
212
213        middlewares.extend(self.custom_middlewares);
214
215        Ok(BetterAuth {
216            config,
217            plugins: self.plugins,
218            middlewares,
219            database,
220            session_manager,
221            context,
222        })
223    }
224}
225
226impl<DB: DatabaseAdapter> BetterAuth<DB> {
227    /// Create a new BetterAuth builder.
228    #[allow(clippy::new_ret_no_self)]
229    pub fn new(config: AuthConfig) -> AuthBuilder {
230        AuthBuilder::new(config)
231    }
232
233    /// Handle an authentication request.
234    ///
235    /// Errors from plugins and core handlers are automatically converted
236    /// into standardized JSON responses via [`AuthError::into_response`],
237    /// producing `{ "message": "..." }` with the appropriate HTTP status code.
238    pub async fn handle_request(&self, req: AuthRequest) -> AuthResult<AuthResponse> {
239        // Ignore any caller-supplied virtual session value; only internal
240        // before_request hooks may inject this during dispatch.
241        let mut req =
242            AuthRequest::from_parts(req.method, req.path, req.headers, req.body, req.query);
243
244        match self.handle_request_inner(&mut req).await {
245            Ok(response) => {
246                // Run after-request middleware chain
247                middleware::run_after(&self.middlewares, &req, response).await
248            }
249            Err(err) => {
250                // Convert error to standardized response, then run after-middleware
251                let response = err.into_response();
252                middleware::run_after(&self.middlewares, &req, response).await
253            }
254        }
255    }
256
257    /// Inner request handler that may return errors.
258    async fn handle_request_inner(&self, req: &mut AuthRequest) -> AuthResult<AuthResponse> {
259        // Run before-request middleware chain
260        if let Some(response) = middleware::run_before(&self.middlewares, req).await? {
261            return Ok(response);
262        }
263
264        // Strip base_path prefix from the request path for internal routing.
265        // This happens BEFORE plugin hooks so that `before_request` sees the
266        // same normalised path that `on_request` / core handlers use.
267        // External callers send e.g. "/api/auth/sign-in/email"; internally
268        // handlers match against "/sign-in/email".
269        let base_path = &self.config.base_path;
270        let stripped_path = if !base_path.is_empty() && base_path != "/" {
271            req.path().strip_prefix(base_path).unwrap_or(req.path())
272        } else {
273            req.path()
274        };
275
276        // Build a request with the stripped path for all subsequent dispatch
277        let mut internal_req = if stripped_path != req.path() {
278            let mut r = req.clone();
279            r.path = stripped_path.to_string();
280            r
281        } else {
282            req.clone()
283        };
284
285        // Check if this path is disabled
286        if self.config.is_path_disabled(internal_req.path()) {
287            return Err(AuthError::not_found("This endpoint has been disabled"));
288        }
289
290        // Run plugin before_request hooks (e.g. API-key → session emulation)
291        // Plugins now see the normalised (base_path-stripped) path.
292        for plugin in &self.plugins {
293            if let Some(action) = plugin.before_request(&internal_req, &self.context).await? {
294                match action {
295                    BeforeRequestAction::Respond(response) => {
296                        return Ok(response);
297                    }
298                    BeforeRequestAction::InjectSession {
299                        user_id,
300                        session_token: _,
301                    } => {
302                        // Set the virtual user id on the request so that
303                        // `extract_current_user` can resolve the user without
304                        // creating a real database session.  This mirrors the
305                        // TypeScript `ctx.context.session` virtual-session
306                        // approach — no DB writes on every API-key request.
307                        internal_req.set_virtual_user_id(user_id);
308                    }
309                }
310            }
311        }
312
313        // Handle core endpoints first
314        if let Some(response) = self.handle_core_request(&internal_req).await? {
315            return Ok(response);
316        }
317
318        // Try each plugin until one handles the request
319        for plugin in &self.plugins {
320            if let Some(response) = plugin.on_request(&internal_req, &self.context).await? {
321                return Ok(response);
322            }
323        }
324
325        // No handler found
326        Err(AuthError::not_found("No handler found for this request"))
327    }
328
329    /// Get the configuration.
330    pub fn config(&self) -> &AuthConfig {
331        &self.config
332    }
333
334    /// Get the database adapter.
335    pub fn database(&self) -> &Arc<DB> {
336        &self.database
337    }
338
339    /// Get the session manager.
340    pub fn session_manager(&self) -> &SessionManager<DB> {
341        &self.session_manager
342    }
343
344    /// Get all routes from plugins.
345    pub fn routes(&self) -> Vec<(String, &dyn AuthPlugin<DB>)> {
346        let mut routes = Vec::new();
347        for plugin in &self.plugins {
348            for route in plugin.routes() {
349                routes.push((route.path, plugin.as_ref()));
350            }
351        }
352        routes
353    }
354
355    /// Get all plugins.
356    pub fn plugins(&self) -> &[Box<dyn AuthPlugin<DB>>] {
357        &self.plugins
358    }
359
360    /// Get plugin by name.
361    pub fn get_plugin(&self, name: &str) -> Option<&dyn AuthPlugin<DB>> {
362        self.plugins
363            .iter()
364            .find(|p| p.name() == name)
365            .map(|p| p.as_ref())
366    }
367
368    /// List all plugin names.
369    pub fn plugin_names(&self) -> Vec<&'static str> {
370        self.plugins.iter().map(|p| p.name()).collect()
371    }
372
373    /// Generate the OpenAPI spec for all registered routes.
374    pub fn openapi_spec(&self) -> OpenApiSpec {
375        let mut builder = OpenApiBuilder::new("Better Auth", env!("CARGO_PKG_VERSION"))
376            .description("Authentication API")
377            .core_routes();
378
379        for plugin in &self.plugins {
380            builder = builder.plugin(plugin.as_ref());
381        }
382
383        builder.build()
384    }
385
386    /// Handle core authentication requests.
387    async fn handle_core_request(&self, req: &AuthRequest) -> AuthResult<Option<AuthResponse>> {
388        match (req.method(), req.path()) {
389            (HttpMethod::Get, core_paths::OK) => {
390                Ok(Some(AuthResponse::json(200, &OkResponse { ok: true })?))
391            }
392            (HttpMethod::Get, core_paths::ERROR) => {
393                Ok(Some(AuthResponse::json(200, &OkResponse { ok: false })?))
394            }
395            (HttpMethod::Get, core_paths::OPENAPI_SPEC) => {
396                let spec = self.openapi_spec();
397                Ok(Some(AuthResponse::json(200, &spec)?))
398            }
399            (HttpMethod::Post, core_paths::UPDATE_USER) => {
400                Ok(Some(self.handle_update_user(req).await?))
401            }
402            (HttpMethod::Post | HttpMethod::Delete, core_paths::DELETE_USER) => {
403                Ok(Some(self.handle_delete_user(req).await?))
404            }
405            (HttpMethod::Post, core_paths::CHANGE_EMAIL) => {
406                Ok(Some(self.handle_change_email(req).await?))
407            }
408            (HttpMethod::Get, core_paths::DELETE_USER_CALLBACK) => {
409                Ok(Some(self.handle_delete_user_callback(req).await?))
410            }
411            _ => Ok(None),
412        }
413    }
414
415    /// Handle user profile update.
416    async fn handle_update_user(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
417        let current_user = self.extract_current_user(req).await?;
418
419        let update_req: UpdateUserRequest = req
420            .body_as_json()
421            .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
422
423        let update_user = UpdateUser {
424            email: update_req.email,
425            name: update_req.name,
426            image: update_req.image,
427            email_verified: None,
428            username: update_req.username,
429            display_username: update_req.display_username,
430            role: update_req.role,
431            banned: None,
432            ban_reason: None,
433            ban_expires: None,
434            two_factor_enabled: None,
435            metadata: update_req.metadata,
436        };
437
438        self.database
439            .update_user(current_user.id(), update_user)
440            .await?;
441
442        Ok(AuthResponse::json(
443            200,
444            &better_auth_core::StatusResponse { status: true },
445        )?)
446    }
447
448    /// Handle user deletion.
449    async fn handle_delete_user(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
450        let current_user = self.extract_current_user(req).await?;
451
452        self.database
453            .delete_user_sessions(current_user.id())
454            .await?;
455        self.database.delete_user(current_user.id()).await?;
456
457        let response = SuccessMessageResponse {
458            success: true,
459            message: "User account successfully deleted".to_string(),
460        };
461
462        Ok(AuthResponse::json(200, &response)?)
463    }
464
465    /// Handle email change.
466    async fn handle_change_email(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
467        let current_user = self.extract_current_user(req).await?;
468
469        let change_req: ChangeEmailRequest = req
470            .body_as_json()
471            .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
472
473        if !change_req.new_email.contains('@') || change_req.new_email.is_empty() {
474            return Err(AuthError::bad_request("Invalid email address"));
475        }
476
477        if self
478            .database
479            .get_user_by_email(&change_req.new_email)
480            .await?
481            .is_some()
482        {
483            return Err(AuthError::conflict("A user with this email already exists"));
484        }
485
486        let update_user = UpdateUser {
487            email: Some(change_req.new_email),
488            name: None,
489            image: None,
490            email_verified: Some(false),
491            username: None,
492            display_username: None,
493            role: None,
494            banned: None,
495            ban_reason: None,
496            ban_expires: None,
497            two_factor_enabled: None,
498            metadata: None,
499        };
500
501        self.database
502            .update_user(current_user.id(), update_user)
503            .await?;
504
505        Ok(AuthResponse::json(
506            200,
507            &StatusMessageResponse {
508                status: true,
509                message: "Email updated".to_string(),
510            },
511        )?)
512    }
513
514    /// Handle delete-user callback (token-based deletion confirmation).
515    async fn handle_delete_user_callback(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
516        let token = req
517            .query
518            .get("token")
519            .ok_or_else(|| AuthError::bad_request("Deletion token is required"))?;
520
521        let verification = self
522            .database
523            .get_verification_by_value(token)
524            .await?
525            .ok_or_else(|| AuthError::bad_request("Invalid or expired deletion token"))?;
526
527        let user_id = verification.identifier();
528
529        self.database.delete_user_sessions(user_id).await?;
530
531        let accounts = self.database.get_user_accounts(user_id).await?;
532        for account in accounts {
533            self.database.delete_account(account.id()).await?;
534        }
535
536        self.database.delete_user(user_id).await?;
537        self.database.delete_verification(verification.id()).await?;
538
539        let response = SuccessMessageResponse {
540            success: true,
541            message: "User account successfully deleted".to_string(),
542        };
543
544        Ok(AuthResponse::json(200, &response)?)
545    }
546
547    /// Extract current user from request (validates session).
548    ///
549    /// If a virtual session was injected by a `before_request` hook (e.g.
550    /// API-key session emulation), the user is resolved directly by ID
551    /// **without** a database session lookup — matching the TypeScript
552    /// `ctx.context.session` virtual-session behaviour.
553    async fn extract_current_user(&self, req: &AuthRequest) -> AuthResult<DB::User> {
554        // Fast path: virtual session injected by before_request hook
555        if let Some(uid) = req.virtual_user_id() {
556            return self
557                .database
558                .get_user_by_id(uid)
559                .await?
560                .ok_or(AuthError::UserNotFound);
561        }
562
563        let token = self
564            .session_manager
565            .extract_session_token(req)
566            .ok_or(AuthError::Unauthenticated)?;
567
568        let session = self
569            .session_manager
570            .get_session(&token)
571            .await?
572            .ok_or(AuthError::SessionNotFound)?;
573
574        let user = self
575            .database
576            .get_user_by_id(session.user_id())
577            .await?
578            .ok_or(AuthError::UserNotFound)?;
579
580        Ok(user)
581    }
582}