openauth_core/plugin/
hooks.rs1use crate::api::{ApiRequest, ApiResponse};
4use crate::context::AuthContext;
5use crate::error::OpenAuthError;
6use http::Method;
7use std::fmt;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11
12pub type PluginBeforeHookHandler = Arc<
13 dyn Fn(&AuthContext, ApiRequest) -> Result<PluginBeforeHookAction, OpenAuthError> + Send + Sync,
14>;
15pub type PluginAfterHookHandler = Arc<
16 dyn Fn(&AuthContext, &ApiRequest, ApiResponse) -> Result<PluginAfterHookAction, OpenAuthError>
17 + Send
18 + Sync,
19>;
20pub type PluginBeforeHookFuture<'a> =
21 Pin<Box<dyn Future<Output = Result<PluginBeforeHookAction, OpenAuthError>> + Send + 'a>>;
22pub type PluginAfterHookFuture<'a> =
23 Pin<Box<dyn Future<Output = Result<PluginAfterHookAction, OpenAuthError>> + Send + 'a>>;
24pub type PluginAsyncBeforeHookHandler =
25 Arc<dyn for<'a> Fn(&'a AuthContext, ApiRequest) -> PluginBeforeHookFuture<'a> + Send + Sync>;
26pub type PluginAsyncAfterHookHandler = Arc<
27 dyn for<'a> Fn(&'a AuthContext, &'a ApiRequest, ApiResponse) -> PluginAfterHookFuture<'a>
28 + Send
29 + Sync,
30>;
31
32pub enum PluginBeforeHookAction {
34 Continue(ApiRequest),
35 Respond(ApiResponse),
36}
37
38pub enum PluginAfterHookAction {
40 Continue(ApiResponse),
41}
42
43#[derive(Debug, Clone, PartialEq, Eq)]
45pub struct PluginHookMatcher {
46 pub path: String,
47 pub method: Option<Method>,
48 pub operation_id: Option<String>,
49}
50
51impl PluginHookMatcher {
52 pub fn path(path: impl Into<String>) -> Self {
53 Self {
54 path: path.into(),
55 method: None,
56 operation_id: None,
57 }
58 }
59
60 #[must_use]
61 pub fn method(mut self, method: Method) -> Self {
62 self.method = Some(method);
63 self
64 }
65
66 #[must_use]
67 pub fn operation_id(mut self, operation_id: impl Into<String>) -> Self {
68 self.operation_id = Some(operation_id.into());
69 self
70 }
71
72 pub fn matches(&self, method: &Method, path: &str, operation_id: Option<&str>) -> bool {
73 if self
74 .method
75 .as_ref()
76 .is_some_and(|expected| expected != method)
77 {
78 return false;
79 }
80 if self
81 .operation_id
82 .as_deref()
83 .is_some_and(|expected| Some(expected) != operation_id)
84 {
85 return false;
86 }
87 path_matches(&self.path, path)
88 }
89}
90
91#[derive(Clone)]
92pub struct PluginBeforeHook {
93 pub matcher: PluginHookMatcher,
94 pub handler: PluginBeforeHookHandler,
95}
96
97impl fmt::Debug for PluginBeforeHook {
98 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
99 formatter
100 .debug_struct("PluginBeforeHook")
101 .field("matcher", &self.matcher)
102 .field("handler", &"<before-hook>")
103 .finish()
104 }
105}
106
107#[derive(Clone)]
108pub struct PluginAfterHook {
109 pub matcher: PluginHookMatcher,
110 pub handler: PluginAfterHookHandler,
111}
112
113impl fmt::Debug for PluginAfterHook {
114 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
115 formatter
116 .debug_struct("PluginAfterHook")
117 .field("matcher", &self.matcher)
118 .field("handler", &"<after-hook>")
119 .finish()
120 }
121}
122
123#[derive(Clone)]
124pub struct PluginAsyncBeforeHook {
125 pub matcher: PluginHookMatcher,
126 pub handler: PluginAsyncBeforeHookHandler,
127}
128
129impl fmt::Debug for PluginAsyncBeforeHook {
130 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
131 formatter
132 .debug_struct("PluginAsyncBeforeHook")
133 .field("matcher", &self.matcher)
134 .field("handler", &"<async-before-hook>")
135 .finish()
136 }
137}
138
139#[derive(Clone)]
140pub struct PluginAsyncAfterHook {
141 pub matcher: PluginHookMatcher,
142 pub handler: PluginAsyncAfterHookHandler,
143}
144
145impl fmt::Debug for PluginAsyncAfterHook {
146 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
147 formatter
148 .debug_struct("PluginAsyncAfterHook")
149 .field("matcher", &self.matcher)
150 .field("handler", &"<async-after-hook>")
151 .finish()
152 }
153}
154
155#[derive(Debug, Clone, Default)]
156pub struct PluginEndpointHooks {
157 pub before: Vec<PluginBeforeHook>,
158 pub after: Vec<PluginAfterHook>,
159 pub async_before: Vec<PluginAsyncBeforeHook>,
160 pub async_after: Vec<PluginAsyncAfterHook>,
161}
162
163fn path_matches(pattern: &str, path: &str) -> bool {
164 if let Some((prefix, suffix)) = pattern.split_once('*') {
165 return path.starts_with(prefix) && path.ends_with(suffix);
166 }
167 let pattern_segments = pattern.trim_matches('/').split('/').collect::<Vec<_>>();
168 let path_segments = path.trim_matches('/').split('/').collect::<Vec<_>>();
169 if pattern_segments.len() != path_segments.len() {
170 return false;
171 }
172 pattern_segments
173 .iter()
174 .zip(path_segments.iter())
175 .all(|(expected, actual)| expected.starts_with(':') || expected == actual)
176}