Skip to main content

better_auth_core/
plugin.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use crate::adapters::DatabaseAdapter;
6use crate::config::AuthConfig;
7use crate::email::EmailProvider;
8use crate::entity::AuthSession;
9use crate::error::{AuthError, AuthResult};
10use crate::types::{AuthRequest, AuthResponse, HttpMethod};
11
12/// Action returned by [`AuthPlugin::before_request`].
13#[derive(Debug)]
14pub enum BeforeRequestAction {
15    /// Short-circuit with this response (e.g. return session JSON).
16    Respond(AuthResponse),
17    /// Inject a virtual session so downstream handlers see it as authenticated.
18    InjectSession {
19        user_id: String,
20        session_token: String,
21    },
22}
23
24/// Plugin trait that all authentication plugins must implement.
25///
26/// Generic over `DB` so that lifecycle hooks receive the adapter's concrete
27/// entity types (e.g., `DB::User`, `DB::Session`).
28#[async_trait]
29pub trait AuthPlugin<DB: DatabaseAdapter>: Send + Sync {
30    /// Plugin name - should be unique
31    fn name(&self) -> &'static str;
32
33    /// Routes that this plugin handles
34    fn routes(&self) -> Vec<AuthRoute>;
35
36    /// Called when the plugin is initialized
37    async fn on_init(&self, ctx: &mut AuthContext<DB>) -> AuthResult<()> {
38        let _ = ctx;
39        Ok(())
40    }
41
42    /// Called before route matching for every incoming request.
43    ///
44    /// Return `Some(BeforeRequestAction::Respond(..))` to short-circuit with a
45    /// response, `Some(BeforeRequestAction::InjectSession { .. })` to attach a
46    /// virtual session (e.g. API-key → session emulation), or `None` to let the
47    /// request continue to normal route matching.
48    async fn before_request(
49        &self,
50        _req: &AuthRequest,
51        _ctx: &AuthContext<DB>,
52    ) -> AuthResult<Option<BeforeRequestAction>> {
53        Ok(None)
54    }
55
56    /// Called for each request - return Some(response) to handle, None to pass through
57    async fn on_request(
58        &self,
59        req: &AuthRequest,
60        ctx: &AuthContext<DB>,
61    ) -> AuthResult<Option<AuthResponse>>;
62
63    /// Called after a user is created
64    async fn on_user_created(&self, user: &DB::User, ctx: &AuthContext<DB>) -> AuthResult<()> {
65        let _ = (user, ctx);
66        Ok(())
67    }
68
69    /// Called after a session is created
70    async fn on_session_created(
71        &self,
72        session: &DB::Session,
73        ctx: &AuthContext<DB>,
74    ) -> AuthResult<()> {
75        let _ = (session, ctx);
76        Ok(())
77    }
78
79    /// Called before a user is deleted
80    async fn on_user_deleted(&self, user_id: &str, ctx: &AuthContext<DB>) -> AuthResult<()> {
81        let _ = (user_id, ctx);
82        Ok(())
83    }
84
85    /// Called before a session is deleted
86    async fn on_session_deleted(
87        &self,
88        session_token: &str,
89        ctx: &AuthContext<DB>,
90    ) -> AuthResult<()> {
91        let _ = (session_token, ctx);
92        Ok(())
93    }
94}
95
96/// Route definition for plugins
97#[derive(Debug, Clone)]
98pub struct AuthRoute {
99    pub path: String,
100    pub method: HttpMethod,
101    /// Identifier used as the OpenAPI `operationId` for this route.
102    pub operation_id: String,
103}
104
105/// Context passed to plugin methods
106pub struct AuthContext<DB: DatabaseAdapter> {
107    pub config: Arc<AuthConfig>,
108    pub database: Arc<DB>,
109    pub email_provider: Option<Arc<dyn EmailProvider>>,
110    pub metadata: HashMap<String, serde_json::Value>,
111}
112
113impl AuthRoute {
114    pub fn new(
115        method: HttpMethod,
116        path: impl Into<String>,
117        operation_id: impl Into<String>,
118    ) -> Self {
119        Self {
120            path: path.into(),
121            method,
122            operation_id: operation_id.into(),
123        }
124    }
125
126    pub fn get(path: impl Into<String>, operation_id: impl Into<String>) -> Self {
127        Self::new(HttpMethod::Get, path, operation_id)
128    }
129
130    pub fn post(path: impl Into<String>, operation_id: impl Into<String>) -> Self {
131        Self::new(HttpMethod::Post, path, operation_id)
132    }
133
134    pub fn put(path: impl Into<String>, operation_id: impl Into<String>) -> Self {
135        Self::new(HttpMethod::Put, path, operation_id)
136    }
137
138    pub fn delete(path: impl Into<String>, operation_id: impl Into<String>) -> Self {
139        Self::new(HttpMethod::Delete, path, operation_id)
140    }
141}
142
143impl<DB: DatabaseAdapter> AuthContext<DB> {
144    pub fn new(config: Arc<AuthConfig>, database: Arc<DB>) -> Self {
145        let email_provider = config.email_provider.clone();
146        Self {
147            config,
148            database,
149            email_provider,
150            metadata: HashMap::new(),
151        }
152    }
153
154    pub fn set_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
155        self.metadata.insert(key.into(), value);
156    }
157
158    pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
159        self.metadata.get(key)
160    }
161
162    /// Get the email provider, returning an error if none is configured.
163    pub fn email_provider(&self) -> AuthResult<&dyn EmailProvider> {
164        self.email_provider
165            .as_deref()
166            .ok_or_else(|| AuthError::config("No email provider configured"))
167    }
168
169    /// Extract a session token from the request, validate the session, and
170    /// return the authenticated `(User, Session)` pair.
171    ///
172    /// This centralises the pattern previously duplicated across many plugins
173    /// (`get_authenticated_user`, `require_session`, etc.).
174    pub async fn require_session(&self, req: &AuthRequest) -> AuthResult<(DB::User, DB::Session)> {
175        let session_manager =
176            crate::session::SessionManager::new(self.config.clone(), self.database.clone());
177
178        if let Some(token) = session_manager.extract_session_token(req)
179            && let Some(session) = session_manager.get_session(&token).await?
180            && let Some(user) = self.database.get_user_by_id(session.user_id()).await?
181        {
182            return Ok((user, session));
183        }
184
185        Err(AuthError::Unauthenticated)
186    }
187}