better_auth/core/
auth.rs

1use std::sync::Arc;
2use std::collections::HashMap;
3use chrono;
4
5use crate::types::{
6    AuthRequest, AuthResponse, UpdateUserRequest, UpdateUserResponse, 
7    DeleteUserResponse, UpdateUser, HttpMethod, User
8};
9use crate::error::{AuthError, AuthResult};
10use crate::adapters::DatabaseAdapter;
11use crate::core::{AuthConfig, AuthPlugin, AuthContext, SessionManager};
12
13/// The main BetterAuth instance
14pub struct BetterAuth {
15    config: Arc<AuthConfig>,
16    plugins: Vec<Box<dyn AuthPlugin>>,
17    database: Arc<dyn DatabaseAdapter>,
18    session_manager: SessionManager,
19    context: AuthContext,
20}
21
22/// Builder for configuring BetterAuth
23pub struct AuthBuilder {
24    config: AuthConfig,
25    plugins: Vec<Box<dyn AuthPlugin>>,
26}
27
28impl AuthBuilder {
29    pub fn new(config: AuthConfig) -> Self {
30        Self {
31            config,
32            plugins: Vec::new(),
33        }
34    }
35    
36    /// Add a plugin to the authentication system
37    pub fn plugin<P: AuthPlugin + 'static>(mut self, plugin: P) -> Self {
38        self.plugins.push(Box::new(plugin));
39        self
40    }
41    
42    /// Set the database adapter
43    pub fn database<D: DatabaseAdapter + 'static>(mut self, database: D) -> Self {
44        self.config.database = Some(Arc::new(database));
45        self
46    }
47    
48    /// Build the BetterAuth instance
49    pub async fn build(self) -> AuthResult<BetterAuth> {
50        // Validate configuration
51        self.config.validate()?;
52        
53        let config = Arc::new(self.config);
54        let database = config.database.as_ref().unwrap().clone();
55        
56        // Create session manager
57        let session_manager = SessionManager::new(config.clone(), database.clone());
58        
59        // Create context
60        let mut context = AuthContext::new(config.clone(), database.clone());
61        
62        // Initialize all plugins
63        for plugin in &self.plugins {
64            plugin.on_init(&mut context).await?;
65        }
66        
67        Ok(BetterAuth {
68            config,
69            plugins: self.plugins,
70            database,
71            session_manager,
72            context,
73        })
74    }
75}
76
77impl BetterAuth {
78    /// Create a new BetterAuth builder
79    pub fn new(config: AuthConfig) -> AuthBuilder {
80        AuthBuilder::new(config)
81    }
82    
83    /// Handle an authentication request
84    pub async fn handle_request(&self, req: AuthRequest) -> AuthResult<AuthResponse> {
85        // Handle core endpoints first
86        if let Some(response) = self.handle_core_request(&req).await? {
87            return Ok(response);
88        }
89        
90        // Try each plugin until one handles the request
91        for plugin in &self.plugins {
92            if let Some(response) = plugin.on_request(&req, &self.context).await? {
93                return Ok(response);
94            }
95        }
96        
97        // No plugin handled the request
98        Ok(AuthResponse::json(404, &serde_json::json!({
99            "error": "Not found",
100            "message": "No plugin found to handle this request"
101        }))?)
102    }
103    
104    /// Get the configuration
105    pub fn config(&self) -> &AuthConfig {
106        &self.config
107    }
108    
109    /// Get the database adapter
110    pub fn database(&self) -> &Arc<dyn DatabaseAdapter> {
111        &self.database
112    }
113    
114    /// Get the session manager
115    pub fn session_manager(&self) -> &SessionManager {
116        &self.session_manager
117    }
118    
119    /// Get all routes from plugins
120    pub fn routes(&self) -> Vec<(String, &dyn AuthPlugin)> {
121        let mut routes = Vec::new();
122        for plugin in &self.plugins {
123            for route in plugin.routes() {
124                routes.push((route.path, plugin.as_ref()));
125            }
126        }
127        routes
128    }
129    
130    /// Get all plugins (useful for Axum integration)
131    pub fn plugins(&self) -> &Vec<Box<dyn AuthPlugin>> {
132        &self.plugins
133    }
134    
135    /// Get plugin by name
136    pub fn get_plugin(&self, name: &str) -> Option<&dyn AuthPlugin> {
137        self.plugins.iter()
138            .find(|p| p.name() == name)
139            .map(|p| p.as_ref())
140    }
141    
142    /// List all plugin names
143    pub fn plugin_names(&self) -> Vec<&'static str> {
144        self.plugins.iter().map(|p| p.name()).collect()
145    }
146    
147    /// Handle core authentication requests (user profile management)
148    async fn handle_core_request(&self, req: &AuthRequest) -> AuthResult<Option<AuthResponse>> {
149        match (req.method(), req.path()) {
150            (HttpMethod::Post, "/update-user") => {
151                Ok(Some(self.handle_update_user(req).await?))
152            },
153            (HttpMethod::Delete, "/delete-user") => {
154                Ok(Some(self.handle_delete_user(req).await?))
155            },
156            _ => Ok(None), // Not a core endpoint
157        }
158    }
159    
160    /// Handle user profile update
161    async fn handle_update_user(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
162        // Extract and validate session
163        let current_user = self.extract_current_user(req).await?;
164        
165        // Parse request body
166        let update_req: UpdateUserRequest = match req.body_as_json() {
167            Ok(req) => req,
168            Err(e) => {
169                return Ok(AuthResponse::json(400, &serde_json::json!({
170                    "error": "Invalid request",
171                    "message": format!("Invalid JSON: {}", e)
172                }))?);
173            }
174        };
175        
176        // Convert to UpdateUser
177        let update_user = UpdateUser {
178            email: update_req.email,
179            name: update_req.name,
180            image: update_req.image,
181            email_verified: None, // Don't allow changing verification status through this endpoint
182            username: update_req.username,
183            display_username: update_req.display_username,
184            role: update_req.role,
185            banned: None, // Don't allow changing banned status through this endpoint
186            ban_reason: None,
187            ban_expires: None,
188            two_factor_enabled: None, // Don't allow changing 2FA status through this endpoint
189            metadata: update_req.metadata,
190        };
191        
192        // Update user in database
193        let updated_user = self.database.update_user(&current_user.id, update_user).await?;
194        
195        let response = UpdateUserResponse {
196            user: updated_user,
197        };
198        
199        Ok(AuthResponse::json(200, &response)?)
200    }
201    
202    /// Handle user deletion
203    async fn handle_delete_user(&self, req: &AuthRequest) -> AuthResult<AuthResponse> {
204        // Extract and validate session
205        let current_user = self.extract_current_user(req).await?;
206        
207        // Delete all user sessions first
208        self.database.delete_user_sessions(&current_user.id).await?;
209        
210        // Delete the user
211        self.database.delete_user(&current_user.id).await?;
212        
213        let response = DeleteUserResponse {
214            success: true,
215            message: "User account successfully deleted".to_string(),
216        };
217        
218        Ok(AuthResponse::json(200, &response)?)
219    }
220    
221    /// Extract current user from request (validates session)
222    async fn extract_current_user(&self, req: &AuthRequest) -> AuthResult<User> {
223        // Extract token from Authorization header
224        let token = self.extract_bearer_token(req)
225            .ok_or_else(|| AuthError::Unauthenticated)?;
226        
227        // Get session from database
228        let session = self.database.get_session(&token).await?
229            .ok_or_else(|| AuthError::SessionNotFound)?;
230        
231        // Check if session is expired
232        if session.expires_at < chrono::Utc::now() {
233            return Err(AuthError::SessionNotFound);
234        }
235        
236        // Get user from database
237        let user = self.database.get_user_by_id(&session.user_id).await?
238            .ok_or_else(|| AuthError::UserNotFound)?;
239        
240        Ok(user)
241    }
242    
243    /// Extract Bearer token from Authorization header
244    fn extract_bearer_token(&self, req: &AuthRequest) -> Option<String> {
245        req.headers.get("authorization")
246            .and_then(|auth| {
247                if auth.starts_with("Bearer ") {
248                    Some(auth[7..].to_string())
249                } else {
250                    None
251                }
252            })
253    }
254}