1use std::sync::Arc;
2
3use serde::Deserialize;
4
5use better_auth_core::{
6 AuthConfig, AuthContext, AuthError, AuthPlugin, AuthRequest, AuthResponse, AuthResult,
7 DatabaseAdapter, DatabaseHooks, DeleteUserResponse, EmailProvider, HttpMethod, OpenApiBuilder,
8 OpenApiSpec, SessionManager, UpdateUser, UpdateUserRequest,
9 entity::{AuthAccount, AuthSession, AuthUser, AuthVerification},
10 middleware::{
11 self, BodyLimitConfig, BodyLimitMiddleware, CorsConfig, CorsMiddleware, CsrfConfig,
12 CsrfMiddleware, Middleware, RateLimitConfig, RateLimitMiddleware,
13 },
14};
15
16#[derive(Debug, Deserialize)]
17struct ChangeEmailRequest {
18 #[serde(rename = "newEmail")]
19 new_email: String,
20}
21
22pub struct BetterAuth<DB: DatabaseAdapter> {
24 config: Arc<AuthConfig>,
25 plugins: Vec<Box<dyn AuthPlugin<DB>>>,
26 middlewares: Vec<Box<dyn Middleware>>,
27 database: Arc<DB>,
28 session_manager: SessionManager<DB>,
29 context: AuthContext<DB>,
30}
31
32pub struct AuthBuilder {
37 config: AuthConfig,
38 csrf_config: Option<CsrfConfig>,
39 rate_limit_config: Option<RateLimitConfig>,
40 cors_config: Option<CorsConfig>,
41 body_limit_config: Option<BodyLimitConfig>,
42 custom_middlewares: Vec<Box<dyn Middleware>>,
43}
44
45pub struct TypedAuthBuilder<DB: DatabaseAdapter> {
49 config: AuthConfig,
50 database: Arc<DB>,
51 plugins: Vec<Box<dyn AuthPlugin<DB>>>,
52 hooks: Vec<Arc<dyn DatabaseHooks<DB>>>,
53 csrf_config: Option<CsrfConfig>,
54 rate_limit_config: Option<RateLimitConfig>,
55 cors_config: Option<CorsConfig>,
56 body_limit_config: Option<BodyLimitConfig>,
57 custom_middlewares: Vec<Box<dyn Middleware>>,
58}
59
60impl AuthBuilder {
61 pub fn new(config: AuthConfig) -> Self {
62 Self {
63 config,
64 csrf_config: None,
65 rate_limit_config: None,
66 cors_config: None,
67 body_limit_config: None,
68 custom_middlewares: Vec::new(),
69 }
70 }
71
72 pub fn database<DB: DatabaseAdapter>(self, database: DB) -> TypedAuthBuilder<DB> {
74 TypedAuthBuilder {
75 config: self.config,
76 database: Arc::new(database),
77 plugins: Vec::new(),
78 hooks: Vec::new(),
79 csrf_config: self.csrf_config,
80 rate_limit_config: self.rate_limit_config,
81 cors_config: self.cors_config,
82 body_limit_config: self.body_limit_config,
83 custom_middlewares: self.custom_middlewares,
84 }
85 }
86
87 pub fn csrf(mut self, config: CsrfConfig) -> Self {
89 self.csrf_config = Some(config);
90 self
91 }
92
93 pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
95 self.rate_limit_config = Some(config);
96 self
97 }
98
99 pub fn cors(mut self, config: CorsConfig) -> Self {
101 self.cors_config = Some(config);
102 self
103 }
104
105 pub fn body_limit(mut self, config: BodyLimitConfig) -> Self {
107 self.body_limit_config = Some(config);
108 self
109 }
110
111 pub fn email_provider<E: EmailProvider + 'static>(mut self, provider: E) -> Self {
113 self.config.email_provider = Some(Arc::new(provider));
114 self
115 }
116}
117
118impl<DB: DatabaseAdapter> TypedAuthBuilder<DB> {
119 pub fn plugin<P: AuthPlugin<DB> + 'static>(mut self, plugin: P) -> Self {
121 self.plugins.push(Box::new(plugin));
122 self
123 }
124
125 pub fn hook<H: DatabaseHooks<DB> + 'static>(mut self, hook: H) -> Self {
127 self.hooks.push(Arc::new(hook));
128 self
129 }
130
131 pub fn csrf(mut self, config: CsrfConfig) -> Self {
133 self.csrf_config = Some(config);
134 self
135 }
136
137 pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
139 self.rate_limit_config = Some(config);
140 self
141 }
142
143 pub fn cors(mut self, config: CorsConfig) -> Self {
145 self.cors_config = Some(config);
146 self
147 }
148
149 pub fn body_limit(mut self, config: BodyLimitConfig) -> Self {
151 self.body_limit_config = Some(config);
152 self
153 }
154
155 pub fn email_provider<E: EmailProvider + 'static>(mut self, provider: E) -> Self {
157 self.config.email_provider = Some(Arc::new(provider));
158 self
159 }
160
161 pub fn middleware<M: Middleware + 'static>(mut self, mw: M) -> Self {
163 self.custom_middlewares.push(Box::new(mw));
164 self
165 }
166
167 pub async fn build(self) -> AuthResult<BetterAuth<DB>> {
169 self.config.validate()?;
171
172 let config = Arc::new(self.config);
173
174 if !self.hooks.is_empty() {
178 return Err(AuthError::config(
179 "Use HookedDatabaseAdapter directly: \
180 BetterAuth::new(config).database(HookedDatabaseAdapter::new(Arc::new(db)).with_hook(h))",
181 ));
182 }
183
184 let database = self.database;
185
186 let session_manager = SessionManager::new(config.clone(), database.clone());
188
189 let mut context = AuthContext::new(config.clone(), database.clone());
191
192 for plugin in &self.plugins {
194 plugin.on_init(&mut context).await?;
195 }
196
197 let mut middlewares: Vec<Box<dyn Middleware>> = vec![
199 Box::new(BodyLimitMiddleware::new(
200 self.body_limit_config.unwrap_or_default(),
201 )),
202 Box::new(RateLimitMiddleware::new(
203 self.rate_limit_config.unwrap_or_default(),
204 )),
205 Box::new(CsrfMiddleware::new(
206 self.csrf_config.unwrap_or_default(),
207 &config.base_url,
208 )),
209 Box::new(CorsMiddleware::new(self.cors_config.unwrap_or_default())),
210 ];
211
212 middlewares.extend(self.custom_middlewares);
213
214 Ok(BetterAuth {
215 config,
216 plugins: self.plugins,
217 middlewares,
218 database,
219 session_manager,
220 context,
221 })
222 }
223}
224
225impl<DB: DatabaseAdapter> BetterAuth<DB> {
226 #[allow(clippy::new_ret_no_self)]
228 pub fn new(config: AuthConfig) -> AuthBuilder {
229 AuthBuilder::new(config)
230 }
231
232 pub async fn handle_request(&self, req: AuthRequest) -> AuthResult<AuthResponse> {
238 match self.handle_request_inner(&req).await {
239 Ok(response) => {
240 middleware::run_after(&self.middlewares, &req, response).await
242 }
243 Err(err) => {
244 let response = err.into_response();
246 middleware::run_after(&self.middlewares, &req, response).await
247 }
248 }
249 }
250
251 async fn handle_request_inner(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
253 if let Some(response) = middleware::run_before(&self.middlewares, req).await? {
255 return Ok(response);
256 }
257
258 if let Some(response) = self.handle_core_request(req).await? {
260 return Ok(response);
261 }
262
263 for plugin in &self.plugins {
265 if let Some(response) = plugin.on_request(req, &self.context).await? {
266 return Ok(response);
267 }
268 }
269
270 Err(AuthError::not_found("No handler found for this request"))
272 }
273
274 pub fn config(&self) -> &AuthConfig {
276 &self.config
277 }
278
279 pub fn database(&self) -> &Arc<DB> {
281 &self.database
282 }
283
284 pub fn session_manager(&self) -> &SessionManager<DB> {
286 &self.session_manager
287 }
288
289 pub fn routes(&self) -> Vec<(String, &dyn AuthPlugin<DB>)> {
291 let mut routes = Vec::new();
292 for plugin in &self.plugins {
293 for route in plugin.routes() {
294 routes.push((route.path, plugin.as_ref()));
295 }
296 }
297 routes
298 }
299
300 pub fn plugins(&self) -> &[Box<dyn AuthPlugin<DB>>] {
302 &self.plugins
303 }
304
305 pub fn get_plugin(&self, name: &str) -> Option<&dyn AuthPlugin<DB>> {
307 self.plugins
308 .iter()
309 .find(|p| p.name() == name)
310 .map(|p| p.as_ref())
311 }
312
313 pub fn plugin_names(&self) -> Vec<&'static str> {
315 self.plugins.iter().map(|p| p.name()).collect()
316 }
317
318 pub fn openapi_spec(&self) -> OpenApiSpec {
320 let mut builder = OpenApiBuilder::new("Better Auth", env!("CARGO_PKG_VERSION"))
321 .description("Authentication API")
322 .core_routes();
323
324 for plugin in &self.plugins {
325 builder = builder.plugin(plugin.as_ref());
326 }
327
328 builder.build()
329 }
330
331 async fn handle_core_request(&self, req: &AuthRequest) -> AuthResult<Option<AuthResponse>> {
333 match (req.method(), req.path()) {
334 (HttpMethod::Get, "/ok") => Ok(Some(AuthResponse::json(
335 200,
336 &serde_json::json!({ "status": true }),
337 )?)),
338 (HttpMethod::Get, "/error") => Ok(Some(AuthResponse::json(
339 200,
340 &serde_json::json!({ "status": false }),
341 )?)),
342 (HttpMethod::Get, "/reference/openapi.json") => {
343 let spec = self.openapi_spec();
344 Ok(Some(AuthResponse::json(200, &spec)?))
345 }
346 (HttpMethod::Post, "/update-user") => Ok(Some(self.handle_update_user(req).await?)),
347 (HttpMethod::Post | HttpMethod::Delete, "/delete-user") => {
348 Ok(Some(self.handle_delete_user(req).await?))
349 }
350 (HttpMethod::Post, "/change-email") => Ok(Some(self.handle_change_email(req).await?)),
351 (HttpMethod::Get, "/delete-user/callback") => {
352 Ok(Some(self.handle_delete_user_callback(req).await?))
353 }
354 _ => Ok(None),
355 }
356 }
357
358 async fn handle_update_user(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
360 let current_user = self.extract_current_user(req).await?;
361
362 let update_req: UpdateUserRequest = req
363 .body_as_json()
364 .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
365
366 let update_user = UpdateUser {
367 email: update_req.email,
368 name: update_req.name,
369 image: update_req.image,
370 email_verified: None,
371 username: update_req.username,
372 display_username: update_req.display_username,
373 role: update_req.role,
374 banned: None,
375 ban_reason: None,
376 ban_expires: None,
377 two_factor_enabled: None,
378 metadata: update_req.metadata,
379 };
380
381 self.database
382 .update_user(current_user.id(), update_user)
383 .await?;
384
385 Ok(AuthResponse::json(
386 200,
387 &serde_json::json!({ "status": true }),
388 )?)
389 }
390
391 async fn handle_delete_user(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
393 let current_user = self.extract_current_user(req).await?;
394
395 self.database
396 .delete_user_sessions(current_user.id())
397 .await?;
398 self.database.delete_user(current_user.id()).await?;
399
400 let response = DeleteUserResponse {
401 success: true,
402 message: "User account successfully deleted".to_string(),
403 };
404
405 Ok(AuthResponse::json(200, &response)?)
406 }
407
408 async fn handle_change_email(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
410 let current_user = self.extract_current_user(req).await?;
411
412 let change_req: ChangeEmailRequest = req
413 .body_as_json()
414 .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
415
416 if !change_req.new_email.contains('@') || change_req.new_email.is_empty() {
417 return Err(AuthError::bad_request("Invalid email address"));
418 }
419
420 if self
421 .database
422 .get_user_by_email(&change_req.new_email)
423 .await?
424 .is_some()
425 {
426 return Err(AuthError::conflict("A user with this email already exists"));
427 }
428
429 let update_user = UpdateUser {
430 email: Some(change_req.new_email),
431 name: None,
432 image: None,
433 email_verified: Some(false),
434 username: None,
435 display_username: None,
436 role: None,
437 banned: None,
438 ban_reason: None,
439 ban_expires: None,
440 two_factor_enabled: None,
441 metadata: None,
442 };
443
444 self.database
445 .update_user(current_user.id(), update_user)
446 .await?;
447
448 Ok(AuthResponse::json(
449 200,
450 &serde_json::json!({ "status": true, "message": "Email updated" }),
451 )?)
452 }
453
454 async fn handle_delete_user_callback(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
456 let token = req
457 .query
458 .get("token")
459 .ok_or_else(|| AuthError::bad_request("Deletion token is required"))?;
460
461 let verification = self
462 .database
463 .get_verification_by_value(token)
464 .await?
465 .ok_or_else(|| AuthError::bad_request("Invalid or expired deletion token"))?;
466
467 let user_id = verification.identifier();
468
469 self.database.delete_user_sessions(user_id).await?;
470
471 let accounts = self.database.get_user_accounts(user_id).await?;
472 for account in accounts {
473 self.database.delete_account(account.id()).await?;
474 }
475
476 self.database.delete_user(user_id).await?;
477 self.database.delete_verification(verification.id()).await?;
478
479 let response = DeleteUserResponse {
480 success: true,
481 message: "User account successfully deleted".to_string(),
482 };
483
484 Ok(AuthResponse::json(200, &response)?)
485 }
486
487 async fn extract_current_user(&self, req: &AuthRequest) -> AuthResult<DB::User> {
489 let token = self
490 .session_manager
491 .extract_session_token(req)
492 .ok_or(AuthError::Unauthenticated)?;
493
494 let session = self
495 .session_manager
496 .get_session(&token)
497 .await?
498 .ok_or(AuthError::SessionNotFound)?;
499
500 let user = self
501 .database
502 .get_user_by_id(session.user_id())
503 .await?
504 .ok_or(AuthError::UserNotFound)?;
505
506 Ok(user)
507 }
508}