use std::fmt;
use std::sync::Arc;
use http::Method;
use crate::api::{ApiRequest, ApiResponse};
use crate::context::AuthContext;
use crate::error::RustAuthError;
use crate::plugin::{PluginAfterHook, PluginBeforeHook, PluginHookMatcher};
#[derive(Clone, Default)]
pub struct GlobalHooksOptions {
pub before: Option<Arc<dyn GlobalBeforeHook>>,
pub after: Option<Arc<dyn GlobalAfterHook>>,
}
impl GlobalHooksOptions {
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn before<H>(mut self, hook: H) -> Self
where
H: GlobalBeforeHook,
{
self.before = Some(Arc::new(hook));
self
}
#[must_use]
pub fn after<H>(mut self, hook: H) -> Self
where
H: GlobalAfterHook,
{
self.after = Some(Arc::new(hook));
self
}
}
impl fmt::Debug for GlobalHooksOptions {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_struct("GlobalHooksOptions")
.field(
"before",
&self.before.as_ref().map(|_| "<global-before-hook>"),
)
.field("after", &self.after.as_ref().map(|_| "<global-after-hook>"))
.finish()
}
}
pub trait GlobalBeforeHook: Send + Sync + 'static {
fn before(
&self,
context: &AuthContext,
request: ApiRequest,
method: &Method,
path: &str,
) -> Result<GlobalHookAction, RustAuthError>;
}
impl<F> GlobalBeforeHook for F
where
F: Fn(&AuthContext, ApiRequest, &Method, &str) -> Result<GlobalHookAction, RustAuthError>
+ Send
+ Sync
+ 'static,
{
fn before(
&self,
context: &AuthContext,
request: ApiRequest,
method: &Method,
path: &str,
) -> Result<GlobalHookAction, RustAuthError> {
self(context, request, method, path)
}
}
pub trait GlobalAfterHook: Send + Sync + 'static {
fn after(
&self,
context: &AuthContext,
request: &ApiRequest,
response: ApiResponse,
method: &Method,
path: &str,
) -> Result<ApiResponse, RustAuthError>;
}
impl<F> GlobalAfterHook for F
where
F: Fn(
&AuthContext,
&ApiRequest,
ApiResponse,
&Method,
&str,
) -> Result<ApiResponse, RustAuthError>
+ Send
+ Sync
+ 'static,
{
fn after(
&self,
context: &AuthContext,
request: &ApiRequest,
response: ApiResponse,
method: &Method,
path: &str,
) -> Result<ApiResponse, RustAuthError> {
self(context, request, response, method, path)
}
}
pub enum GlobalHookAction {
Continue(ApiRequest),
Respond(ApiResponse),
}
pub(crate) fn plugin_before_hooks(options: &GlobalHooksOptions) -> Vec<PluginBeforeHook> {
let Some(hook) = options.before.clone() else {
return Vec::new();
};
vec![PluginBeforeHook {
matcher: PluginHookMatcher {
path: "/*".to_owned(),
method: None,
operation_id: None,
},
handler: Arc::new(move |context, request| {
let method = request.method().clone();
let path = request
.uri()
.path()
.trim_start_matches(context.base_path.trim_end_matches('/'))
.to_owned();
match hook.before(context, request, &method, &path)? {
GlobalHookAction::Continue(request) => {
Ok(crate::plugin::PluginBeforeHookAction::Continue(request))
}
GlobalHookAction::Respond(response) => {
Ok(crate::plugin::PluginBeforeHookAction::Respond(response))
}
}
}),
}]
}
pub(crate) fn plugin_after_hooks(options: &GlobalHooksOptions) -> Vec<PluginAfterHook> {
let Some(hook) = options.after.clone() else {
return Vec::new();
};
vec![PluginAfterHook {
matcher: PluginHookMatcher {
path: "/*".to_owned(),
method: None,
operation_id: None,
},
handler: Arc::new(move |context, request, response| {
let method = request.method().clone();
let path = request
.uri()
.path()
.trim_start_matches(context.base_path.trim_end_matches('/'))
.to_owned();
let response = hook.after(context, request, response, &method, &path)?;
Ok(crate::plugin::PluginAfterHookAction::Continue(response))
}),
}]
}
#[cfg(test)]
mod tests {
use http::Method;
use crate::api::{ApiRequest, ApiResponse};
use crate::context::AuthContext;
use super::*;
struct TestBeforeHook;
struct TestAfterHook;
impl GlobalBeforeHook for TestBeforeHook {
fn before(
&self,
_context: &AuthContext,
request: ApiRequest,
_method: &Method,
_path: &str,
) -> Result<GlobalHookAction, RustAuthError> {
Ok(GlobalHookAction::Continue(request))
}
}
impl GlobalAfterHook for TestAfterHook {
fn after(
&self,
_context: &AuthContext,
_request: &ApiRequest,
response: ApiResponse,
_method: &Method,
_path: &str,
) -> Result<ApiResponse, RustAuthError> {
Ok(response)
}
}
#[test]
fn global_hooks_options_supports_fluent_registration() {
let options = GlobalHooksOptions::new()
.before(TestBeforeHook)
.after(TestAfterHook);
assert!(options.before.is_some());
assert!(options.after.is_some());
}
}