Skip to main content

better_auth_core/
plugin.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use crate::adapters::DatabaseAdapter;
6use crate::config::AuthConfig;
7use crate::email::EmailProvider;
8use crate::entity::AuthSession;
9use crate::error::{AuthError, AuthResult};
10use crate::session::SessionManager;
11use crate::types::{AuthRequest, AuthResponse, HttpMethod};
12
13/// Action returned by [`AuthPlugin::before_request`].
14#[derive(Debug)]
15pub enum BeforeRequestAction {
16    /// Short-circuit with this response (e.g. return session JSON).
17    Respond(AuthResponse),
18    /// Inject a virtual session so downstream handlers see it as authenticated.
19    InjectSession {
20        user_id: String,
21        session_token: String,
22    },
23}
24
25/// Plugin trait that all authentication plugins must implement.
26///
27/// Generic over `DB` so that lifecycle hooks receive the adapter's concrete
28/// entity types (e.g., `DB::User`, `DB::Session`).
29#[async_trait]
30pub trait AuthPlugin<DB: DatabaseAdapter>: Send + Sync {
31    /// Plugin name - should be unique
32    fn name(&self) -> &'static str;
33
34    /// Routes that this plugin handles
35    fn routes(&self) -> Vec<AuthRoute>;
36
37    /// Called when the plugin is initialized
38    async fn on_init(&self, ctx: &mut AuthContext<DB>) -> AuthResult<()> {
39        let _ = ctx;
40        Ok(())
41    }
42
43    /// Called before route matching for every incoming request.
44    ///
45    /// Return `Some(BeforeRequestAction::Respond(..))` to short-circuit with a
46    /// response, `Some(BeforeRequestAction::InjectSession { .. })` to attach a
47    /// virtual session (e.g. API-key → session emulation), or `None` to let the
48    /// request continue to normal route matching.
49    async fn before_request(
50        &self,
51        _req: &AuthRequest,
52        _ctx: &AuthContext<DB>,
53    ) -> AuthResult<Option<BeforeRequestAction>> {
54        Ok(None)
55    }
56
57    /// Called for each request - return Some(response) to handle, None to pass through
58    async fn on_request(
59        &self,
60        req: &AuthRequest,
61        ctx: &AuthContext<DB>,
62    ) -> AuthResult<Option<AuthResponse>>;
63
64    /// Called after a user is created
65    async fn on_user_created(&self, user: &DB::User, ctx: &AuthContext<DB>) -> AuthResult<()> {
66        let _ = (user, ctx);
67        Ok(())
68    }
69
70    /// Called after a session is created
71    async fn on_session_created(
72        &self,
73        session: &DB::Session,
74        ctx: &AuthContext<DB>,
75    ) -> AuthResult<()> {
76        let _ = (session, ctx);
77        Ok(())
78    }
79
80    /// Called before a user is deleted
81    async fn on_user_deleted(&self, user_id: &str, ctx: &AuthContext<DB>) -> AuthResult<()> {
82        let _ = (user_id, ctx);
83        Ok(())
84    }
85
86    /// Called before a session is deleted
87    async fn on_session_deleted(
88        &self,
89        session_token: &str,
90        ctx: &AuthContext<DB>,
91    ) -> AuthResult<()> {
92        let _ = (session_token, ctx);
93        Ok(())
94    }
95}
96
97/// Generates the [`AuthPlugin<DB>`] impl for a plugin with static route dispatch.
98///
99/// Eliminates the dual declaration of routes in `routes()` and `on_request()`
100/// by generating both from a single route table.
101///
102/// # Exceptions (must keep manual impl)
103/// - `OAuthPlugin` — dynamic path matching for `/callback/{provider}`
104/// - `SessionManagementPlugin` — match guards and OR patterns
105/// - `EmailPasswordPlugin` — conditional routes based on config
106/// - `UserManagementPlugin` — conditional routes based on config
107/// - `PasswordManagementPlugin` — dynamic path matching for `/reset-password/{token}`
108/// - `OrganizationPlugin` — handlers accept extra `&self.config` argument
109#[macro_export]
110macro_rules! impl_auth_plugin {
111    (@pat get) => { $crate::HttpMethod::Get };
112    (@pat post) => { $crate::HttpMethod::Post };
113    (@pat put) => { $crate::HttpMethod::Put };
114    (@pat delete) => { $crate::HttpMethod::Delete };
115    (@pat patch) => { $crate::HttpMethod::Patch };
116    (@pat head) => { $crate::HttpMethod::Head };
117
118    (@route get) => { $crate::AuthRoute::get };
119    (@route post) => { $crate::AuthRoute::post };
120    (@route put) => { $crate::AuthRoute::put };
121    (@route delete) => { $crate::AuthRoute::delete };
122
123    (
124        $plugin:ty, $name:expr;
125        routes {
126            $( $method:ident $path:literal => $handler:ident, $op_id:literal );* $(;)?
127        }
128        $( extra { $($extra:tt)* } )?
129    ) => {
130        #[::async_trait::async_trait]
131        impl<DB: $crate::adapters::DatabaseAdapter> $crate::AuthPlugin<DB> for $plugin {
132            fn name(&self) -> &'static str { $name }
133
134            fn routes(&self) -> Vec<$crate::AuthRoute> {
135                vec![
136                    $( $crate::AuthRoute::new($crate::impl_auth_plugin!(@pat $method), $path, $op_id), )*
137                ]
138            }
139
140            async fn on_request(
141                &self,
142                req: &$crate::AuthRequest,
143                ctx: &$crate::AuthContext<DB>,
144            ) -> $crate::AuthResult<Option<$crate::AuthResponse>> {
145                match (req.method(), req.path()) {
146                    $(
147                        ($crate::impl_auth_plugin!(@pat $method), $path) => {
148                            Ok(Some(self.$handler(req, ctx).await?))
149                        }
150                    )*
151                    _ => Ok(None),
152                }
153            }
154
155            $( $($extra)* )?
156        }
157    };
158}
159
160/// Route definition for plugins
161#[derive(Debug, Clone)]
162pub struct AuthRoute {
163    pub path: String,
164    pub method: HttpMethod,
165    /// Identifier used as the OpenAPI `operationId` for this route.
166    pub operation_id: String,
167}
168
169/// Context passed to plugin methods
170pub struct AuthContext<DB: DatabaseAdapter> {
171    pub config: Arc<AuthConfig>,
172    pub database: Arc<DB>,
173    pub email_provider: Option<Arc<dyn EmailProvider>>,
174    pub metadata: HashMap<String, serde_json::Value>,
175}
176
177impl AuthRoute {
178    pub fn new(
179        method: HttpMethod,
180        path: impl Into<String>,
181        operation_id: impl Into<String>,
182    ) -> Self {
183        Self {
184            path: path.into(),
185            method,
186            operation_id: operation_id.into(),
187        }
188    }
189
190    pub fn get(path: impl Into<String>, operation_id: impl Into<String>) -> Self {
191        Self::new(HttpMethod::Get, path, operation_id)
192    }
193
194    pub fn post(path: impl Into<String>, operation_id: impl Into<String>) -> Self {
195        Self::new(HttpMethod::Post, path, operation_id)
196    }
197
198    pub fn put(path: impl Into<String>, operation_id: impl Into<String>) -> Self {
199        Self::new(HttpMethod::Put, path, operation_id)
200    }
201
202    pub fn delete(path: impl Into<String>, operation_id: impl Into<String>) -> Self {
203        Self::new(HttpMethod::Delete, path, operation_id)
204    }
205}
206
207impl<DB: DatabaseAdapter> AuthContext<DB> {
208    pub fn new(config: Arc<AuthConfig>, database: Arc<DB>) -> Self {
209        let email_provider = config.email_provider.clone();
210        Self {
211            config,
212            database,
213            email_provider,
214            metadata: HashMap::new(),
215        }
216    }
217
218    pub fn set_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
219        self.metadata.insert(key.into(), value);
220    }
221
222    pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
223        self.metadata.get(key)
224    }
225
226    /// Get the email provider, returning an error if none is configured.
227    pub fn email_provider(&self) -> AuthResult<&dyn EmailProvider> {
228        self.email_provider
229            .as_deref()
230            .ok_or_else(|| AuthError::config("No email provider configured"))
231    }
232
233    /// Create a `SessionManager` from this context's config and database.
234    pub fn session_manager(&self) -> crate::session::SessionManager<DB> {
235        crate::session::SessionManager::new(self.config.clone(), self.database.clone())
236    }
237
238    /// Extract a session token from the request, validate the session, and
239    /// return the authenticated `(User, Session)` pair.
240    ///
241    /// This centralises the pattern previously duplicated across many plugins
242    /// (`get_authenticated_user`, `require_session`, etc.).
243    pub async fn require_session(&self, req: &AuthRequest) -> AuthResult<(DB::User, DB::Session)> {
244        let session_manager = self.session_manager();
245
246        if let Some(token) = session_manager.extract_session_token(req)
247            && let Some(session) = session_manager.get_session(&token).await?
248            && let Some(user) = self.database.get_user_by_id(session.user_id()).await?
249        {
250            return Ok((user, session));
251        }
252
253        Err(AuthError::Unauthenticated)
254    }
255}
256
257/// Axum-friendly shared state type.
258///
259/// All fields are behind `Arc` so `AuthState` is cheap to clone and can
260/// be used directly as axum `State`.
261pub struct AuthState<DB: DatabaseAdapter> {
262    pub config: Arc<AuthConfig>,
263    pub database: Arc<DB>,
264    pub session_manager: SessionManager<DB>,
265    pub email_provider: Option<Arc<dyn EmailProvider>>,
266}
267
268impl<DB: DatabaseAdapter> Clone for AuthState<DB> {
269    fn clone(&self) -> Self {
270        Self {
271            config: self.config.clone(),
272            database: self.database.clone(),
273            session_manager: self.session_manager.clone(),
274            email_provider: self.email_provider.clone(),
275        }
276    }
277}
278
279impl<DB: DatabaseAdapter> AuthState<DB> {
280    /// Create a new `AuthState` from an `AuthContext` and `SessionManager`.
281    pub fn new(ctx: &AuthContext<DB>, session_manager: SessionManager<DB>) -> Self {
282        Self {
283            config: ctx.config.clone(),
284            database: ctx.database.clone(),
285            session_manager,
286            email_provider: ctx.email_provider.clone(),
287        }
288    }
289
290    /// Create an `AuthContext` for use with existing plugin handler methods.
291    pub fn to_context(&self) -> AuthContext<DB> {
292        let mut ctx = AuthContext::new(self.config.clone(), self.database.clone());
293        ctx.email_provider = self.email_provider.clone();
294        ctx
295    }
296
297    /// Build a `Set-Cookie` header value for a session token.
298    pub fn session_cookie(&self, token: &str) -> String {
299        crate::utils::cookie_utils::create_session_cookie(token, &self.config)
300    }
301
302    /// Build a `Set-Cookie` header value that clears the session cookie.
303    pub fn clear_session_cookie(&self) -> String {
304        crate::utils::cookie_utils::create_clear_session_cookie(&self.config)
305    }
306}
307
308/// Plugin trait for axum-native routing.
309///
310/// Unlike [`AuthPlugin`] which uses the custom `AuthRequest`/`AuthResponse`
311/// abstraction, `AxumPlugin` returns a standard `axum::Router` with handlers
312/// already bound to routes. This eliminates the triple route-matching overhead
313/// and enables use of axum extractors.
314#[cfg(feature = "axum")]
315#[async_trait]
316pub trait AxumPlugin<DB: DatabaseAdapter>: Send + Sync {
317    /// Plugin name — should be unique and match the `AuthPlugin` name when
318    /// both traits are implemented on the same type.
319    fn name(&self) -> &'static str;
320
321    /// Return an axum `Router` with all routes for this plugin.
322    ///
323    /// The router uses `AuthState<DB>` as its state type.
324    fn router(&self) -> axum::Router<AuthState<DB>>;
325
326    /// Called after a user is created.
327    async fn on_user_created(&self, _user: &DB::User, _ctx: &AuthContext<DB>) -> AuthResult<()> {
328        Ok(())
329    }
330
331    /// Called after a session is created.
332    async fn on_session_created(
333        &self,
334        _session: &DB::Session,
335        _ctx: &AuthContext<DB>,
336    ) -> AuthResult<()> {
337        Ok(())
338    }
339
340    /// Called before a user is deleted.
341    async fn on_user_deleted(&self, _user_id: &str, _ctx: &AuthContext<DB>) -> AuthResult<()> {
342        Ok(())
343    }
344
345    /// Called before a session is deleted.
346    async fn on_session_deleted(
347        &self,
348        _session_token: &str,
349        _ctx: &AuthContext<DB>,
350    ) -> AuthResult<()> {
351        Ok(())
352    }
353}