1use std::sync::Arc;
2
3use serde::Deserialize;
4
5use better_auth_core::{
6 AuthConfig, AuthContext, AuthError, AuthPlugin, AuthRequest, AuthResponse, AuthResult,
7 DatabaseAdapter, DatabaseHooks, DeleteUserResponse, EmailProvider, HttpMethod, OpenApiBuilder,
8 OpenApiSpec, SessionManager, UpdateUser, UpdateUserRequest,
9 entity::{AuthAccount, AuthSession, AuthUser, AuthVerification},
10 middleware::{
11 self, BodyLimitConfig, BodyLimitMiddleware, CorsConfig, CorsMiddleware, CsrfConfig,
12 CsrfMiddleware, Middleware, RateLimitConfig, RateLimitMiddleware,
13 },
14};
15
16#[derive(Debug, Deserialize)]
17struct ChangeEmailRequest {
18 #[serde(rename = "newEmail")]
19 new_email: String,
20}
21
22pub struct BetterAuth<DB: DatabaseAdapter> {
24 config: Arc<AuthConfig>,
25 plugins: Vec<Box<dyn AuthPlugin<DB>>>,
26 middlewares: Vec<Box<dyn Middleware>>,
27 database: Arc<DB>,
28 session_manager: SessionManager<DB>,
29 context: AuthContext<DB>,
30}
31
32pub struct AuthBuilder {
37 config: AuthConfig,
38 csrf_config: Option<CsrfConfig>,
39 rate_limit_config: Option<RateLimitConfig>,
40 cors_config: Option<CorsConfig>,
41 body_limit_config: Option<BodyLimitConfig>,
42 custom_middlewares: Vec<Box<dyn Middleware>>,
43}
44
45pub struct TypedAuthBuilder<DB: DatabaseAdapter> {
49 config: AuthConfig,
50 database: Arc<DB>,
51 plugins: Vec<Box<dyn AuthPlugin<DB>>>,
52 hooks: Vec<Arc<dyn DatabaseHooks<DB>>>,
53 csrf_config: Option<CsrfConfig>,
54 rate_limit_config: Option<RateLimitConfig>,
55 cors_config: Option<CorsConfig>,
56 body_limit_config: Option<BodyLimitConfig>,
57 custom_middlewares: Vec<Box<dyn Middleware>>,
58}
59
60impl AuthBuilder {
61 pub fn new(config: AuthConfig) -> Self {
62 Self {
63 config,
64 csrf_config: None,
65 rate_limit_config: None,
66 cors_config: None,
67 body_limit_config: None,
68 custom_middlewares: Vec::new(),
69 }
70 }
71
72 pub fn database<DB: DatabaseAdapter>(self, database: DB) -> TypedAuthBuilder<DB> {
74 TypedAuthBuilder {
75 config: self.config,
76 database: Arc::new(database),
77 plugins: Vec::new(),
78 hooks: Vec::new(),
79 csrf_config: self.csrf_config,
80 rate_limit_config: self.rate_limit_config,
81 cors_config: self.cors_config,
82 body_limit_config: self.body_limit_config,
83 custom_middlewares: self.custom_middlewares,
84 }
85 }
86
87 pub fn csrf(mut self, config: CsrfConfig) -> Self {
89 self.csrf_config = Some(config);
90 self
91 }
92
93 pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
95 self.rate_limit_config = Some(config);
96 self
97 }
98
99 pub fn cors(mut self, config: CorsConfig) -> Self {
101 self.cors_config = Some(config);
102 self
103 }
104
105 pub fn body_limit(mut self, config: BodyLimitConfig) -> Self {
107 self.body_limit_config = Some(config);
108 self
109 }
110
111 pub fn email_provider<E: EmailProvider + 'static>(mut self, provider: E) -> Self {
113 self.config.email_provider = Some(Arc::new(provider));
114 self
115 }
116}
117
118impl<DB: DatabaseAdapter> TypedAuthBuilder<DB> {
119 pub fn plugin<P: AuthPlugin<DB> + 'static>(mut self, plugin: P) -> Self {
121 self.plugins.push(Box::new(plugin));
122 self
123 }
124
125 pub fn hook<H: DatabaseHooks<DB> + 'static>(mut self, hook: H) -> Self {
127 self.hooks.push(Arc::new(hook));
128 self
129 }
130
131 pub fn csrf(mut self, config: CsrfConfig) -> Self {
133 self.csrf_config = Some(config);
134 self
135 }
136
137 pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
139 self.rate_limit_config = Some(config);
140 self
141 }
142
143 pub fn cors(mut self, config: CorsConfig) -> Self {
145 self.cors_config = Some(config);
146 self
147 }
148
149 pub fn body_limit(mut self, config: BodyLimitConfig) -> Self {
151 self.body_limit_config = Some(config);
152 self
153 }
154
155 pub fn email_provider<E: EmailProvider + 'static>(mut self, provider: E) -> Self {
157 self.config.email_provider = Some(Arc::new(provider));
158 self
159 }
160
161 pub fn middleware<M: Middleware + 'static>(mut self, mw: M) -> Self {
163 self.custom_middlewares.push(Box::new(mw));
164 self
165 }
166
167 pub async fn build(self) -> AuthResult<BetterAuth<DB>> {
169 self.config.validate()?;
171
172 let config = Arc::new(self.config);
173
174 if !self.hooks.is_empty() {
178 return Err(AuthError::config(
179 "Use HookedDatabaseAdapter directly: \
180 BetterAuth::new(config).database(HookedDatabaseAdapter::new(Arc::new(db)).with_hook(h))",
181 ));
182 }
183
184 let database = self.database;
185
186 let session_manager = SessionManager::new(config.clone(), database.clone());
188
189 let mut context = AuthContext::new(config.clone(), database.clone());
191
192 for plugin in &self.plugins {
194 plugin.on_init(&mut context).await?;
195 }
196
197 let mut middlewares: Vec<Box<dyn Middleware>> = vec![
199 Box::new(BodyLimitMiddleware::new(
200 self.body_limit_config.unwrap_or_default(),
201 )),
202 Box::new(RateLimitMiddleware::new(
203 self.rate_limit_config.unwrap_or_default(),
204 )),
205 Box::new(CsrfMiddleware::new(
206 self.csrf_config.unwrap_or_default(),
207 &config.base_url,
208 )),
209 Box::new(CorsMiddleware::new(self.cors_config.unwrap_or_default())),
210 ];
211
212 middlewares.extend(self.custom_middlewares);
213
214 Ok(BetterAuth {
215 config,
216 plugins: self.plugins,
217 middlewares,
218 database,
219 session_manager,
220 context,
221 })
222 }
223}
224
225impl<DB: DatabaseAdapter> BetterAuth<DB> {
226 #[allow(clippy::new_ret_no_self)]
228 pub fn new(config: AuthConfig) -> AuthBuilder {
229 AuthBuilder::new(config)
230 }
231
232 pub async fn handle_request(&self, req: AuthRequest) -> AuthResult<AuthResponse> {
238 match self.handle_request_inner(&req).await {
239 Ok(response) => {
240 middleware::run_after(&self.middlewares, &req, response).await
242 }
243 Err(err) => {
244 let response = err.into_response();
246 middleware::run_after(&self.middlewares, &req, response).await
247 }
248 }
249 }
250
251 async fn handle_request_inner(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
253 if let Some(response) = middleware::run_before(&self.middlewares, req).await? {
255 return Ok(response);
256 }
257
258 if let Some(response) = self.handle_core_request(req).await? {
260 return Ok(response);
261 }
262
263 for plugin in &self.plugins {
265 if let Some(response) = plugin.on_request(req, &self.context).await? {
266 return Ok(response);
267 }
268 }
269
270 Err(AuthError::not_found("No handler found for this request"))
272 }
273
274 pub fn config(&self) -> &AuthConfig {
276 &self.config
277 }
278
279 pub fn database(&self) -> &Arc<DB> {
281 &self.database
282 }
283
284 pub fn session_manager(&self) -> &SessionManager<DB> {
286 &self.session_manager
287 }
288
289 pub fn routes(&self) -> Vec<(String, &dyn AuthPlugin<DB>)> {
291 let mut routes = Vec::new();
292 for plugin in &self.plugins {
293 for route in plugin.routes() {
294 routes.push((route.path, plugin.as_ref()));
295 }
296 }
297 routes
298 }
299
300 pub fn plugins(&self) -> &[Box<dyn AuthPlugin<DB>>] {
302 &self.plugins
303 }
304
305 pub fn get_plugin(&self, name: &str) -> Option<&dyn AuthPlugin<DB>> {
307 self.plugins
308 .iter()
309 .find(|p| p.name() == name)
310 .map(|p| p.as_ref())
311 }
312
313 pub fn plugin_names(&self) -> Vec<&'static str> {
315 self.plugins.iter().map(|p| p.name()).collect()
316 }
317
318 pub fn openapi_spec(&self) -> OpenApiSpec {
320 let mut builder = OpenApiBuilder::new("Better Auth", env!("CARGO_PKG_VERSION"))
321 .description("Authentication API")
322 .core_routes();
323
324 for plugin in &self.plugins {
325 builder = builder.plugin(plugin.as_ref());
326 }
327
328 builder.build()
329 }
330
331 async fn handle_core_request(&self, req: &AuthRequest) -> AuthResult<Option<AuthResponse>> {
333 match (req.method(), req.path()) {
334 (HttpMethod::Get, "/ok") => Ok(Some(AuthResponse::json(
335 200,
336 &serde_json::json!({ "status": true }),
337 )?)),
338 (HttpMethod::Get, "/error") => Ok(Some(AuthResponse::json(
339 200,
340 &serde_json::json!({ "status": false }),
341 )?)),
342 (HttpMethod::Get, "/reference/openapi.json") => {
343 let spec = self.openapi_spec();
344 Ok(Some(AuthResponse::json(200, &spec)?))
345 }
346 (HttpMethod::Post, "/update-user") => Ok(Some(self.handle_update_user(req).await?)),
347 (HttpMethod::Post, "/delete-user") => Ok(Some(self.handle_delete_user(req).await?)),
348 (HttpMethod::Post, "/change-email") => Ok(Some(self.handle_change_email(req).await?)),
349 (HttpMethod::Get, "/delete-user/callback") => {
350 Ok(Some(self.handle_delete_user_callback(req).await?))
351 }
352 _ => Ok(None),
353 }
354 }
355
356 async fn handle_update_user(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
358 let current_user = self.extract_current_user(req).await?;
359
360 let update_req: UpdateUserRequest = req
361 .body_as_json()
362 .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
363
364 let update_user = UpdateUser {
365 email: update_req.email,
366 name: update_req.name,
367 image: update_req.image,
368 email_verified: None,
369 username: update_req.username,
370 display_username: update_req.display_username,
371 role: update_req.role,
372 banned: None,
373 ban_reason: None,
374 ban_expires: None,
375 two_factor_enabled: None,
376 metadata: update_req.metadata,
377 };
378
379 self.database
380 .update_user(current_user.id(), update_user)
381 .await?;
382
383 Ok(AuthResponse::json(
384 200,
385 &serde_json::json!({ "status": true }),
386 )?)
387 }
388
389 async fn handle_delete_user(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
391 let current_user = self.extract_current_user(req).await?;
392
393 self.database
394 .delete_user_sessions(current_user.id())
395 .await?;
396 self.database.delete_user(current_user.id()).await?;
397
398 let response = DeleteUserResponse {
399 success: true,
400 message: "User account successfully deleted".to_string(),
401 };
402
403 Ok(AuthResponse::json(200, &response)?)
404 }
405
406 async fn handle_change_email(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
408 let current_user = self.extract_current_user(req).await?;
409
410 let change_req: ChangeEmailRequest = req
411 .body_as_json()
412 .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
413
414 if !change_req.new_email.contains('@') || change_req.new_email.is_empty() {
415 return Err(AuthError::bad_request("Invalid email address"));
416 }
417
418 if self
419 .database
420 .get_user_by_email(&change_req.new_email)
421 .await?
422 .is_some()
423 {
424 return Err(AuthError::conflict("A user with this email already exists"));
425 }
426
427 let update_user = UpdateUser {
428 email: Some(change_req.new_email),
429 name: None,
430 image: None,
431 email_verified: Some(false),
432 username: None,
433 display_username: None,
434 role: None,
435 banned: None,
436 ban_reason: None,
437 ban_expires: None,
438 two_factor_enabled: None,
439 metadata: None,
440 };
441
442 self.database
443 .update_user(current_user.id(), update_user)
444 .await?;
445
446 Ok(AuthResponse::json(
447 200,
448 &serde_json::json!({ "status": true, "message": "Email updated" }),
449 )?)
450 }
451
452 async fn handle_delete_user_callback(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
454 let token = req
455 .query
456 .get("token")
457 .ok_or_else(|| AuthError::bad_request("Deletion token is required"))?;
458
459 let verification = self
460 .database
461 .get_verification_by_value(token)
462 .await?
463 .ok_or_else(|| AuthError::bad_request("Invalid or expired deletion token"))?;
464
465 let user_id = verification.identifier();
466
467 self.database.delete_user_sessions(user_id).await?;
468
469 let accounts = self.database.get_user_accounts(user_id).await?;
470 for account in accounts {
471 self.database.delete_account(account.id()).await?;
472 }
473
474 self.database.delete_user(user_id).await?;
475 self.database.delete_verification(verification.id()).await?;
476
477 let response = DeleteUserResponse {
478 success: true,
479 message: "User account successfully deleted".to_string(),
480 };
481
482 Ok(AuthResponse::json(200, &response)?)
483 }
484
485 async fn extract_current_user(&self, req: &AuthRequest) -> AuthResult<DB::User> {
487 let token = self
488 .session_manager
489 .extract_session_token(req)
490 .ok_or(AuthError::Unauthenticated)?;
491
492 let session = self
493 .session_manager
494 .get_session(&token)
495 .await?
496 .ok_or(AuthError::SessionNotFound)?;
497
498 let user = self
499 .database
500 .get_user_by_id(session.user_id())
501 .await?
502 .ok_or(AuthError::UserNotFound)?;
503
504 Ok(user)
505 }
506}