Skip to main content

openauth_core/plugin/
hooks.rs

1//! Endpoint-scoped plugin hooks.
2
3use crate::api::{ApiRequest, ApiResponse};
4use crate::context::AuthContext;
5use crate::error::OpenAuthError;
6use http::Method;
7use std::fmt;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11
12pub type PluginBeforeHookHandler = Arc<
13    dyn Fn(&AuthContext, ApiRequest) -> Result<PluginBeforeHookAction, OpenAuthError> + Send + Sync,
14>;
15pub type PluginAfterHookHandler = Arc<
16    dyn Fn(&AuthContext, &ApiRequest, ApiResponse) -> Result<PluginAfterHookAction, OpenAuthError>
17        + Send
18        + Sync,
19>;
20pub type PluginBeforeHookFuture<'a> =
21    Pin<Box<dyn Future<Output = Result<PluginBeforeHookAction, OpenAuthError>> + Send + 'a>>;
22pub type PluginAfterHookFuture<'a> =
23    Pin<Box<dyn Future<Output = Result<PluginAfterHookAction, OpenAuthError>> + Send + 'a>>;
24pub type PluginAsyncBeforeHookHandler =
25    Arc<dyn for<'a> Fn(&'a AuthContext, ApiRequest) -> PluginBeforeHookFuture<'a> + Send + Sync>;
26pub type PluginAsyncAfterHookHandler = Arc<
27    dyn for<'a> Fn(&'a AuthContext, &'a ApiRequest, ApiResponse) -> PluginAfterHookFuture<'a>
28        + Send
29        + Sync,
30>;
31
32/// Action returned by a before endpoint hook.
33pub enum PluginBeforeHookAction {
34    Continue(ApiRequest),
35    Respond(ApiResponse),
36}
37
38/// Action returned by an after endpoint hook.
39pub enum PluginAfterHookAction {
40    Continue(ApiResponse),
41}
42
43/// Matcher used to select endpoint hooks.
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub struct PluginHookMatcher {
46    pub path: String,
47    pub method: Option<Method>,
48    pub operation_id: Option<String>,
49}
50
51impl PluginHookMatcher {
52    pub fn path(path: impl Into<String>) -> Self {
53        Self {
54            path: path.into(),
55            method: None,
56            operation_id: None,
57        }
58    }
59
60    #[must_use]
61    pub fn method(mut self, method: Method) -> Self {
62        self.method = Some(method);
63        self
64    }
65
66    #[must_use]
67    pub fn operation_id(mut self, operation_id: impl Into<String>) -> Self {
68        self.operation_id = Some(operation_id.into());
69        self
70    }
71
72    pub fn matches(&self, method: &Method, path: &str, operation_id: Option<&str>) -> bool {
73        if self
74            .method
75            .as_ref()
76            .is_some_and(|expected| expected != method)
77        {
78            return false;
79        }
80        if self
81            .operation_id
82            .as_deref()
83            .is_some_and(|expected| Some(expected) != operation_id)
84        {
85            return false;
86        }
87        path_matches(&self.path, path)
88    }
89}
90
91#[derive(Clone)]
92pub struct PluginBeforeHook {
93    pub matcher: PluginHookMatcher,
94    pub handler: PluginBeforeHookHandler,
95}
96
97impl fmt::Debug for PluginBeforeHook {
98    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
99        formatter
100            .debug_struct("PluginBeforeHook")
101            .field("matcher", &self.matcher)
102            .field("handler", &"<before-hook>")
103            .finish()
104    }
105}
106
107#[derive(Clone)]
108pub struct PluginAfterHook {
109    pub matcher: PluginHookMatcher,
110    pub handler: PluginAfterHookHandler,
111}
112
113impl fmt::Debug for PluginAfterHook {
114    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
115        formatter
116            .debug_struct("PluginAfterHook")
117            .field("matcher", &self.matcher)
118            .field("handler", &"<after-hook>")
119            .finish()
120    }
121}
122
123#[derive(Clone)]
124pub struct PluginAsyncBeforeHook {
125    pub matcher: PluginHookMatcher,
126    pub handler: PluginAsyncBeforeHookHandler,
127}
128
129impl fmt::Debug for PluginAsyncBeforeHook {
130    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
131        formatter
132            .debug_struct("PluginAsyncBeforeHook")
133            .field("matcher", &self.matcher)
134            .field("handler", &"<async-before-hook>")
135            .finish()
136    }
137}
138
139#[derive(Clone)]
140pub struct PluginAsyncAfterHook {
141    pub matcher: PluginHookMatcher,
142    pub handler: PluginAsyncAfterHookHandler,
143}
144
145impl fmt::Debug for PluginAsyncAfterHook {
146    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
147        formatter
148            .debug_struct("PluginAsyncAfterHook")
149            .field("matcher", &self.matcher)
150            .field("handler", &"<async-after-hook>")
151            .finish()
152    }
153}
154
155#[derive(Debug, Clone, Default)]
156pub struct PluginEndpointHooks {
157    pub before: Vec<PluginBeforeHook>,
158    pub after: Vec<PluginAfterHook>,
159    pub async_before: Vec<PluginAsyncBeforeHook>,
160    pub async_after: Vec<PluginAsyncAfterHook>,
161}
162
163fn path_matches(pattern: &str, path: &str) -> bool {
164    if let Some((prefix, suffix)) = pattern.split_once('*') {
165        return path.starts_with(prefix) && path.ends_with(suffix);
166    }
167    let pattern_segments = pattern.trim_matches('/').split('/').collect::<Vec<_>>();
168    let path_segments = path.trim_matches('/').split('/').collect::<Vec<_>>();
169    if pattern_segments.len() != path_segments.len() {
170        return false;
171    }
172    pattern_segments
173        .iter()
174        .zip(path_segments.iter())
175        .all(|(expected, actual)| expected.starts_with(':') || expected == actual)
176}