1use std::sync::Arc;
2
3use serde::Deserialize;
4
5use better_auth_core::{
6 AuthConfig, AuthContext, AuthError, AuthPlugin, AuthRequest, AuthResponse, AuthResult,
7 BeforeRequestAction, DatabaseAdapter, DatabaseHooks, EmailProvider, HttpMethod, OkResponse,
8 OpenApiBuilder, OpenApiSpec, SessionManager, StatusMessageResponse, SuccessMessageResponse,
9 UpdateUser, UpdateUserRequest, core_paths,
10 entity::{AuthAccount, AuthSession, AuthUser, AuthVerification},
11 middleware::{
12 self, BodyLimitConfig, BodyLimitMiddleware, CorsConfig, CorsMiddleware, CsrfConfig,
13 CsrfMiddleware, Middleware, RateLimitConfig, RateLimitMiddleware,
14 },
15};
16
17#[derive(Debug, Deserialize)]
18struct ChangeEmailRequest {
19 #[serde(rename = "newEmail")]
20 new_email: String,
21}
22
23pub struct BetterAuth<DB: DatabaseAdapter> {
25 config: Arc<AuthConfig>,
26 plugins: Vec<Box<dyn AuthPlugin<DB>>>,
27 middlewares: Vec<Box<dyn Middleware>>,
28 database: Arc<DB>,
29 session_manager: SessionManager<DB>,
30 context: AuthContext<DB>,
31 body_limit_config: BodyLimitConfig,
36}
37
38pub struct AuthBuilder {
43 config: AuthConfig,
44 csrf_config: Option<CsrfConfig>,
45 rate_limit_config: Option<RateLimitConfig>,
46 cors_config: Option<CorsConfig>,
47 body_limit_config: Option<BodyLimitConfig>,
48 custom_middlewares: Vec<Box<dyn Middleware>>,
49}
50
51pub struct TypedAuthBuilder<DB: DatabaseAdapter> {
55 config: AuthConfig,
56 database: Arc<DB>,
57 plugins: Vec<Box<dyn AuthPlugin<DB>>>,
58 hooks: Vec<Arc<dyn DatabaseHooks<DB>>>,
59 csrf_config: Option<CsrfConfig>,
60 rate_limit_config: Option<RateLimitConfig>,
61 cors_config: Option<CorsConfig>,
62 body_limit_config: Option<BodyLimitConfig>,
63 custom_middlewares: Vec<Box<dyn Middleware>>,
64}
65
66impl AuthBuilder {
67 pub fn new(config: AuthConfig) -> Self {
68 Self {
69 config,
70 csrf_config: None,
71 rate_limit_config: None,
72 cors_config: None,
73 body_limit_config: None,
74 custom_middlewares: Vec::new(),
75 }
76 }
77
78 pub fn database<DB: DatabaseAdapter>(self, database: DB) -> TypedAuthBuilder<DB> {
80 TypedAuthBuilder {
81 config: self.config,
82 database: Arc::new(database),
83 plugins: Vec::new(),
84 hooks: Vec::new(),
85 csrf_config: self.csrf_config,
86 rate_limit_config: self.rate_limit_config,
87 cors_config: self.cors_config,
88 body_limit_config: self.body_limit_config,
89 custom_middlewares: self.custom_middlewares,
90 }
91 }
92
93 pub fn csrf(mut self, config: CsrfConfig) -> Self {
95 self.csrf_config = Some(config);
96 self
97 }
98
99 pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
101 self.rate_limit_config = Some(config);
102 self
103 }
104
105 pub fn cors(mut self, config: CorsConfig) -> Self {
107 self.cors_config = Some(config);
108 self
109 }
110
111 pub fn body_limit(mut self, config: BodyLimitConfig) -> Self {
113 self.body_limit_config = Some(config);
114 self
115 }
116
117 pub fn email_provider<E: EmailProvider + 'static>(mut self, provider: E) -> Self {
119 self.config.email_provider = Some(Arc::new(provider));
120 self
121 }
122}
123
124impl<DB: DatabaseAdapter> TypedAuthBuilder<DB> {
125 pub fn plugin<P: AuthPlugin<DB> + 'static>(mut self, plugin: P) -> Self {
127 self.plugins.push(Box::new(plugin));
128 self
129 }
130
131 pub fn hook<H: DatabaseHooks<DB> + 'static>(mut self, hook: H) -> Self {
133 self.hooks.push(Arc::new(hook));
134 self
135 }
136
137 pub fn csrf(mut self, config: CsrfConfig) -> Self {
139 self.csrf_config = Some(config);
140 self
141 }
142
143 pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
145 self.rate_limit_config = Some(config);
146 self
147 }
148
149 pub fn cors(mut self, config: CorsConfig) -> Self {
151 self.cors_config = Some(config);
152 self
153 }
154
155 pub fn body_limit(mut self, config: BodyLimitConfig) -> Self {
157 self.body_limit_config = Some(config);
158 self
159 }
160
161 pub fn email_provider<E: EmailProvider + 'static>(mut self, provider: E) -> Self {
163 self.config.email_provider = Some(Arc::new(provider));
164 self
165 }
166
167 pub fn middleware<M: Middleware + 'static>(mut self, mw: M) -> Self {
169 self.custom_middlewares.push(Box::new(mw));
170 self
171 }
172
173 pub async fn build(self) -> AuthResult<BetterAuth<DB>> {
175 self.config.validate()?;
177
178 let config = Arc::new(self.config);
179
180 if !self.hooks.is_empty() {
184 return Err(AuthError::config(
185 "Use HookedDatabaseAdapter directly: \
186 BetterAuth::new(config).database(HookedDatabaseAdapter::new(Arc::new(db)).with_hook(h))",
187 ));
188 }
189
190 let database = self.database;
191
192 let session_manager = SessionManager::new(config.clone(), database.clone());
194
195 let mut context = AuthContext::new(config.clone(), database.clone());
197
198 for plugin in &self.plugins {
200 plugin.on_init(&mut context).await?;
201 }
202
203 let body_limit_config = self.body_limit_config.unwrap_or_default();
205 let mut middlewares: Vec<Box<dyn Middleware>> = vec![
206 Box::new(BodyLimitMiddleware::new(body_limit_config.clone())),
207 Box::new(RateLimitMiddleware::new(
208 self.rate_limit_config.unwrap_or_default(),
209 )),
210 Box::new(CsrfMiddleware::new(
211 self.csrf_config.unwrap_or_default(),
212 config.clone(),
213 )),
214 Box::new(CorsMiddleware::new(self.cors_config.unwrap_or_default())),
215 ];
216
217 middlewares.extend(self.custom_middlewares);
218
219 Ok(BetterAuth {
220 config,
221 plugins: self.plugins,
222 middlewares,
223 database,
224 session_manager,
225 context,
226 body_limit_config,
227 })
228 }
229}
230
231impl<DB: DatabaseAdapter> BetterAuth<DB> {
232 #[allow(clippy::new_ret_no_self)]
234 pub fn new(config: AuthConfig) -> AuthBuilder {
235 AuthBuilder::new(config)
236 }
237
238 pub async fn handle_request(&self, req: AuthRequest) -> AuthResult<AuthResponse> {
244 let mut req =
247 AuthRequest::from_parts(req.method, req.path, req.headers, req.body, req.query);
248
249 match self.handle_request_inner(&mut req).await {
250 Ok(response) => {
251 middleware::run_after(&self.middlewares, &req, response).await
253 }
254 Err(err) => {
255 let response = err.into_response();
257 middleware::run_after(&self.middlewares, &req, response).await
258 }
259 }
260 }
261
262 async fn handle_request_inner(&self, req: &mut AuthRequest) -> AuthResult<AuthResponse> {
264 if let Some(response) = middleware::run_before(&self.middlewares, req).await? {
266 return Ok(response);
267 }
268
269 let base_path = &self.config.base_path;
275 let stripped_path = if !base_path.is_empty() && base_path != "/" {
276 req.path().strip_prefix(base_path).unwrap_or(req.path())
277 } else {
278 req.path()
279 };
280
281 let mut internal_req = if stripped_path != req.path() {
283 let mut r = req.clone();
284 r.path = stripped_path.to_string();
285 r
286 } else {
287 req.clone()
288 };
289
290 if self.config.is_path_disabled(internal_req.path()) {
292 return Err(AuthError::not_found("This endpoint has been disabled"));
293 }
294
295 for plugin in &self.plugins {
298 if let Some(action) = plugin.before_request(&internal_req, &self.context).await? {
299 match action {
300 BeforeRequestAction::Respond(response) => {
301 return Ok(response);
302 }
303 BeforeRequestAction::InjectSession {
304 user_id,
305 session_token: _,
306 } => {
307 internal_req.set_virtual_user_id(user_id);
313 }
314 }
315 }
316 }
317
318 if let Some(response) = self.handle_core_request(&internal_req).await? {
320 return Ok(response);
321 }
322
323 for plugin in &self.plugins {
325 if let Some(response) = plugin.on_request(&internal_req, &self.context).await? {
326 return Ok(response);
327 }
328 }
329
330 Err(AuthError::not_found("No handler found for this request"))
332 }
333
334 pub fn config(&self) -> &AuthConfig {
336 &self.config
337 }
338
339 pub fn body_limit_config(&self) -> &BodyLimitConfig {
344 &self.body_limit_config
345 }
346
347 pub fn database(&self) -> &Arc<DB> {
349 &self.database
350 }
351
352 pub fn session_manager(&self) -> &SessionManager<DB> {
354 &self.session_manager
355 }
356
357 pub fn routes(&self) -> Vec<(String, &dyn AuthPlugin<DB>)> {
359 let mut routes = Vec::new();
360 for plugin in &self.plugins {
361 for route in plugin.routes() {
362 routes.push((route.path, plugin.as_ref()));
363 }
364 }
365 routes
366 }
367
368 pub fn plugins(&self) -> &[Box<dyn AuthPlugin<DB>>] {
370 &self.plugins
371 }
372
373 pub fn get_plugin(&self, name: &str) -> Option<&dyn AuthPlugin<DB>> {
375 self.plugins
376 .iter()
377 .find(|p| p.name() == name)
378 .map(|p| p.as_ref())
379 }
380
381 pub fn plugin_names(&self) -> Vec<&'static str> {
383 self.plugins.iter().map(|p| p.name()).collect()
384 }
385
386 pub fn openapi_spec(&self) -> OpenApiSpec {
388 let mut builder = OpenApiBuilder::new("Better Auth", env!("CARGO_PKG_VERSION"))
389 .description("Authentication API")
390 .core_routes();
391
392 for plugin in &self.plugins {
393 builder = builder.plugin(plugin.as_ref());
394 }
395
396 builder.build()
397 }
398
399 async fn handle_core_request(&self, req: &AuthRequest) -> AuthResult<Option<AuthResponse>> {
401 match (req.method(), req.path()) {
402 (HttpMethod::Get, core_paths::OK) => {
403 Ok(Some(AuthResponse::json(200, &OkResponse { ok: true })?))
404 }
405 (HttpMethod::Get, core_paths::ERROR) => {
406 Ok(Some(AuthResponse::json(200, &OkResponse { ok: false })?))
407 }
408 (HttpMethod::Get, core_paths::OPENAPI_SPEC) => {
409 let spec = self.openapi_spec();
410 Ok(Some(AuthResponse::json(200, &spec)?))
411 }
412 (HttpMethod::Post, core_paths::UPDATE_USER) => {
413 Ok(Some(self.handle_update_user(req).await?))
414 }
415 (HttpMethod::Post | HttpMethod::Delete, core_paths::DELETE_USER) => {
416 Ok(Some(self.handle_delete_user(req).await?))
417 }
418 (HttpMethod::Post, core_paths::CHANGE_EMAIL) => {
419 Ok(Some(self.handle_change_email(req).await?))
420 }
421 (HttpMethod::Get, core_paths::DELETE_USER_CALLBACK) => {
422 Ok(Some(self.handle_delete_user_callback(req).await?))
423 }
424 _ => Ok(None),
425 }
426 }
427
428 async fn handle_update_user(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
430 let current_user = self.extract_current_user(req).await?;
431
432 let update_req: UpdateUserRequest = req
433 .body_as_json()
434 .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
435
436 let update_user = UpdateUser {
437 email: update_req.email,
438 name: update_req.name,
439 image: update_req.image,
440 email_verified: None,
441 username: update_req.username,
442 display_username: update_req.display_username,
443 role: update_req.role,
444 banned: None,
445 ban_reason: None,
446 ban_expires: None,
447 two_factor_enabled: None,
448 metadata: update_req.metadata,
449 };
450
451 self.database
452 .update_user(current_user.id(), update_user)
453 .await?;
454
455 Ok(AuthResponse::json(
456 200,
457 &better_auth_core::StatusResponse { status: true },
458 )?)
459 }
460
461 async fn handle_delete_user(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
463 let current_user = self.extract_current_user(req).await?;
464
465 self.database
466 .delete_user_sessions(current_user.id())
467 .await?;
468 self.database.delete_user(current_user.id()).await?;
469
470 let response = SuccessMessageResponse {
471 success: true,
472 message: "User account successfully deleted".to_string(),
473 };
474
475 Ok(AuthResponse::json(200, &response)?)
476 }
477
478 async fn handle_change_email(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
480 let current_user = self.extract_current_user(req).await?;
481
482 let change_req: ChangeEmailRequest = req
483 .body_as_json()
484 .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
485
486 if !change_req.new_email.contains('@') || change_req.new_email.is_empty() {
487 return Err(AuthError::bad_request("Invalid email address"));
488 }
489
490 if self
491 .database
492 .get_user_by_email(&change_req.new_email)
493 .await?
494 .is_some()
495 {
496 return Err(AuthError::conflict("A user with this email already exists"));
497 }
498
499 let update_user = UpdateUser {
500 email: Some(change_req.new_email),
501 name: None,
502 image: None,
503 email_verified: Some(false),
504 username: None,
505 display_username: None,
506 role: None,
507 banned: None,
508 ban_reason: None,
509 ban_expires: None,
510 two_factor_enabled: None,
511 metadata: None,
512 };
513
514 self.database
515 .update_user(current_user.id(), update_user)
516 .await?;
517
518 Ok(AuthResponse::json(
519 200,
520 &StatusMessageResponse {
521 status: true,
522 message: "Email updated".to_string(),
523 },
524 )?)
525 }
526
527 async fn handle_delete_user_callback(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
529 let token = req
530 .query
531 .get("token")
532 .ok_or_else(|| AuthError::bad_request("Deletion token is required"))?;
533
534 let verification = self
535 .database
536 .get_verification_by_value(token)
537 .await?
538 .ok_or_else(|| AuthError::bad_request("Invalid or expired deletion token"))?;
539
540 let user_id = verification.identifier();
541
542 self.database.delete_user_sessions(user_id).await?;
543
544 let accounts = self.database.get_user_accounts(user_id).await?;
545 for account in accounts {
546 self.database.delete_account(account.id()).await?;
547 }
548
549 self.database.delete_user(user_id).await?;
550 self.database.delete_verification(verification.id()).await?;
551
552 let response = SuccessMessageResponse {
553 success: true,
554 message: "User account successfully deleted".to_string(),
555 };
556
557 Ok(AuthResponse::json(200, &response)?)
558 }
559
560 async fn extract_current_user(&self, req: &AuthRequest) -> AuthResult<DB::User> {
567 if let Some(uid) = req.virtual_user_id() {
569 return self
570 .database
571 .get_user_by_id(uid)
572 .await?
573 .ok_or(AuthError::UserNotFound);
574 }
575
576 let token = self
577 .session_manager
578 .extract_session_token(req)
579 .ok_or(AuthError::Unauthenticated)?;
580
581 let session = self
582 .session_manager
583 .get_session(&token)
584 .await?
585 .ok_or(AuthError::SessionNotFound)?;
586
587 let user = self
588 .database
589 .get_user_by_id(session.user_id())
590 .await?
591 .ok_or(AuthError::UserNotFound)?;
592
593 Ok(user)
594 }
595}