Skip to main content

better_auth/core/
auth.rs

1use std::sync::Arc;
2
3use serde::Deserialize;
4
5use better_auth_core::{
6    AuthConfig, AuthContext, AuthError, AuthPlugin, AuthRequest, AuthResponse, AuthResult,
7    DatabaseAdapter, DatabaseHooks, DeleteUserResponse, EmailProvider, HttpMethod, OkResponse,
8    OpenApiBuilder, OpenApiSpec, SessionManager, StatusMessageResponse, StatusResponse, UpdateUser,
9    UpdateUserRequest,
10    entity::{AuthAccount, AuthSession, AuthUser, AuthVerification},
11    middleware::{
12        self, BodyLimitConfig, BodyLimitMiddleware, CorsConfig, CorsMiddleware, CsrfConfig,
13        CsrfMiddleware, Middleware, RateLimitConfig, RateLimitMiddleware,
14    },
15};
16
17#[derive(Debug, Deserialize)]
18struct ChangeEmailRequest {
19    #[serde(rename = "newEmail")]
20    new_email: String,
21}
22
23/// The main BetterAuth instance, generic over the database adapter.
24pub struct BetterAuth<DB: DatabaseAdapter> {
25    config: Arc<AuthConfig>,
26    plugins: Vec<Box<dyn AuthPlugin<DB>>>,
27    middlewares: Vec<Box<dyn Middleware>>,
28    database: Arc<DB>,
29    session_manager: SessionManager<DB>,
30    context: AuthContext<DB>,
31}
32
33/// Initial builder for configuring BetterAuth.
34///
35/// Call `.database(adapter)` to obtain a [`TypedAuthBuilder`] that can
36/// accept plugins and hooks.
37pub struct AuthBuilder {
38    config: AuthConfig,
39    csrf_config: Option<CsrfConfig>,
40    rate_limit_config: Option<RateLimitConfig>,
41    cors_config: Option<CorsConfig>,
42    body_limit_config: Option<BodyLimitConfig>,
43    custom_middlewares: Vec<Box<dyn Middleware>>,
44}
45
46/// Typed builder returned by [`AuthBuilder::database`].
47///
48/// Accepts plugins, hooks, and middleware before calling `.build()`.
49pub struct TypedAuthBuilder<DB: DatabaseAdapter> {
50    config: AuthConfig,
51    database: Arc<DB>,
52    plugins: Vec<Box<dyn AuthPlugin<DB>>>,
53    hooks: Vec<Arc<dyn DatabaseHooks<DB>>>,
54    csrf_config: Option<CsrfConfig>,
55    rate_limit_config: Option<RateLimitConfig>,
56    cors_config: Option<CorsConfig>,
57    body_limit_config: Option<BodyLimitConfig>,
58    custom_middlewares: Vec<Box<dyn Middleware>>,
59}
60
61impl AuthBuilder {
62    pub fn new(config: AuthConfig) -> Self {
63        Self {
64            config,
65            csrf_config: None,
66            rate_limit_config: None,
67            cors_config: None,
68            body_limit_config: None,
69            custom_middlewares: Vec::new(),
70        }
71    }
72
73    /// Set the database adapter, returning a [`TypedAuthBuilder`].
74    pub fn database<DB: DatabaseAdapter>(self, database: DB) -> TypedAuthBuilder<DB> {
75        TypedAuthBuilder {
76            config: self.config,
77            database: Arc::new(database),
78            plugins: Vec::new(),
79            hooks: Vec::new(),
80            csrf_config: self.csrf_config,
81            rate_limit_config: self.rate_limit_config,
82            cors_config: self.cors_config,
83            body_limit_config: self.body_limit_config,
84            custom_middlewares: self.custom_middlewares,
85        }
86    }
87
88    /// Configure CSRF protection.
89    pub fn csrf(mut self, config: CsrfConfig) -> Self {
90        self.csrf_config = Some(config);
91        self
92    }
93
94    /// Configure rate limiting.
95    pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
96        self.rate_limit_config = Some(config);
97        self
98    }
99
100    /// Configure CORS.
101    pub fn cors(mut self, config: CorsConfig) -> Self {
102        self.cors_config = Some(config);
103        self
104    }
105
106    /// Configure body size limit.
107    pub fn body_limit(mut self, config: BodyLimitConfig) -> Self {
108        self.body_limit_config = Some(config);
109        self
110    }
111
112    /// Set the email provider.
113    pub fn email_provider<E: EmailProvider + 'static>(mut self, provider: E) -> Self {
114        self.config.email_provider = Some(Arc::new(provider));
115        self
116    }
117}
118
119impl<DB: DatabaseAdapter> TypedAuthBuilder<DB> {
120    /// Add a plugin to the authentication system.
121    pub fn plugin<P: AuthPlugin<DB> + 'static>(mut self, plugin: P) -> Self {
122        self.plugins.push(Box::new(plugin));
123        self
124    }
125
126    /// Add a database lifecycle hook.
127    pub fn hook<H: DatabaseHooks<DB> + 'static>(mut self, hook: H) -> Self {
128        self.hooks.push(Arc::new(hook));
129        self
130    }
131
132    /// Configure CSRF protection.
133    pub fn csrf(mut self, config: CsrfConfig) -> Self {
134        self.csrf_config = Some(config);
135        self
136    }
137
138    /// Configure rate limiting.
139    pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
140        self.rate_limit_config = Some(config);
141        self
142    }
143
144    /// Configure CORS.
145    pub fn cors(mut self, config: CorsConfig) -> Self {
146        self.cors_config = Some(config);
147        self
148    }
149
150    /// Configure body size limit.
151    pub fn body_limit(mut self, config: BodyLimitConfig) -> Self {
152        self.body_limit_config = Some(config);
153        self
154    }
155
156    /// Set the email provider for sending emails.
157    pub fn email_provider<E: EmailProvider + 'static>(mut self, provider: E) -> Self {
158        self.config.email_provider = Some(Arc::new(provider));
159        self
160    }
161
162    /// Add a custom middleware.
163    pub fn middleware<M: Middleware + 'static>(mut self, mw: M) -> Self {
164        self.custom_middlewares.push(Box::new(mw));
165        self
166    }
167
168    /// Build the BetterAuth instance.
169    pub async fn build(self) -> AuthResult<BetterAuth<DB>> {
170        // Validate configuration
171        self.config.validate()?;
172
173        let config = Arc::new(self.config);
174
175        // If hooks are registered, the user should wrap the adapter themselves:
176        //   let db = HookedDatabaseAdapter::new(Arc::new(my_db)).with_hook(hook);
177        //   BetterAuth::new(config).database(db).plugin(...).build().await
178        if !self.hooks.is_empty() {
179            return Err(AuthError::config(
180                "Use HookedDatabaseAdapter directly: \
181                 BetterAuth::new(config).database(HookedDatabaseAdapter::new(Arc::new(db)).with_hook(h))",
182            ));
183        }
184
185        let database = self.database;
186
187        // Create session manager
188        let session_manager = SessionManager::new(config.clone(), database.clone());
189
190        // Create context
191        let mut context = AuthContext::new(config.clone(), database.clone());
192
193        // Initialize all plugins
194        for plugin in &self.plugins {
195            plugin.on_init(&mut context).await?;
196        }
197
198        // Build middleware chain (order matters: body limit → rate limit → CSRF → CORS → custom)
199        let mut middlewares: Vec<Box<dyn Middleware>> = vec![
200            Box::new(BodyLimitMiddleware::new(
201                self.body_limit_config.unwrap_or_default(),
202            )),
203            Box::new(RateLimitMiddleware::new(
204                self.rate_limit_config.unwrap_or_default(),
205            )),
206            Box::new(CsrfMiddleware::new(
207                self.csrf_config.unwrap_or_default(),
208                &config.base_url,
209            )),
210            Box::new(CorsMiddleware::new(self.cors_config.unwrap_or_default())),
211        ];
212
213        middlewares.extend(self.custom_middlewares);
214
215        Ok(BetterAuth {
216            config,
217            plugins: self.plugins,
218            middlewares,
219            database,
220            session_manager,
221            context,
222        })
223    }
224}
225
226impl<DB: DatabaseAdapter> BetterAuth<DB> {
227    /// Create a new BetterAuth builder.
228    #[allow(clippy::new_ret_no_self)]
229    pub fn new(config: AuthConfig) -> AuthBuilder {
230        AuthBuilder::new(config)
231    }
232
233    /// Handle an authentication request.
234    ///
235    /// Errors from plugins and core handlers are automatically converted
236    /// into standardized JSON responses via [`AuthError::into_response`],
237    /// producing `{ "message": "..." }` with the appropriate HTTP status code.
238    pub async fn handle_request(&self, req: AuthRequest) -> AuthResult<AuthResponse> {
239        match self.handle_request_inner(&req).await {
240            Ok(response) => {
241                // Run after-request middleware chain
242                middleware::run_after(&self.middlewares, &req, response).await
243            }
244            Err(err) => {
245                // Convert error to standardized response, then run after-middleware
246                let response = err.into_response();
247                middleware::run_after(&self.middlewares, &req, response).await
248            }
249        }
250    }
251
252    /// Inner request handler that may return errors.
253    async fn handle_request_inner(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
254        // Run before-request middleware chain
255        if let Some(response) = middleware::run_before(&self.middlewares, req).await? {
256            return Ok(response);
257        }
258
259        // Handle core endpoints first
260        if let Some(response) = self.handle_core_request(req).await? {
261            return Ok(response);
262        }
263
264        // Try each plugin until one handles the request
265        for plugin in &self.plugins {
266            if let Some(response) = plugin.on_request(req, &self.context).await? {
267                return Ok(response);
268            }
269        }
270
271        // No handler found
272        Err(AuthError::not_found("No handler found for this request"))
273    }
274
275    /// Get the configuration.
276    pub fn config(&self) -> &AuthConfig {
277        &self.config
278    }
279
280    /// Get the database adapter.
281    pub fn database(&self) -> &Arc<DB> {
282        &self.database
283    }
284
285    /// Get the session manager.
286    pub fn session_manager(&self) -> &SessionManager<DB> {
287        &self.session_manager
288    }
289
290    /// Get all routes from plugins.
291    pub fn routes(&self) -> Vec<(String, &dyn AuthPlugin<DB>)> {
292        let mut routes = Vec::new();
293        for plugin in &self.plugins {
294            for route in plugin.routes() {
295                routes.push((route.path, plugin.as_ref()));
296            }
297        }
298        routes
299    }
300
301    /// Get all plugins.
302    pub fn plugins(&self) -> &[Box<dyn AuthPlugin<DB>>] {
303        &self.plugins
304    }
305
306    /// Get plugin by name.
307    pub fn get_plugin(&self, name: &str) -> Option<&dyn AuthPlugin<DB>> {
308        self.plugins
309            .iter()
310            .find(|p| p.name() == name)
311            .map(|p| p.as_ref())
312    }
313
314    /// List all plugin names.
315    pub fn plugin_names(&self) -> Vec<&'static str> {
316        self.plugins.iter().map(|p| p.name()).collect()
317    }
318
319    /// Generate the OpenAPI spec for all registered routes.
320    pub fn openapi_spec(&self) -> OpenApiSpec {
321        let mut builder = OpenApiBuilder::new("Better Auth", env!("CARGO_PKG_VERSION"))
322            .description("Authentication API")
323            .core_routes();
324
325        for plugin in &self.plugins {
326            builder = builder.plugin(plugin.as_ref());
327        }
328
329        builder.build()
330    }
331
332    /// Handle core authentication requests.
333    async fn handle_core_request(&self, req: &AuthRequest) -> AuthResult<Option<AuthResponse>> {
334        match (req.method(), req.path()) {
335            (HttpMethod::Get, "/ok") => {
336                Ok(Some(AuthResponse::json(200, &OkResponse { ok: true })?))
337            }
338            (HttpMethod::Get, "/error") => {
339                Ok(Some(AuthResponse::json(200, &OkResponse { ok: false })?))
340            }
341            (HttpMethod::Get, "/reference/openapi.json") => {
342                let spec = self.openapi_spec();
343                Ok(Some(AuthResponse::json(200, &spec)?))
344            }
345            (HttpMethod::Post, "/update-user") => Ok(Some(self.handle_update_user(req).await?)),
346            (HttpMethod::Post | HttpMethod::Delete, "/delete-user") => {
347                Ok(Some(self.handle_delete_user(req).await?))
348            }
349            (HttpMethod::Post, "/change-email") => Ok(Some(self.handle_change_email(req).await?)),
350            (HttpMethod::Get, "/delete-user/callback") => {
351                Ok(Some(self.handle_delete_user_callback(req).await?))
352            }
353            _ => Ok(None),
354        }
355    }
356
357    /// Handle user profile update.
358    async fn handle_update_user(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
359        let current_user = self.extract_current_user(req).await?;
360
361        let update_req: UpdateUserRequest = req
362            .body_as_json()
363            .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
364
365        let update_user = UpdateUser {
366            email: update_req.email,
367            name: update_req.name,
368            image: update_req.image,
369            email_verified: None,
370            username: update_req.username,
371            display_username: update_req.display_username,
372            role: update_req.role,
373            banned: None,
374            ban_reason: None,
375            ban_expires: None,
376            two_factor_enabled: None,
377            metadata: update_req.metadata,
378        };
379
380        self.database
381            .update_user(current_user.id(), update_user)
382            .await?;
383
384        Ok(AuthResponse::json(200, &StatusResponse { status: true })?)
385    }
386
387    /// Handle user deletion.
388    async fn handle_delete_user(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
389        let current_user = self.extract_current_user(req).await?;
390
391        self.database
392            .delete_user_sessions(current_user.id())
393            .await?;
394        self.database.delete_user(current_user.id()).await?;
395
396        let response = DeleteUserResponse {
397            success: true,
398            message: "User account successfully deleted".to_string(),
399        };
400
401        Ok(AuthResponse::json(200, &response)?)
402    }
403
404    /// Handle email change.
405    async fn handle_change_email(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
406        let current_user = self.extract_current_user(req).await?;
407
408        let change_req: ChangeEmailRequest = req
409            .body_as_json()
410            .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
411
412        if !change_req.new_email.contains('@') || change_req.new_email.is_empty() {
413            return Err(AuthError::bad_request("Invalid email address"));
414        }
415
416        if self
417            .database
418            .get_user_by_email(&change_req.new_email)
419            .await?
420            .is_some()
421        {
422            return Err(AuthError::conflict("A user with this email already exists"));
423        }
424
425        let update_user = UpdateUser {
426            email: Some(change_req.new_email),
427            name: None,
428            image: None,
429            email_verified: Some(false),
430            username: None,
431            display_username: None,
432            role: None,
433            banned: None,
434            ban_reason: None,
435            ban_expires: None,
436            two_factor_enabled: None,
437            metadata: None,
438        };
439
440        self.database
441            .update_user(current_user.id(), update_user)
442            .await?;
443
444        Ok(AuthResponse::json(
445            200,
446            &StatusMessageResponse {
447                status: true,
448                message: "Email updated".to_string(),
449            },
450        )?)
451    }
452
453    /// Handle delete-user callback (token-based deletion confirmation).
454    async fn handle_delete_user_callback(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
455        let token = req
456            .query
457            .get("token")
458            .ok_or_else(|| AuthError::bad_request("Deletion token is required"))?;
459
460        let verification = self
461            .database
462            .get_verification_by_value(token)
463            .await?
464            .ok_or_else(|| AuthError::bad_request("Invalid or expired deletion token"))?;
465
466        let user_id = verification.identifier();
467
468        self.database.delete_user_sessions(user_id).await?;
469
470        let accounts = self.database.get_user_accounts(user_id).await?;
471        for account in accounts {
472            self.database.delete_account(account.id()).await?;
473        }
474
475        self.database.delete_user(user_id).await?;
476        self.database.delete_verification(verification.id()).await?;
477
478        let response = DeleteUserResponse {
479            success: true,
480            message: "User account successfully deleted".to_string(),
481        };
482
483        Ok(AuthResponse::json(200, &response)?)
484    }
485
486    /// Extract current user from request (validates session).
487    async fn extract_current_user(&self, req: &AuthRequest) -> AuthResult<DB::User> {
488        let token = self
489            .session_manager
490            .extract_session_token(req)
491            .ok_or(AuthError::Unauthenticated)?;
492
493        let session = self
494            .session_manager
495            .get_session(&token)
496            .await?
497            .ok_or(AuthError::SessionNotFound)?;
498
499        let user = self
500            .database
501            .get_user_by_id(session.user_id())
502            .await?
503            .ok_or(AuthError::UserNotFound)?;
504
505        Ok(user)
506    }
507}