Skip to main content

openauth_core/
plugin.rs

1//! Plugin contracts for OpenAuth extensions.
2
3use std::future::Future;
4use std::pin::Pin;
5
6mod db;
7mod endpoint;
8mod error;
9mod hooks;
10mod init;
11mod password;
12mod rate_limit;
13mod schema;
14
15pub use db::{
16    PluginDatabaseAfterHookHandler, PluginDatabaseAfterInput, PluginDatabaseBeforeAction,
17    PluginDatabaseBeforeHookHandler, PluginDatabaseBeforeInput, PluginDatabaseHook,
18    PluginDatabaseHookContext, PluginDatabaseOperation, PluginMigration,
19};
20pub use endpoint::PluginEndpoint;
21pub use error::PluginErrorCode;
22pub use hooks::{
23    PluginAfterHook, PluginAfterHookAction, PluginAfterHookFuture, PluginAfterHookHandler,
24    PluginAsyncAfterHook, PluginAsyncAfterHookHandler, PluginAsyncBeforeHook,
25    PluginAsyncBeforeHookHandler, PluginBeforeHook, PluginBeforeHookAction, PluginBeforeHookFuture,
26    PluginBeforeHookHandler, PluginEndpointHooks, PluginHookMatcher,
27};
28pub use init::{PluginInitHandler, PluginInitOutput};
29pub use password::{
30    PluginPasswordValidationInput, PluginPasswordValidationRejection, PluginPasswordValidator,
31    PluginPasswordValidatorFuture, PluginPasswordValidatorHandler,
32};
33pub use rate_limit::PluginRateLimitRule;
34pub use schema::PluginSchemaContribution;
35
36use crate::api::AsyncAuthEndpoint;
37use crate::context::AuthContext;
38use crate::error::OpenAuthError;
39use http::{Request, Response};
40#[cfg(feature = "oauth")]
41use openauth_oauth::oauth2::SocialOAuthProvider;
42use serde_json::Value;
43use std::fmt;
44use std::sync::Arc;
45
46pub type PluginBody = Vec<u8>;
47pub type PluginRequest = Request<PluginBody>;
48pub type PluginResponse = Response<PluginBody>;
49pub type PluginMiddlewareFuture<'a> =
50    Pin<Box<dyn Future<Output = Result<Option<PluginResponse>, OpenAuthError>> + Send + 'a>>;
51pub type PluginOnRequest = Arc<
52    dyn Fn(&AuthContext, PluginRequest) -> Result<PluginRequestAction, OpenAuthError> + Send + Sync,
53>;
54pub type PluginOnResponse = Arc<
55    dyn Fn(&AuthContext, &PluginRequest, PluginResponse) -> Result<PluginResponse, OpenAuthError>
56        + Send
57        + Sync,
58>;
59pub type PluginMiddlewareHandler = Arc<
60    dyn Fn(&AuthContext, &PluginRequest) -> Result<Option<PluginResponse>, OpenAuthError>
61        + Send
62        + Sync,
63>;
64pub type PluginAsyncMiddlewareHandler = Arc<
65    dyn for<'a> Fn(&'a AuthContext, &'a PluginRequest) -> PluginMiddlewareFuture<'a> + Send + Sync,
66>;
67
68#[derive(Clone)]
69pub struct AuthPlugin {
70    pub id: String,
71    pub version: Option<String>,
72    pub options: Option<Value>,
73    pub endpoints: Vec<AsyncAuthEndpoint>,
74    pub middlewares: Vec<PluginMiddleware>,
75    pub async_middlewares: Vec<PluginAsyncMiddleware>,
76    pub on_request: Option<PluginOnRequest>,
77    pub on_response: Option<PluginOnResponse>,
78    pub init: Option<PluginInitHandler>,
79    pub schema: Vec<PluginSchemaContribution>,
80    pub rate_limit: Vec<PluginRateLimitRule>,
81    pub hooks: PluginEndpointHooks,
82    pub error_codes: Vec<PluginErrorCode>,
83    pub database_hooks: Vec<PluginDatabaseHook>,
84    pub migrations: Vec<PluginMigration>,
85    #[cfg(feature = "oauth")]
86    pub social_providers: Vec<Arc<dyn SocialOAuthProvider>>,
87    pub password_validators: Vec<PluginPasswordValidator>,
88}
89
90impl AuthPlugin {
91    pub fn new(id: impl Into<String>) -> Self {
92        Self {
93            id: id.into(),
94            version: None,
95            options: None,
96            endpoints: Vec::new(),
97            middlewares: Vec::new(),
98            async_middlewares: Vec::new(),
99            on_request: None,
100            on_response: None,
101            init: None,
102            schema: Vec::new(),
103            rate_limit: Vec::new(),
104            hooks: PluginEndpointHooks::default(),
105            error_codes: Vec::new(),
106            database_hooks: Vec::new(),
107            migrations: Vec::new(),
108            #[cfg(feature = "oauth")]
109            social_providers: Vec::new(),
110            password_validators: Vec::new(),
111        }
112    }
113
114    pub fn with_version(mut self, version: impl Into<String>) -> Self {
115        self.version = Some(version.into());
116        self
117    }
118
119    pub fn with_options(mut self, options: Value) -> Self {
120        self.options = Some(options);
121        self
122    }
123
124    pub fn with_endpoint(mut self, endpoint: AsyncAuthEndpoint) -> Self {
125        self.endpoints.push(endpoint);
126        self
127    }
128
129    pub fn with_init<F>(mut self, init: F) -> Self
130    where
131        F: Fn(&AuthContext) -> Result<PluginInitOutput, OpenAuthError> + Send + Sync + 'static,
132    {
133        self.init = Some(Arc::new(init));
134        self
135    }
136
137    pub fn with_schema(mut self, contribution: PluginSchemaContribution) -> Self {
138        self.schema.push(contribution);
139        self
140    }
141
142    pub fn with_rate_limit(mut self, rule: PluginRateLimitRule) -> Self {
143        self.rate_limit.push(rule);
144        self
145    }
146
147    pub fn with_before_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
148    where
149        F: Fn(&AuthContext, PluginRequest) -> Result<PluginBeforeHookAction, OpenAuthError>
150            + Send
151            + Sync
152            + 'static,
153    {
154        self.hooks.before.push(PluginBeforeHook {
155            matcher: PluginHookMatcher::path(path),
156            handler: Arc::new(hook),
157        });
158        self
159    }
160
161    pub fn with_after_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
162    where
163        F: Fn(
164                &AuthContext,
165                &PluginRequest,
166                PluginResponse,
167            ) -> Result<PluginAfterHookAction, OpenAuthError>
168            + Send
169            + Sync
170            + 'static,
171    {
172        self.hooks.after.push(PluginAfterHook {
173            matcher: PluginHookMatcher::path(path),
174            handler: Arc::new(hook),
175        });
176        self
177    }
178
179    pub fn with_async_before_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
180    where
181        F: for<'a> Fn(&'a AuthContext, PluginRequest) -> PluginBeforeHookFuture<'a>
182            + Send
183            + Sync
184            + 'static,
185    {
186        self.hooks.async_before.push(PluginAsyncBeforeHook {
187            matcher: PluginHookMatcher::path(path),
188            handler: Arc::new(hook),
189        });
190        self
191    }
192
193    pub fn with_async_after_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
194    where
195        F: for<'a> Fn(
196                &'a AuthContext,
197                &'a PluginRequest,
198                PluginResponse,
199            ) -> PluginAfterHookFuture<'a>
200            + Send
201            + Sync
202            + 'static,
203    {
204        self.hooks.async_after.push(PluginAsyncAfterHook {
205            matcher: PluginHookMatcher::path(path),
206            handler: Arc::new(hook),
207        });
208        self
209    }
210
211    pub fn with_error_code(mut self, error_code: PluginErrorCode) -> Self {
212        self.error_codes.push(error_code);
213        self
214    }
215
216    pub fn with_database_hook(mut self, hook: PluginDatabaseHook) -> Self {
217        self.database_hooks.push(hook);
218        self
219    }
220
221    pub fn with_migration(mut self, migration: PluginMigration) -> Self {
222        self.migrations.push(migration);
223        self
224    }
225
226    #[cfg(feature = "oauth")]
227    pub fn with_social_provider(
228        mut self,
229        provider: impl Into<Arc<dyn SocialOAuthProvider>>,
230    ) -> Self {
231        self.social_providers.push(provider.into());
232        self
233    }
234
235    pub fn with_password_validator<F>(mut self, validator: F) -> Self
236    where
237        F: for<'a> Fn(
238                &'a AuthContext,
239                PluginPasswordValidationInput,
240            ) -> PluginPasswordValidatorFuture<'a>
241            + Send
242            + Sync
243            + 'static,
244    {
245        self.password_validators.push(PluginPasswordValidator {
246            handler: Arc::new(validator),
247        });
248        self
249    }
250
251    pub fn with_middleware<F>(mut self, path: impl Into<String>, middleware: F) -> Self
252    where
253        F: Fn(&AuthContext, &PluginRequest) -> Result<Option<PluginResponse>, OpenAuthError>
254            + Send
255            + Sync
256            + 'static,
257    {
258        self.middlewares.push(PluginMiddleware {
259            path: path.into(),
260            handler: Arc::new(middleware),
261        });
262        self
263    }
264
265    pub fn with_async_middleware<F>(mut self, path: impl Into<String>, middleware: F) -> Self
266    where
267        F: for<'a> Fn(&'a AuthContext, &'a PluginRequest) -> PluginMiddlewareFuture<'a>
268            + Send
269            + Sync
270            + 'static,
271    {
272        self.async_middlewares.push(PluginAsyncMiddleware {
273            path: path.into(),
274            handler: Arc::new(middleware),
275        });
276        self
277    }
278
279    pub fn with_on_request<F>(mut self, hook: F) -> Self
280    where
281        F: Fn(&AuthContext, PluginRequest) -> Result<PluginRequestAction, OpenAuthError>
282            + Send
283            + Sync
284            + 'static,
285    {
286        self.on_request = Some(Arc::new(hook));
287        self
288    }
289
290    pub fn with_on_response<F>(mut self, hook: F) -> Self
291    where
292        F: Fn(
293                &AuthContext,
294                &PluginRequest,
295                PluginResponse,
296            ) -> Result<PluginResponse, OpenAuthError>
297            + Send
298            + Sync
299            + 'static,
300    {
301        self.on_response = Some(Arc::new(hook));
302        self
303    }
304}
305
306impl fmt::Debug for AuthPlugin {
307    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
308        formatter
309            .debug_struct("AuthPlugin")
310            .field("id", &self.id)
311            .field("version", &self.version)
312            .field("options", &self.options)
313            .field("endpoints", &self.endpoints.len())
314            .field("middlewares", &self.middlewares)
315            .field("async_middlewares", &self.async_middlewares)
316            .field("on_request", &self.on_request.as_ref().map(|_| "<hook>"))
317            .field("on_response", &self.on_response.as_ref().map(|_| "<hook>"))
318            .field("init", &self.init.as_ref().map(|_| "<init>"))
319            .field("schema", &self.schema)
320            .field("rate_limit", &self.rate_limit)
321            .field("hooks", &self.hooks)
322            .field("error_codes", &self.error_codes)
323            .field("database_hooks", &self.database_hooks)
324            .field("migrations", &self.migrations)
325            .field("social_providers", &debug_social_providers(self))
326            .field("password_validators", &self.password_validators)
327            .finish()
328    }
329}
330
331#[cfg(feature = "oauth")]
332fn debug_social_providers(plugin: &AuthPlugin) -> Vec<&str> {
333    plugin
334        .social_providers
335        .iter()
336        .map(|provider| provider.id())
337        .collect()
338}
339
340#[cfg(not(feature = "oauth"))]
341fn debug_social_providers(_plugin: &AuthPlugin) -> Vec<&'static str> {
342    Vec::new()
343}
344
345#[derive(Clone)]
346pub struct PluginMiddleware {
347    pub path: String,
348    pub handler: PluginMiddlewareHandler,
349}
350
351impl fmt::Debug for PluginMiddleware {
352    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
353        formatter
354            .debug_struct("PluginMiddleware")
355            .field("path", &self.path)
356            .field("handler", &"<middleware>")
357            .finish()
358    }
359}
360
361#[derive(Clone)]
362pub struct PluginAsyncMiddleware {
363    pub path: String,
364    pub handler: PluginAsyncMiddlewareHandler,
365}
366
367impl fmt::Debug for PluginAsyncMiddleware {
368    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
369        formatter
370            .debug_struct("PluginAsyncMiddleware")
371            .field("path", &self.path)
372            .field("handler", &"<async middleware>")
373            .finish()
374    }
375}
376
377pub enum PluginRequestAction {
378    Continue(PluginRequest),
379    Respond(PluginResponse),
380}