use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use crate::adapters::DatabaseAdapter;
use crate::config::AuthConfig;
use crate::email::EmailProvider;
use crate::entity::AuthSession;
use crate::error::{AuthError, AuthResult};
use crate::session::SessionManager;
use crate::types::{AuthRequest, AuthResponse, HttpMethod};
#[derive(Debug)]
pub enum BeforeRequestAction {
Respond(AuthResponse),
InjectSession {
user_id: String,
session_token: String,
},
}
#[async_trait]
pub trait AuthPlugin<DB: DatabaseAdapter>: Send + Sync {
fn name(&self) -> &'static str;
fn routes(&self) -> Vec<AuthRoute>;
async fn on_init(&self, ctx: &mut AuthContext<DB>) -> AuthResult<()> {
let _ = ctx;
Ok(())
}
async fn before_request(
&self,
_req: &AuthRequest,
_ctx: &AuthContext<DB>,
) -> AuthResult<Option<BeforeRequestAction>> {
Ok(None)
}
async fn on_request(
&self,
req: &AuthRequest,
ctx: &AuthContext<DB>,
) -> AuthResult<Option<AuthResponse>>;
async fn on_user_created(&self, user: &DB::User, ctx: &AuthContext<DB>) -> AuthResult<()> {
let _ = (user, ctx);
Ok(())
}
async fn on_session_created(
&self,
session: &DB::Session,
ctx: &AuthContext<DB>,
) -> AuthResult<()> {
let _ = (session, ctx);
Ok(())
}
async fn on_user_deleted(&self, user_id: &str, ctx: &AuthContext<DB>) -> AuthResult<()> {
let _ = (user_id, ctx);
Ok(())
}
async fn on_session_deleted(
&self,
session_token: &str,
ctx: &AuthContext<DB>,
) -> AuthResult<()> {
let _ = (session_token, ctx);
Ok(())
}
}
#[macro_export]
macro_rules! impl_auth_plugin {
(@pat get) => { $crate::HttpMethod::Get };
(@pat post) => { $crate::HttpMethod::Post };
(@pat put) => { $crate::HttpMethod::Put };
(@pat delete) => { $crate::HttpMethod::Delete };
(@pat patch) => { $crate::HttpMethod::Patch };
(@pat head) => { $crate::HttpMethod::Head };
(@route get) => { $crate::AuthRoute::get };
(@route post) => { $crate::AuthRoute::post };
(@route put) => { $crate::AuthRoute::put };
(@route delete) => { $crate::AuthRoute::delete };
(
$plugin:ty, $name:expr;
routes {
$( $method:ident $path:literal => $handler:ident, $op_id:literal );* $(;)?
}
$( extra { $($extra:tt)* } )?
) => {
#[::async_trait::async_trait]
impl<DB: $crate::adapters::DatabaseAdapter> $crate::AuthPlugin<DB> for $plugin {
fn name(&self) -> &'static str { $name }
fn routes(&self) -> Vec<$crate::AuthRoute> {
vec![
$( $crate::AuthRoute::new($crate::impl_auth_plugin!(@pat $method), $path, $op_id), )*
]
}
async fn on_request(
&self,
req: &$crate::AuthRequest,
ctx: &$crate::AuthContext<DB>,
) -> $crate::AuthResult<Option<$crate::AuthResponse>> {
match (req.method(), req.path()) {
$(
($crate::impl_auth_plugin!(@pat $method), $path) => {
Ok(Some(self.$handler(req, ctx).await?))
}
)*
_ => Ok(None),
}
}
$( $($extra)* )?
}
};
}
#[derive(Debug, Clone)]
pub struct AuthRoute {
pub path: String,
pub method: HttpMethod,
pub operation_id: String,
}
pub struct AuthContext<DB: DatabaseAdapter> {
pub config: Arc<AuthConfig>,
pub database: Arc<DB>,
pub email_provider: Option<Arc<dyn EmailProvider>>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl AuthRoute {
pub fn new(
method: HttpMethod,
path: impl Into<String>,
operation_id: impl Into<String>,
) -> Self {
Self {
path: path.into(),
method,
operation_id: operation_id.into(),
}
}
pub fn get(path: impl Into<String>, operation_id: impl Into<String>) -> Self {
Self::new(HttpMethod::Get, path, operation_id)
}
pub fn post(path: impl Into<String>, operation_id: impl Into<String>) -> Self {
Self::new(HttpMethod::Post, path, operation_id)
}
pub fn put(path: impl Into<String>, operation_id: impl Into<String>) -> Self {
Self::new(HttpMethod::Put, path, operation_id)
}
pub fn delete(path: impl Into<String>, operation_id: impl Into<String>) -> Self {
Self::new(HttpMethod::Delete, path, operation_id)
}
}
impl<DB: DatabaseAdapter> AuthContext<DB> {
pub fn new(config: Arc<AuthConfig>, database: Arc<DB>) -> Self {
let email_provider = config.email_provider.clone();
Self {
config,
database,
email_provider,
metadata: HashMap::new(),
}
}
pub fn set_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
self.metadata.insert(key.into(), value);
}
pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
self.metadata.get(key)
}
pub fn email_provider(&self) -> AuthResult<&dyn EmailProvider> {
self.email_provider
.as_deref()
.ok_or_else(|| AuthError::config("No email provider configured"))
}
pub fn session_manager(&self) -> crate::session::SessionManager<DB> {
crate::session::SessionManager::new(self.config.clone(), self.database.clone())
}
pub async fn require_session(&self, req: &AuthRequest) -> AuthResult<(DB::User, DB::Session)> {
let session_manager = self.session_manager();
if let Some(token) = session_manager.extract_session_token(req)
&& let Some(session) = session_manager.get_session(&token).await?
&& let Some(user) = self.database.get_user_by_id(session.user_id()).await?
{
return Ok((user, session));
}
Err(AuthError::Unauthenticated)
}
}
pub struct AuthState<DB: DatabaseAdapter> {
pub config: Arc<AuthConfig>,
pub database: Arc<DB>,
pub session_manager: SessionManager<DB>,
pub email_provider: Option<Arc<dyn EmailProvider>>,
}
impl<DB: DatabaseAdapter> Clone for AuthState<DB> {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
database: self.database.clone(),
session_manager: self.session_manager.clone(),
email_provider: self.email_provider.clone(),
}
}
}
impl<DB: DatabaseAdapter> AuthState<DB> {
pub fn new(ctx: &AuthContext<DB>, session_manager: SessionManager<DB>) -> Self {
Self {
config: ctx.config.clone(),
database: ctx.database.clone(),
session_manager,
email_provider: ctx.email_provider.clone(),
}
}
pub fn to_context(&self) -> AuthContext<DB> {
let mut ctx = AuthContext::new(self.config.clone(), self.database.clone());
ctx.email_provider = self.email_provider.clone();
ctx
}
pub fn session_cookie(&self, token: &str) -> String {
crate::utils::cookie_utils::create_session_cookie(token, &self.config)
}
pub fn clear_session_cookie(&self) -> String {
crate::utils::cookie_utils::create_clear_session_cookie(&self.config)
}
}
#[cfg(feature = "axum")]
#[async_trait]
pub trait AxumPlugin<DB: DatabaseAdapter>: Send + Sync {
fn name(&self) -> &'static str;
fn router(&self) -> axum::Router<AuthState<DB>>;
async fn on_user_created(&self, _user: &DB::User, _ctx: &AuthContext<DB>) -> AuthResult<()> {
Ok(())
}
async fn on_session_created(
&self,
_session: &DB::Session,
_ctx: &AuthContext<DB>,
) -> AuthResult<()> {
Ok(())
}
async fn on_user_deleted(&self, _user_id: &str, _ctx: &AuthContext<DB>) -> AuthResult<()> {
Ok(())
}
async fn on_session_deleted(
&self,
_session_token: &str,
_ctx: &AuthContext<DB>,
) -> AuthResult<()> {
Ok(())
}
}