Skip to main content

better_auth/core/
auth.rs

1use std::sync::Arc;
2
3use serde::Deserialize;
4
5use better_auth_core::{
6    AuthConfig, AuthContext, AuthError, AuthPlugin, AuthRequest, AuthResponse, AuthResult,
7    DatabaseAdapter, DatabaseHooks, DeleteUserResponse, EmailProvider, HttpMethod, OpenApiBuilder,
8    OpenApiSpec, SessionManager, UpdateUser, UpdateUserRequest,
9    entity::{AuthAccount, AuthSession, AuthUser, AuthVerification},
10    middleware::{
11        self, BodyLimitConfig, BodyLimitMiddleware, CorsConfig, CorsMiddleware, CsrfConfig,
12        CsrfMiddleware, Middleware, RateLimitConfig, RateLimitMiddleware,
13    },
14};
15
16#[derive(Debug, Deserialize)]
17struct ChangeEmailRequest {
18    #[serde(rename = "newEmail")]
19    new_email: String,
20}
21
22/// The main BetterAuth instance, generic over the database adapter.
23pub struct BetterAuth<DB: DatabaseAdapter> {
24    config: Arc<AuthConfig>,
25    plugins: Vec<Box<dyn AuthPlugin<DB>>>,
26    middlewares: Vec<Box<dyn Middleware>>,
27    database: Arc<DB>,
28    session_manager: SessionManager<DB>,
29    context: AuthContext<DB>,
30}
31
32/// Initial builder for configuring BetterAuth.
33///
34/// Call `.database(adapter)` to obtain a [`TypedAuthBuilder`] that can
35/// accept plugins and hooks.
36pub struct AuthBuilder {
37    config: AuthConfig,
38    csrf_config: Option<CsrfConfig>,
39    rate_limit_config: Option<RateLimitConfig>,
40    cors_config: Option<CorsConfig>,
41    body_limit_config: Option<BodyLimitConfig>,
42    custom_middlewares: Vec<Box<dyn Middleware>>,
43}
44
45/// Typed builder returned by [`AuthBuilder::database`].
46///
47/// Accepts plugins, hooks, and middleware before calling `.build()`.
48pub struct TypedAuthBuilder<DB: DatabaseAdapter> {
49    config: AuthConfig,
50    database: Arc<DB>,
51    plugins: Vec<Box<dyn AuthPlugin<DB>>>,
52    hooks: Vec<Arc<dyn DatabaseHooks<DB>>>,
53    csrf_config: Option<CsrfConfig>,
54    rate_limit_config: Option<RateLimitConfig>,
55    cors_config: Option<CorsConfig>,
56    body_limit_config: Option<BodyLimitConfig>,
57    custom_middlewares: Vec<Box<dyn Middleware>>,
58}
59
60impl AuthBuilder {
61    pub fn new(config: AuthConfig) -> Self {
62        Self {
63            config,
64            csrf_config: None,
65            rate_limit_config: None,
66            cors_config: None,
67            body_limit_config: None,
68            custom_middlewares: Vec::new(),
69        }
70    }
71
72    /// Set the database adapter, returning a [`TypedAuthBuilder`].
73    pub fn database<DB: DatabaseAdapter>(self, database: DB) -> TypedAuthBuilder<DB> {
74        TypedAuthBuilder {
75            config: self.config,
76            database: Arc::new(database),
77            plugins: Vec::new(),
78            hooks: Vec::new(),
79            csrf_config: self.csrf_config,
80            rate_limit_config: self.rate_limit_config,
81            cors_config: self.cors_config,
82            body_limit_config: self.body_limit_config,
83            custom_middlewares: self.custom_middlewares,
84        }
85    }
86
87    /// Configure CSRF protection.
88    pub fn csrf(mut self, config: CsrfConfig) -> Self {
89        self.csrf_config = Some(config);
90        self
91    }
92
93    /// Configure rate limiting.
94    pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
95        self.rate_limit_config = Some(config);
96        self
97    }
98
99    /// Configure CORS.
100    pub fn cors(mut self, config: CorsConfig) -> Self {
101        self.cors_config = Some(config);
102        self
103    }
104
105    /// Configure body size limit.
106    pub fn body_limit(mut self, config: BodyLimitConfig) -> Self {
107        self.body_limit_config = Some(config);
108        self
109    }
110
111    /// Set the email provider.
112    pub fn email_provider<E: EmailProvider + 'static>(mut self, provider: E) -> Self {
113        self.config.email_provider = Some(Arc::new(provider));
114        self
115    }
116}
117
118impl<DB: DatabaseAdapter> TypedAuthBuilder<DB> {
119    /// Add a plugin to the authentication system.
120    pub fn plugin<P: AuthPlugin<DB> + 'static>(mut self, plugin: P) -> Self {
121        self.plugins.push(Box::new(plugin));
122        self
123    }
124
125    /// Add a database lifecycle hook.
126    pub fn hook<H: DatabaseHooks<DB> + 'static>(mut self, hook: H) -> Self {
127        self.hooks.push(Arc::new(hook));
128        self
129    }
130
131    /// Configure CSRF protection.
132    pub fn csrf(mut self, config: CsrfConfig) -> Self {
133        self.csrf_config = Some(config);
134        self
135    }
136
137    /// Configure rate limiting.
138    pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
139        self.rate_limit_config = Some(config);
140        self
141    }
142
143    /// Configure CORS.
144    pub fn cors(mut self, config: CorsConfig) -> Self {
145        self.cors_config = Some(config);
146        self
147    }
148
149    /// Configure body size limit.
150    pub fn body_limit(mut self, config: BodyLimitConfig) -> Self {
151        self.body_limit_config = Some(config);
152        self
153    }
154
155    /// Set the email provider for sending emails.
156    pub fn email_provider<E: EmailProvider + 'static>(mut self, provider: E) -> Self {
157        self.config.email_provider = Some(Arc::new(provider));
158        self
159    }
160
161    /// Add a custom middleware.
162    pub fn middleware<M: Middleware + 'static>(mut self, mw: M) -> Self {
163        self.custom_middlewares.push(Box::new(mw));
164        self
165    }
166
167    /// Build the BetterAuth instance.
168    pub async fn build(self) -> AuthResult<BetterAuth<DB>> {
169        // Validate configuration
170        self.config.validate()?;
171
172        let config = Arc::new(self.config);
173
174        // If hooks are registered, the user should wrap the adapter themselves:
175        //   let db = HookedDatabaseAdapter::new(Arc::new(my_db)).with_hook(hook);
176        //   BetterAuth::new(config).database(db).plugin(...).build().await
177        if !self.hooks.is_empty() {
178            return Err(AuthError::config(
179                "Use HookedDatabaseAdapter directly: \
180                 BetterAuth::new(config).database(HookedDatabaseAdapter::new(Arc::new(db)).with_hook(h))",
181            ));
182        }
183
184        let database = self.database;
185
186        // Create session manager
187        let session_manager = SessionManager::new(config.clone(), database.clone());
188
189        // Create context
190        let mut context = AuthContext::new(config.clone(), database.clone());
191
192        // Initialize all plugins
193        for plugin in &self.plugins {
194            plugin.on_init(&mut context).await?;
195        }
196
197        // Build middleware chain (order matters: body limit → rate limit → CSRF → CORS → custom)
198        let mut middlewares: Vec<Box<dyn Middleware>> = vec![
199            Box::new(BodyLimitMiddleware::new(
200                self.body_limit_config.unwrap_or_default(),
201            )),
202            Box::new(RateLimitMiddleware::new(
203                self.rate_limit_config.unwrap_or_default(),
204            )),
205            Box::new(CsrfMiddleware::new(
206                self.csrf_config.unwrap_or_default(),
207                &config.base_url,
208            )),
209            Box::new(CorsMiddleware::new(self.cors_config.unwrap_or_default())),
210        ];
211
212        middlewares.extend(self.custom_middlewares);
213
214        Ok(BetterAuth {
215            config,
216            plugins: self.plugins,
217            middlewares,
218            database,
219            session_manager,
220            context,
221        })
222    }
223}
224
225impl<DB: DatabaseAdapter> BetterAuth<DB> {
226    /// Create a new BetterAuth builder.
227    #[allow(clippy::new_ret_no_self)]
228    pub fn new(config: AuthConfig) -> AuthBuilder {
229        AuthBuilder::new(config)
230    }
231
232    /// Handle an authentication request.
233    ///
234    /// Errors from plugins and core handlers are automatically converted
235    /// into standardized JSON responses via [`AuthError::into_response`],
236    /// producing `{ "message": "..." }` with the appropriate HTTP status code.
237    pub async fn handle_request(&self, req: AuthRequest) -> AuthResult<AuthResponse> {
238        match self.handle_request_inner(&req).await {
239            Ok(response) => {
240                // Run after-request middleware chain
241                middleware::run_after(&self.middlewares, &req, response).await
242            }
243            Err(err) => {
244                // Convert error to standardized response, then run after-middleware
245                let response = err.into_response();
246                middleware::run_after(&self.middlewares, &req, response).await
247            }
248        }
249    }
250
251    /// Inner request handler that may return errors.
252    async fn handle_request_inner(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
253        // Run before-request middleware chain
254        if let Some(response) = middleware::run_before(&self.middlewares, req).await? {
255            return Ok(response);
256        }
257
258        // Handle core endpoints first
259        if let Some(response) = self.handle_core_request(req).await? {
260            return Ok(response);
261        }
262
263        // Try each plugin until one handles the request
264        for plugin in &self.plugins {
265            if let Some(response) = plugin.on_request(req, &self.context).await? {
266                return Ok(response);
267            }
268        }
269
270        // No handler found
271        Err(AuthError::not_found("No handler found for this request"))
272    }
273
274    /// Get the configuration.
275    pub fn config(&self) -> &AuthConfig {
276        &self.config
277    }
278
279    /// Get the database adapter.
280    pub fn database(&self) -> &Arc<DB> {
281        &self.database
282    }
283
284    /// Get the session manager.
285    pub fn session_manager(&self) -> &SessionManager<DB> {
286        &self.session_manager
287    }
288
289    /// Get all routes from plugins.
290    pub fn routes(&self) -> Vec<(String, &dyn AuthPlugin<DB>)> {
291        let mut routes = Vec::new();
292        for plugin in &self.plugins {
293            for route in plugin.routes() {
294                routes.push((route.path, plugin.as_ref()));
295            }
296        }
297        routes
298    }
299
300    /// Get all plugins.
301    pub fn plugins(&self) -> &[Box<dyn AuthPlugin<DB>>] {
302        &self.plugins
303    }
304
305    /// Get plugin by name.
306    pub fn get_plugin(&self, name: &str) -> Option<&dyn AuthPlugin<DB>> {
307        self.plugins
308            .iter()
309            .find(|p| p.name() == name)
310            .map(|p| p.as_ref())
311    }
312
313    /// List all plugin names.
314    pub fn plugin_names(&self) -> Vec<&'static str> {
315        self.plugins.iter().map(|p| p.name()).collect()
316    }
317
318    /// Generate the OpenAPI spec for all registered routes.
319    pub fn openapi_spec(&self) -> OpenApiSpec {
320        let mut builder = OpenApiBuilder::new("Better Auth", env!("CARGO_PKG_VERSION"))
321            .description("Authentication API")
322            .core_routes();
323
324        for plugin in &self.plugins {
325            builder = builder.plugin(plugin.as_ref());
326        }
327
328        builder.build()
329    }
330
331    /// Handle core authentication requests.
332    async fn handle_core_request(&self, req: &AuthRequest) -> AuthResult<Option<AuthResponse>> {
333        match (req.method(), req.path()) {
334            (HttpMethod::Get, "/ok") => Ok(Some(AuthResponse::json(
335                200,
336                &serde_json::json!({ "status": true }),
337            )?)),
338            (HttpMethod::Get, "/error") => Ok(Some(AuthResponse::json(
339                200,
340                &serde_json::json!({ "status": false }),
341            )?)),
342            (HttpMethod::Get, "/reference/openapi.json") => {
343                let spec = self.openapi_spec();
344                Ok(Some(AuthResponse::json(200, &spec)?))
345            }
346            (HttpMethod::Post, "/update-user") => Ok(Some(self.handle_update_user(req).await?)),
347            (HttpMethod::Post, "/delete-user") => Ok(Some(self.handle_delete_user(req).await?)),
348            (HttpMethod::Post, "/change-email") => Ok(Some(self.handle_change_email(req).await?)),
349            (HttpMethod::Get, "/delete-user/callback") => {
350                Ok(Some(self.handle_delete_user_callback(req).await?))
351            }
352            _ => Ok(None),
353        }
354    }
355
356    /// Handle user profile update.
357    async fn handle_update_user(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
358        let current_user = self.extract_current_user(req).await?;
359
360        let update_req: UpdateUserRequest = req
361            .body_as_json()
362            .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
363
364        let update_user = UpdateUser {
365            email: update_req.email,
366            name: update_req.name,
367            image: update_req.image,
368            email_verified: None,
369            username: update_req.username,
370            display_username: update_req.display_username,
371            role: update_req.role,
372            banned: None,
373            ban_reason: None,
374            ban_expires: None,
375            two_factor_enabled: None,
376            metadata: update_req.metadata,
377        };
378
379        self.database
380            .update_user(current_user.id(), update_user)
381            .await?;
382
383        Ok(AuthResponse::json(
384            200,
385            &serde_json::json!({ "status": true }),
386        )?)
387    }
388
389    /// Handle user deletion.
390    async fn handle_delete_user(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
391        let current_user = self.extract_current_user(req).await?;
392
393        self.database
394            .delete_user_sessions(current_user.id())
395            .await?;
396        self.database.delete_user(current_user.id()).await?;
397
398        let response = DeleteUserResponse {
399            success: true,
400            message: "User account successfully deleted".to_string(),
401        };
402
403        Ok(AuthResponse::json(200, &response)?)
404    }
405
406    /// Handle email change.
407    async fn handle_change_email(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
408        let current_user = self.extract_current_user(req).await?;
409
410        let change_req: ChangeEmailRequest = req
411            .body_as_json()
412            .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
413
414        if !change_req.new_email.contains('@') || change_req.new_email.is_empty() {
415            return Err(AuthError::bad_request("Invalid email address"));
416        }
417
418        if self
419            .database
420            .get_user_by_email(&change_req.new_email)
421            .await?
422            .is_some()
423        {
424            return Err(AuthError::conflict("A user with this email already exists"));
425        }
426
427        let update_user = UpdateUser {
428            email: Some(change_req.new_email),
429            name: None,
430            image: None,
431            email_verified: Some(false),
432            username: None,
433            display_username: None,
434            role: None,
435            banned: None,
436            ban_reason: None,
437            ban_expires: None,
438            two_factor_enabled: None,
439            metadata: None,
440        };
441
442        self.database
443            .update_user(current_user.id(), update_user)
444            .await?;
445
446        Ok(AuthResponse::json(
447            200,
448            &serde_json::json!({ "status": true, "message": "Email updated" }),
449        )?)
450    }
451
452    /// Handle delete-user callback (token-based deletion confirmation).
453    async fn handle_delete_user_callback(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
454        let token = req
455            .query
456            .get("token")
457            .ok_or_else(|| AuthError::bad_request("Deletion token is required"))?;
458
459        let verification = self
460            .database
461            .get_verification_by_value(token)
462            .await?
463            .ok_or_else(|| AuthError::bad_request("Invalid or expired deletion token"))?;
464
465        let user_id = verification.identifier();
466
467        self.database.delete_user_sessions(user_id).await?;
468
469        let accounts = self.database.get_user_accounts(user_id).await?;
470        for account in accounts {
471            self.database.delete_account(account.id()).await?;
472        }
473
474        self.database.delete_user(user_id).await?;
475        self.database.delete_verification(verification.id()).await?;
476
477        let response = DeleteUserResponse {
478            success: true,
479            message: "User account successfully deleted".to_string(),
480        };
481
482        Ok(AuthResponse::json(200, &response)?)
483    }
484
485    /// Extract current user from request (validates session).
486    async fn extract_current_user(&self, req: &AuthRequest) -> AuthResult<DB::User> {
487        let token = self
488            .session_manager
489            .extract_session_token(req)
490            .ok_or(AuthError::Unauthenticated)?;
491
492        let session = self
493            .session_manager
494            .get_session(&token)
495            .await?
496            .ok_or(AuthError::SessionNotFound)?;
497
498        let user = self
499            .database
500            .get_user_by_id(session.user_id())
501            .await?
502            .ok_or(AuthError::UserNotFound)?;
503
504        Ok(user)
505    }
506}