1use std::sync::Arc;
2
3use serde::Deserialize;
4
5use better_auth_core::{
6 AuthConfig, AuthContext, AuthError, AuthPlugin, AuthRequest, AuthResponse, AuthResult,
7 DatabaseAdapter, DatabaseHooks, DeleteUserResponse, EmailProvider, HttpMethod, OkResponse,
8 OpenApiBuilder, OpenApiSpec, SessionManager, StatusMessageResponse, StatusResponse, UpdateUser,
9 UpdateUserRequest,
10 entity::{AuthAccount, AuthSession, AuthUser, AuthVerification},
11 middleware::{
12 self, BodyLimitConfig, BodyLimitMiddleware, CorsConfig, CorsMiddleware, CsrfConfig,
13 CsrfMiddleware, Middleware, RateLimitConfig, RateLimitMiddleware,
14 },
15};
16
17#[derive(Debug, Deserialize)]
18struct ChangeEmailRequest {
19 #[serde(rename = "newEmail")]
20 new_email: String,
21}
22
23pub struct BetterAuth<DB: DatabaseAdapter> {
25 config: Arc<AuthConfig>,
26 plugins: Vec<Box<dyn AuthPlugin<DB>>>,
27 middlewares: Vec<Box<dyn Middleware>>,
28 database: Arc<DB>,
29 session_manager: SessionManager<DB>,
30 context: AuthContext<DB>,
31}
32
33pub struct AuthBuilder {
38 config: AuthConfig,
39 csrf_config: Option<CsrfConfig>,
40 rate_limit_config: Option<RateLimitConfig>,
41 cors_config: Option<CorsConfig>,
42 body_limit_config: Option<BodyLimitConfig>,
43 custom_middlewares: Vec<Box<dyn Middleware>>,
44}
45
46pub struct TypedAuthBuilder<DB: DatabaseAdapter> {
50 config: AuthConfig,
51 database: Arc<DB>,
52 plugins: Vec<Box<dyn AuthPlugin<DB>>>,
53 hooks: Vec<Arc<dyn DatabaseHooks<DB>>>,
54 csrf_config: Option<CsrfConfig>,
55 rate_limit_config: Option<RateLimitConfig>,
56 cors_config: Option<CorsConfig>,
57 body_limit_config: Option<BodyLimitConfig>,
58 custom_middlewares: Vec<Box<dyn Middleware>>,
59}
60
61impl AuthBuilder {
62 pub fn new(config: AuthConfig) -> Self {
63 Self {
64 config,
65 csrf_config: None,
66 rate_limit_config: None,
67 cors_config: None,
68 body_limit_config: None,
69 custom_middlewares: Vec::new(),
70 }
71 }
72
73 pub fn database<DB: DatabaseAdapter>(self, database: DB) -> TypedAuthBuilder<DB> {
75 TypedAuthBuilder {
76 config: self.config,
77 database: Arc::new(database),
78 plugins: Vec::new(),
79 hooks: Vec::new(),
80 csrf_config: self.csrf_config,
81 rate_limit_config: self.rate_limit_config,
82 cors_config: self.cors_config,
83 body_limit_config: self.body_limit_config,
84 custom_middlewares: self.custom_middlewares,
85 }
86 }
87
88 pub fn csrf(mut self, config: CsrfConfig) -> Self {
90 self.csrf_config = Some(config);
91 self
92 }
93
94 pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
96 self.rate_limit_config = Some(config);
97 self
98 }
99
100 pub fn cors(mut self, config: CorsConfig) -> Self {
102 self.cors_config = Some(config);
103 self
104 }
105
106 pub fn body_limit(mut self, config: BodyLimitConfig) -> Self {
108 self.body_limit_config = Some(config);
109 self
110 }
111
112 pub fn email_provider<E: EmailProvider + 'static>(mut self, provider: E) -> Self {
114 self.config.email_provider = Some(Arc::new(provider));
115 self
116 }
117}
118
119impl<DB: DatabaseAdapter> TypedAuthBuilder<DB> {
120 pub fn plugin<P: AuthPlugin<DB> + 'static>(mut self, plugin: P) -> Self {
122 self.plugins.push(Box::new(plugin));
123 self
124 }
125
126 pub fn hook<H: DatabaseHooks<DB> + 'static>(mut self, hook: H) -> Self {
128 self.hooks.push(Arc::new(hook));
129 self
130 }
131
132 pub fn csrf(mut self, config: CsrfConfig) -> Self {
134 self.csrf_config = Some(config);
135 self
136 }
137
138 pub fn rate_limit(mut self, config: RateLimitConfig) -> Self {
140 self.rate_limit_config = Some(config);
141 self
142 }
143
144 pub fn cors(mut self, config: CorsConfig) -> Self {
146 self.cors_config = Some(config);
147 self
148 }
149
150 pub fn body_limit(mut self, config: BodyLimitConfig) -> Self {
152 self.body_limit_config = Some(config);
153 self
154 }
155
156 pub fn email_provider<E: EmailProvider + 'static>(mut self, provider: E) -> Self {
158 self.config.email_provider = Some(Arc::new(provider));
159 self
160 }
161
162 pub fn middleware<M: Middleware + 'static>(mut self, mw: M) -> Self {
164 self.custom_middlewares.push(Box::new(mw));
165 self
166 }
167
168 pub async fn build(self) -> AuthResult<BetterAuth<DB>> {
170 self.config.validate()?;
172
173 let config = Arc::new(self.config);
174
175 if !self.hooks.is_empty() {
179 return Err(AuthError::config(
180 "Use HookedDatabaseAdapter directly: \
181 BetterAuth::new(config).database(HookedDatabaseAdapter::new(Arc::new(db)).with_hook(h))",
182 ));
183 }
184
185 let database = self.database;
186
187 let session_manager = SessionManager::new(config.clone(), database.clone());
189
190 let mut context = AuthContext::new(config.clone(), database.clone());
192
193 for plugin in &self.plugins {
195 plugin.on_init(&mut context).await?;
196 }
197
198 let mut middlewares: Vec<Box<dyn Middleware>> = vec![
200 Box::new(BodyLimitMiddleware::new(
201 self.body_limit_config.unwrap_or_default(),
202 )),
203 Box::new(RateLimitMiddleware::new(
204 self.rate_limit_config.unwrap_or_default(),
205 )),
206 Box::new(CsrfMiddleware::new(
207 self.csrf_config.unwrap_or_default(),
208 &config.base_url,
209 )),
210 Box::new(CorsMiddleware::new(self.cors_config.unwrap_or_default())),
211 ];
212
213 middlewares.extend(self.custom_middlewares);
214
215 Ok(BetterAuth {
216 config,
217 plugins: self.plugins,
218 middlewares,
219 database,
220 session_manager,
221 context,
222 })
223 }
224}
225
226impl<DB: DatabaseAdapter> BetterAuth<DB> {
227 #[allow(clippy::new_ret_no_self)]
229 pub fn new(config: AuthConfig) -> AuthBuilder {
230 AuthBuilder::new(config)
231 }
232
233 pub async fn handle_request(&self, req: AuthRequest) -> AuthResult<AuthResponse> {
239 match self.handle_request_inner(&req).await {
240 Ok(response) => {
241 middleware::run_after(&self.middlewares, &req, response).await
243 }
244 Err(err) => {
245 let response = err.into_response();
247 middleware::run_after(&self.middlewares, &req, response).await
248 }
249 }
250 }
251
252 async fn handle_request_inner(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
254 if let Some(response) = middleware::run_before(&self.middlewares, req).await? {
256 return Ok(response);
257 }
258
259 if let Some(response) = self.handle_core_request(req).await? {
261 return Ok(response);
262 }
263
264 for plugin in &self.plugins {
266 if let Some(response) = plugin.on_request(req, &self.context).await? {
267 return Ok(response);
268 }
269 }
270
271 Err(AuthError::not_found("No handler found for this request"))
273 }
274
275 pub fn config(&self) -> &AuthConfig {
277 &self.config
278 }
279
280 pub fn database(&self) -> &Arc<DB> {
282 &self.database
283 }
284
285 pub fn session_manager(&self) -> &SessionManager<DB> {
287 &self.session_manager
288 }
289
290 pub fn routes(&self) -> Vec<(String, &dyn AuthPlugin<DB>)> {
292 let mut routes = Vec::new();
293 for plugin in &self.plugins {
294 for route in plugin.routes() {
295 routes.push((route.path, plugin.as_ref()));
296 }
297 }
298 routes
299 }
300
301 pub fn plugins(&self) -> &[Box<dyn AuthPlugin<DB>>] {
303 &self.plugins
304 }
305
306 pub fn get_plugin(&self, name: &str) -> Option<&dyn AuthPlugin<DB>> {
308 self.plugins
309 .iter()
310 .find(|p| p.name() == name)
311 .map(|p| p.as_ref())
312 }
313
314 pub fn plugin_names(&self) -> Vec<&'static str> {
316 self.plugins.iter().map(|p| p.name()).collect()
317 }
318
319 pub fn openapi_spec(&self) -> OpenApiSpec {
321 let mut builder = OpenApiBuilder::new("Better Auth", env!("CARGO_PKG_VERSION"))
322 .description("Authentication API")
323 .core_routes();
324
325 for plugin in &self.plugins {
326 builder = builder.plugin(plugin.as_ref());
327 }
328
329 builder.build()
330 }
331
332 async fn handle_core_request(&self, req: &AuthRequest) -> AuthResult<Option<AuthResponse>> {
334 match (req.method(), req.path()) {
335 (HttpMethod::Get, "/ok") => {
336 Ok(Some(AuthResponse::json(200, &OkResponse { ok: true })?))
337 }
338 (HttpMethod::Get, "/error") => {
339 Ok(Some(AuthResponse::json(200, &OkResponse { ok: false })?))
340 }
341 (HttpMethod::Get, "/reference/openapi.json") => {
342 let spec = self.openapi_spec();
343 Ok(Some(AuthResponse::json(200, &spec)?))
344 }
345 (HttpMethod::Post, "/update-user") => Ok(Some(self.handle_update_user(req).await?)),
346 (HttpMethod::Post | HttpMethod::Delete, "/delete-user") => {
347 Ok(Some(self.handle_delete_user(req).await?))
348 }
349 (HttpMethod::Post, "/change-email") => Ok(Some(self.handle_change_email(req).await?)),
350 (HttpMethod::Get, "/delete-user/callback") => {
351 Ok(Some(self.handle_delete_user_callback(req).await?))
352 }
353 _ => Ok(None),
354 }
355 }
356
357 async fn handle_update_user(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
359 let current_user = self.extract_current_user(req).await?;
360
361 let update_req: UpdateUserRequest = req
362 .body_as_json()
363 .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
364
365 let update_user = UpdateUser {
366 email: update_req.email,
367 name: update_req.name,
368 image: update_req.image,
369 email_verified: None,
370 username: update_req.username,
371 display_username: update_req.display_username,
372 role: update_req.role,
373 banned: None,
374 ban_reason: None,
375 ban_expires: None,
376 two_factor_enabled: None,
377 metadata: update_req.metadata,
378 };
379
380 self.database
381 .update_user(current_user.id(), update_user)
382 .await?;
383
384 Ok(AuthResponse::json(200, &StatusResponse { status: true })?)
385 }
386
387 async fn handle_delete_user(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
389 let current_user = self.extract_current_user(req).await?;
390
391 self.database
392 .delete_user_sessions(current_user.id())
393 .await?;
394 self.database.delete_user(current_user.id()).await?;
395
396 let response = DeleteUserResponse {
397 success: true,
398 message: "User account successfully deleted".to_string(),
399 };
400
401 Ok(AuthResponse::json(200, &response)?)
402 }
403
404 async fn handle_change_email(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
406 let current_user = self.extract_current_user(req).await?;
407
408 let change_req: ChangeEmailRequest = req
409 .body_as_json()
410 .map_err(|e| AuthError::bad_request(format!("Invalid JSON: {}", e)))?;
411
412 if !change_req.new_email.contains('@') || change_req.new_email.is_empty() {
413 return Err(AuthError::bad_request("Invalid email address"));
414 }
415
416 if self
417 .database
418 .get_user_by_email(&change_req.new_email)
419 .await?
420 .is_some()
421 {
422 return Err(AuthError::conflict("A user with this email already exists"));
423 }
424
425 let update_user = UpdateUser {
426 email: Some(change_req.new_email),
427 name: None,
428 image: None,
429 email_verified: Some(false),
430 username: None,
431 display_username: None,
432 role: None,
433 banned: None,
434 ban_reason: None,
435 ban_expires: None,
436 two_factor_enabled: None,
437 metadata: None,
438 };
439
440 self.database
441 .update_user(current_user.id(), update_user)
442 .await?;
443
444 Ok(AuthResponse::json(
445 200,
446 &StatusMessageResponse {
447 status: true,
448 message: "Email updated".to_string(),
449 },
450 )?)
451 }
452
453 async fn handle_delete_user_callback(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
455 let token = req
456 .query
457 .get("token")
458 .ok_or_else(|| AuthError::bad_request("Deletion token is required"))?;
459
460 let verification = self
461 .database
462 .get_verification_by_value(token)
463 .await?
464 .ok_or_else(|| AuthError::bad_request("Invalid or expired deletion token"))?;
465
466 let user_id = verification.identifier();
467
468 self.database.delete_user_sessions(user_id).await?;
469
470 let accounts = self.database.get_user_accounts(user_id).await?;
471 for account in accounts {
472 self.database.delete_account(account.id()).await?;
473 }
474
475 self.database.delete_user(user_id).await?;
476 self.database.delete_verification(verification.id()).await?;
477
478 let response = DeleteUserResponse {
479 success: true,
480 message: "User account successfully deleted".to_string(),
481 };
482
483 Ok(AuthResponse::json(200, &response)?)
484 }
485
486 async fn extract_current_user(&self, req: &AuthRequest) -> AuthResult<DB::User> {
488 let token = self
489 .session_manager
490 .extract_session_token(req)
491 .ok_or(AuthError::Unauthenticated)?;
492
493 let session = self
494 .session_manager
495 .get_session(&token)
496 .await?
497 .ok_or(AuthError::SessionNotFound)?;
498
499 let user = self
500 .database
501 .get_user_by_id(session.user_id())
502 .await?
503 .ok_or(AuthError::UserNotFound)?;
504
505 Ok(user)
506 }
507}