1use std::sync::Arc;
2
3use serde::Deserialize;
4
5use better_auth_core::{
6 AuthConfig, AuthContext, AuthError, AuthPlugin, AuthRequest, AuthResponse, AuthResult,
7 BeforeRequestAction, DatabaseAdapter, DatabaseHooks, EmailProvider, HttpMethod, OkResponse,
8 OpenApiBuilder, OpenApiSpec, SessionManager, StatusMessageResponse, SuccessMessageResponse,
9 UpdateUser, UpdateUserRequest, core_paths,
10 entity::{AuthAccount, AuthSession, AuthUser, AuthVerification},
11 middleware::{
12 self, BodyLimitConfig, BodyLimitMiddleware, CorsConfig, CorsMiddleware, CsrfConfig,
13 CsrfMiddleware, Middleware, RateLimitConfig, RateLimitMiddleware,
14 },
15};
16
17#[derive(Debug, Deserialize)]
18struct ChangeEmailRequest {
19 #[serde(rename = "newEmail")]
20 new_email: String,
21}
22
23pub struct BetterAuth<DB: DatabaseAdapter> {
25 config: Arc<AuthConfig>,
26 plugins: Vec<Box<dyn AuthPlugin<DB>>>,
27 middlewares: Vec<Box<dyn Middleware>>,
28 database: Arc<DB>,
29 session_manager: SessionManager<DB>,
30 context: AuthContext<DB>,
31}
32
33pub struct AuthBuilder {
38 config: AuthConfig,
39 csrf_config: Option<CsrfConfig>,
40 rate_limit_config: Option<RateLimitConfig>,
41 cors_config: Option<CorsConfig>,
42 body_limit_config: Option<BodyLimitConfig>,
43 custom_middlewares: Vec<Box<dyn Middleware>>,
44}
45
46pub struct TypedAuthBuilder<DB: DatabaseAdapter> {
50 config: AuthConfig,
51 database: Arc<DB>,
52 plugins: Vec<Box<dyn AuthPlugin<DB>>>,
53 hooks: Vec<Arc<dyn DatabaseHooks<DB>>>,
54 csrf_config: Option<CsrfConfig>,
55 rate_limit_config: Option<RateLimitConfig>,
56 cors_config: Option<CorsConfig>,
57 body_limit_config: Option<BodyLimitConfig>,
58 custom_middlewares: Vec<Box<dyn Middleware>>,
59}
60
61impl AuthBuilder {
62 pub fn new(config: AuthConfig) -> Self {
63 Self {
64 config,
65 csrf_config: None,
66 rate_limit_config: None,
67 cors_config: None,
68 body_limit_config: None,
69 custom_middlewares: Vec::new(),
70 }
71 }
72
73 pub fn database<DB: DatabaseAdapter>(self, database: DB) -> TypedAuthBuilder<DB> {
75 TypedAuthBuilder {
76 config: self.config,
77 database: Arc::new(database),
78 plugins: Vec::new(),
79 hooks: Vec::new(),
80 csrf_config: self.csrf_config,
81 rate_limit_config: self.rate_limit_config,
82 cors_config: self.cors_config,
83 body_limit_config: self.body_limit_config,
84 custom_middlewares: self.custom_middlewares,
85 }
86 }
87
88 pub fn csrf(mut self, config: CsrfConfig) -> Self {
90 self.csrf_config = Some(config);
91 self
92 }
93
94 pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
96 self.rate_limit_config = Some(config);
97 self
98 }
99
100 pub fn cors(mut self, config: CorsConfig) -> Self {
102 self.cors_config = Some(config);
103 self
104 }
105
106 pub fn body_limit(mut self, config: BodyLimitConfig) -> Self {
108 self.body_limit_config = Some(config);
109 self
110 }
111
112 pub fn email_provider<E: EmailProvider + 'static>(mut self, provider: E) -> Self {
114 self.config.email_provider = Some(Arc::new(provider));
115 self
116 }
117}
118
119impl<DB: DatabaseAdapter> TypedAuthBuilder<DB> {
120 pub fn plugin<P: AuthPlugin<DB> + 'static>(mut self, plugin: P) -> Self {
122 self.plugins.push(Box::new(plugin));
123 self
124 }
125
126 pub fn hook<H: DatabaseHooks<DB> + 'static>(mut self, hook: H) -> Self {
128 self.hooks.push(Arc::new(hook));
129 self
130 }
131
132 pub fn csrf(mut self, config: CsrfConfig) -> Self {
134 self.csrf_config = Some(config);
135 self
136 }
137
138 pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
140 self.rate_limit_config = Some(config);
141 self
142 }
143
144 pub fn cors(mut self, config: CorsConfig) -> Self {
146 self.cors_config = Some(config);
147 self
148 }
149
150 pub fn body_limit(mut self, config: BodyLimitConfig) -> Self {
152 self.body_limit_config = Some(config);
153 self
154 }
155
156 pub fn email_provider<E: EmailProvider + 'static>(mut self, provider: E) -> Self {
158 self.config.email_provider = Some(Arc::new(provider));
159 self
160 }
161
162 pub fn middleware<M: Middleware + 'static>(mut self, mw: M) -> Self {
164 self.custom_middlewares.push(Box::new(mw));
165 self
166 }
167
168 pub async fn build(self) -> AuthResult<BetterAuth<DB>> {
170 self.config.validate()?;
172
173 let config = Arc::new(self.config);
174
175 if !self.hooks.is_empty() {
179 return Err(AuthError::config(
180 "Use HookedDatabaseAdapter directly: \
181 BetterAuth::new(config).database(HookedDatabaseAdapter::new(Arc::new(db)).with_hook(h))",
182 ));
183 }
184
185 let database = self.database;
186
187 let session_manager = SessionManager::new(config.clone(), database.clone());
189
190 let mut context = AuthContext::new(config.clone(), database.clone());
192
193 for plugin in &self.plugins {
195 plugin.on_init(&mut context).await?;
196 }
197
198 let mut middlewares: Vec<Box<dyn Middleware>> = vec![
200 Box::new(BodyLimitMiddleware::new(
201 self.body_limit_config.unwrap_or_default(),
202 )),
203 Box::new(RateLimitMiddleware::new(
204 self.rate_limit_config.unwrap_or_default(),
205 )),
206 Box::new(CsrfMiddleware::new(
207 self.csrf_config.unwrap_or_default(),
208 config.clone(),
209 )),
210 Box::new(CorsMiddleware::new(self.cors_config.unwrap_or_default())),
211 ];
212
213 middlewares.extend(self.custom_middlewares);
214
215 Ok(BetterAuth {
216 config,
217 plugins: self.plugins,
218 middlewares,
219 database,
220 session_manager,
221 context,
222 })
223 }
224}
225
226impl<DB: DatabaseAdapter> BetterAuth<DB> {
227 #[allow(clippy::new_ret_no_self)]
229 pub fn new(config: AuthConfig) -> AuthBuilder {
230 AuthBuilder::new(config)
231 }
232
233 pub async fn handle_request(&self, req: AuthRequest) -> AuthResult<AuthResponse> {
239 let mut req =
242 AuthRequest::from_parts(req.method, req.path, req.headers, req.body, req.query);
243
244 match self.handle_request_inner(&mut req).await {
245 Ok(response) => {
246 middleware::run_after(&self.middlewares, &req, response).await
248 }
249 Err(err) => {
250 let response = err.into_response();
252 middleware::run_after(&self.middlewares, &req, response).await
253 }
254 }
255 }
256
257 async fn handle_request_inner(&self, req: &mut AuthRequest) -> AuthResult<AuthResponse> {
259 if let Some(response) = middleware::run_before(&self.middlewares, req).await? {
261 return Ok(response);
262 }
263
264 let base_path = &self.config.base_path;
270 let stripped_path = if !base_path.is_empty() && base_path != "/" {
271 req.path().strip_prefix(base_path).unwrap_or(req.path())
272 } else {
273 req.path()
274 };
275
276 let mut internal_req = if stripped_path != req.path() {
278 let mut r = req.clone();
279 r.path = stripped_path.to_string();
280 r
281 } else {
282 req.clone()
283 };
284
285 if self.config.is_path_disabled(internal_req.path()) {
287 return Err(AuthError::not_found("This endpoint has been disabled"));
288 }
289
290 for plugin in &self.plugins {
293 if let Some(action) = plugin.before_request(&internal_req, &self.context).await? {
294 match action {
295 BeforeRequestAction::Respond(response) => {
296 return Ok(response);
297 }
298 BeforeRequestAction::InjectSession {
299 user_id,
300 session_token: _,
301 } => {
302 internal_req.set_virtual_user_id(user_id);
308 }
309 }
310 }
311 }
312
313 if let Some(response) = self.handle_core_request(&internal_req).await? {
315 return Ok(response);
316 }
317
318 for plugin in &self.plugins {
320 if let Some(response) = plugin.on_request(&internal_req, &self.context).await? {
321 return Ok(response);
322 }
323 }
324
325 Err(AuthError::not_found("No handler found for this request"))
327 }
328
329 pub fn config(&self) -> &AuthConfig {
331 &self.config
332 }
333
334 pub fn database(&self) -> &Arc<DB> {
336 &self.database
337 }
338
339 pub fn session_manager(&self) -> &SessionManager<DB> {
341 &self.session_manager
342 }
343
344 pub fn routes(&self) -> Vec<(String, &dyn AuthPlugin<DB>)> {
346 let mut routes = Vec::new();
347 for plugin in &self.plugins {
348 for route in plugin.routes() {
349 routes.push((route.path, plugin.as_ref()));
350 }
351 }
352 routes
353 }
354
355 pub fn plugins(&self) -> &[Box<dyn AuthPlugin<DB>>] {
357 &self.plugins
358 }
359
360 pub fn get_plugin(&self, name: &str) -> Option<&dyn AuthPlugin<DB>> {
362 self.plugins
363 .iter()
364 .find(|p| p.name() == name)
365 .map(|p| p.as_ref())
366 }
367
368 pub fn plugin_names(&self) -> Vec<&'static str> {
370 self.plugins.iter().map(|p| p.name()).collect()
371 }
372
373 pub fn openapi_spec(&self) -> OpenApiSpec {
375 let mut builder = OpenApiBuilder::new("Better Auth", env!("CARGO_PKG_VERSION"))
376 .description("Authentication API")
377 .core_routes();
378
379 for plugin in &self.plugins {
380 builder = builder.plugin(plugin.as_ref());
381 }
382
383 builder.build()
384 }
385
386 async fn handle_core_request(&self, req: &AuthRequest) -> AuthResult<Option<AuthResponse>> {
388 match (req.method(), req.path()) {
389 (HttpMethod::Get, core_paths::OK) => {
390 Ok(Some(AuthResponse::json(200, &OkResponse { ok: true })?))
391 }
392 (HttpMethod::Get, core_paths::ERROR) => {
393 Ok(Some(AuthResponse::json(200, &OkResponse { ok: false })?))
394 }
395 (HttpMethod::Get, core_paths::OPENAPI_SPEC) => {
396 let spec = self.openapi_spec();
397 Ok(Some(AuthResponse::json(200, &spec)?))
398 }
399 (HttpMethod::Post, core_paths::UPDATE_USER) => {
400 Ok(Some(self.handle_update_user(req).await?))
401 }
402 (HttpMethod::Post | HttpMethod::Delete, core_paths::DELETE_USER) => {
403 Ok(Some(self.handle_delete_user(req).await?))
404 }
405 (HttpMethod::Post, core_paths::CHANGE_EMAIL) => {
406 Ok(Some(self.handle_change_email(req).await?))
407 }
408 (HttpMethod::Get, core_paths::DELETE_USER_CALLBACK) => {
409 Ok(Some(self.handle_delete_user_callback(req).await?))
410 }
411 _ => Ok(None),
412 }
413 }
414
415 async fn handle_update_user(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
417 let current_user = self.extract_current_user(req).await?;
418
419 let update_req: UpdateUserRequest = req
420 .body_as_json()
421 .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
422
423 let update_user = UpdateUser {
424 email: update_req.email,
425 name: update_req.name,
426 image: update_req.image,
427 email_verified: None,
428 username: update_req.username,
429 display_username: update_req.display_username,
430 role: update_req.role,
431 banned: None,
432 ban_reason: None,
433 ban_expires: None,
434 two_factor_enabled: None,
435 metadata: update_req.metadata,
436 };
437
438 self.database
439 .update_user(current_user.id(), update_user)
440 .await?;
441
442 Ok(AuthResponse::json(
443 200,
444 &better_auth_core::StatusResponse { status: true },
445 )?)
446 }
447
448 async fn handle_delete_user(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
450 let current_user = self.extract_current_user(req).await?;
451
452 self.database
453 .delete_user_sessions(current_user.id())
454 .await?;
455 self.database.delete_user(current_user.id()).await?;
456
457 let response = SuccessMessageResponse {
458 success: true,
459 message: "User account successfully deleted".to_string(),
460 };
461
462 Ok(AuthResponse::json(200, &response)?)
463 }
464
465 async fn handle_change_email(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
467 let current_user = self.extract_current_user(req).await?;
468
469 let change_req: ChangeEmailRequest = req
470 .body_as_json()
471 .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
472
473 if !change_req.new_email.contains('@') || change_req.new_email.is_empty() {
474 return Err(AuthError::bad_request("Invalid email address"));
475 }
476
477 if self
478 .database
479 .get_user_by_email(&change_req.new_email)
480 .await?
481 .is_some()
482 {
483 return Err(AuthError::conflict("A user with this email already exists"));
484 }
485
486 let update_user = UpdateUser {
487 email: Some(change_req.new_email),
488 name: None,
489 image: None,
490 email_verified: Some(false),
491 username: None,
492 display_username: None,
493 role: None,
494 banned: None,
495 ban_reason: None,
496 ban_expires: None,
497 two_factor_enabled: None,
498 metadata: None,
499 };
500
501 self.database
502 .update_user(current_user.id(), update_user)
503 .await?;
504
505 Ok(AuthResponse::json(
506 200,
507 &StatusMessageResponse {
508 status: true,
509 message: "Email updated".to_string(),
510 },
511 )?)
512 }
513
514 async fn handle_delete_user_callback(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
516 let token = req
517 .query
518 .get("token")
519 .ok_or_else(|| AuthError::bad_request("Deletion token is required"))?;
520
521 let verification = self
522 .database
523 .get_verification_by_value(token)
524 .await?
525 .ok_or_else(|| AuthError::bad_request("Invalid or expired deletion token"))?;
526
527 let user_id = verification.identifier();
528
529 self.database.delete_user_sessions(user_id).await?;
530
531 let accounts = self.database.get_user_accounts(user_id).await?;
532 for account in accounts {
533 self.database.delete_account(account.id()).await?;
534 }
535
536 self.database.delete_user(user_id).await?;
537 self.database.delete_verification(verification.id()).await?;
538
539 let response = SuccessMessageResponse {
540 success: true,
541 message: "User account successfully deleted".to_string(),
542 };
543
544 Ok(AuthResponse::json(200, &response)?)
545 }
546
547 async fn extract_current_user(&self, req: &AuthRequest) -> AuthResult<DB::User> {
554 if let Some(uid) = req.virtual_user_id() {
556 return self
557 .database
558 .get_user_by_id(uid)
559 .await?
560 .ok_or(AuthError::UserNotFound);
561 }
562
563 let token = self
564 .session_manager
565 .extract_session_token(req)
566 .ok_or(AuthError::Unauthenticated)?;
567
568 let session = self
569 .session_manager
570 .get_session(&token)
571 .await?
572 .ok_or(AuthError::SessionNotFound)?;
573
574 let user = self
575 .database
576 .get_user_by_id(session.user_id())
577 .await?
578 .ok_or(AuthError::UserNotFound)?;
579
580 Ok(user)
581 }
582}