Skip to main content

openauth_core/
plugin.rs

1//! Plugin contracts for OpenAuth extensions.
2
3use std::future::Future;
4use std::pin::Pin;
5
6mod db;
7mod endpoint;
8mod error;
9mod hooks;
10mod init;
11mod password;
12mod rate_limit;
13mod schema;
14
15pub use db::{
16    PluginDatabaseAfterHookHandler, PluginDatabaseAfterInput, PluginDatabaseBeforeAction,
17    PluginDatabaseBeforeHookHandler, PluginDatabaseBeforeInput, PluginDatabaseHook,
18    PluginDatabaseHookContext, PluginDatabaseOperation, PluginMigration,
19};
20pub use endpoint::PluginEndpoint;
21pub use error::PluginErrorCode;
22pub use hooks::{
23    PluginAfterHook, PluginAfterHookAction, PluginAfterHookFuture, PluginAfterHookHandler,
24    PluginAsyncAfterHook, PluginAsyncAfterHookHandler, PluginAsyncBeforeHook,
25    PluginAsyncBeforeHookHandler, PluginBeforeHook, PluginBeforeHookAction, PluginBeforeHookFuture,
26    PluginBeforeHookHandler, PluginEndpointHooks, PluginHookMatcher,
27};
28pub use init::{PluginInitHandler, PluginInitOutput};
29pub use password::{
30    PluginPasswordValidationInput, PluginPasswordValidationRejection, PluginPasswordValidator,
31    PluginPasswordValidatorFuture, PluginPasswordValidatorHandler,
32};
33pub use rate_limit::PluginRateLimitRule;
34pub use schema::PluginSchemaContribution;
35
36use crate::api::AsyncAuthEndpoint;
37use crate::context::AuthContext;
38use crate::error::OpenAuthError;
39use http::{Request, Response};
40use openauth_oauth::oauth2::SocialOAuthProvider;
41use serde_json::Value;
42use std::fmt;
43use std::sync::Arc;
44
45pub type PluginBody = Vec<u8>;
46pub type PluginRequest = Request<PluginBody>;
47pub type PluginResponse = Response<PluginBody>;
48pub type PluginMiddlewareFuture<'a> =
49    Pin<Box<dyn Future<Output = Result<Option<PluginResponse>, OpenAuthError>> + Send + 'a>>;
50pub type PluginOnRequest = Arc<
51    dyn Fn(&AuthContext, PluginRequest) -> Result<PluginRequestAction, OpenAuthError> + Send + Sync,
52>;
53pub type PluginOnResponse = Arc<
54    dyn Fn(&AuthContext, &PluginRequest, PluginResponse) -> Result<PluginResponse, OpenAuthError>
55        + Send
56        + Sync,
57>;
58pub type PluginMiddlewareHandler = Arc<
59    dyn Fn(&AuthContext, &PluginRequest) -> Result<Option<PluginResponse>, OpenAuthError>
60        + Send
61        + Sync,
62>;
63pub type PluginAsyncMiddlewareHandler = Arc<
64    dyn for<'a> Fn(&'a AuthContext, &'a PluginRequest) -> PluginMiddlewareFuture<'a> + Send + Sync,
65>;
66
67#[derive(Clone)]
68pub struct AuthPlugin {
69    pub id: String,
70    pub version: Option<String>,
71    pub options: Option<Value>,
72    pub endpoints: Vec<AsyncAuthEndpoint>,
73    pub middlewares: Vec<PluginMiddleware>,
74    pub async_middlewares: Vec<PluginAsyncMiddleware>,
75    pub on_request: Option<PluginOnRequest>,
76    pub on_response: Option<PluginOnResponse>,
77    pub init: Option<PluginInitHandler>,
78    pub schema: Vec<PluginSchemaContribution>,
79    pub rate_limit: Vec<PluginRateLimitRule>,
80    pub hooks: PluginEndpointHooks,
81    pub error_codes: Vec<PluginErrorCode>,
82    pub database_hooks: Vec<PluginDatabaseHook>,
83    pub migrations: Vec<PluginMigration>,
84    pub social_providers: Vec<Arc<dyn SocialOAuthProvider>>,
85    pub password_validators: Vec<PluginPasswordValidator>,
86}
87
88impl AuthPlugin {
89    pub fn new(id: impl Into<String>) -> Self {
90        Self {
91            id: id.into(),
92            version: None,
93            options: None,
94            endpoints: Vec::new(),
95            middlewares: Vec::new(),
96            async_middlewares: Vec::new(),
97            on_request: None,
98            on_response: None,
99            init: None,
100            schema: Vec::new(),
101            rate_limit: Vec::new(),
102            hooks: PluginEndpointHooks::default(),
103            error_codes: Vec::new(),
104            database_hooks: Vec::new(),
105            migrations: Vec::new(),
106            social_providers: Vec::new(),
107            password_validators: Vec::new(),
108        }
109    }
110
111    pub fn with_version(mut self, version: impl Into<String>) -> Self {
112        self.version = Some(version.into());
113        self
114    }
115
116    pub fn with_options(mut self, options: Value) -> Self {
117        self.options = Some(options);
118        self
119    }
120
121    pub fn with_endpoint(mut self, endpoint: AsyncAuthEndpoint) -> Self {
122        self.endpoints.push(endpoint);
123        self
124    }
125
126    pub fn with_init<F>(mut self, init: F) -> Self
127    where
128        F: Fn(&AuthContext) -> Result<PluginInitOutput, OpenAuthError> + Send + Sync + 'static,
129    {
130        self.init = Some(Arc::new(init));
131        self
132    }
133
134    pub fn with_schema(mut self, contribution: PluginSchemaContribution) -> Self {
135        self.schema.push(contribution);
136        self
137    }
138
139    pub fn with_rate_limit(mut self, rule: PluginRateLimitRule) -> Self {
140        self.rate_limit.push(rule);
141        self
142    }
143
144    pub fn with_before_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
145    where
146        F: Fn(&AuthContext, PluginRequest) -> Result<PluginBeforeHookAction, OpenAuthError>
147            + Send
148            + Sync
149            + 'static,
150    {
151        self.hooks.before.push(PluginBeforeHook {
152            matcher: PluginHookMatcher::path(path),
153            handler: Arc::new(hook),
154        });
155        self
156    }
157
158    pub fn with_after_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
159    where
160        F: Fn(
161                &AuthContext,
162                &PluginRequest,
163                PluginResponse,
164            ) -> Result<PluginAfterHookAction, OpenAuthError>
165            + Send
166            + Sync
167            + 'static,
168    {
169        self.hooks.after.push(PluginAfterHook {
170            matcher: PluginHookMatcher::path(path),
171            handler: Arc::new(hook),
172        });
173        self
174    }
175
176    pub fn with_async_before_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
177    where
178        F: for<'a> Fn(&'a AuthContext, PluginRequest) -> PluginBeforeHookFuture<'a>
179            + Send
180            + Sync
181            + 'static,
182    {
183        self.hooks.async_before.push(PluginAsyncBeforeHook {
184            matcher: PluginHookMatcher::path(path),
185            handler: Arc::new(hook),
186        });
187        self
188    }
189
190    pub fn with_async_after_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
191    where
192        F: for<'a> Fn(
193                &'a AuthContext,
194                &'a PluginRequest,
195                PluginResponse,
196            ) -> PluginAfterHookFuture<'a>
197            + Send
198            + Sync
199            + 'static,
200    {
201        self.hooks.async_after.push(PluginAsyncAfterHook {
202            matcher: PluginHookMatcher::path(path),
203            handler: Arc::new(hook),
204        });
205        self
206    }
207
208    pub fn with_error_code(mut self, error_code: PluginErrorCode) -> Self {
209        self.error_codes.push(error_code);
210        self
211    }
212
213    pub fn with_database_hook(mut self, hook: PluginDatabaseHook) -> Self {
214        self.database_hooks.push(hook);
215        self
216    }
217
218    pub fn with_migration(mut self, migration: PluginMigration) -> Self {
219        self.migrations.push(migration);
220        self
221    }
222
223    pub fn with_social_provider(
224        mut self,
225        provider: impl Into<Arc<dyn SocialOAuthProvider>>,
226    ) -> Self {
227        self.social_providers.push(provider.into());
228        self
229    }
230
231    pub fn with_password_validator<F>(mut self, validator: F) -> Self
232    where
233        F: for<'a> Fn(
234                &'a AuthContext,
235                PluginPasswordValidationInput,
236            ) -> PluginPasswordValidatorFuture<'a>
237            + Send
238            + Sync
239            + 'static,
240    {
241        self.password_validators.push(PluginPasswordValidator {
242            handler: Arc::new(validator),
243        });
244        self
245    }
246
247    pub fn with_middleware<F>(mut self, path: impl Into<String>, middleware: F) -> Self
248    where
249        F: Fn(&AuthContext, &PluginRequest) -> Result<Option<PluginResponse>, OpenAuthError>
250            + Send
251            + Sync
252            + 'static,
253    {
254        self.middlewares.push(PluginMiddleware {
255            path: path.into(),
256            handler: Arc::new(middleware),
257        });
258        self
259    }
260
261    pub fn with_async_middleware<F>(mut self, path: impl Into<String>, middleware: F) -> Self
262    where
263        F: for<'a> Fn(&'a AuthContext, &'a PluginRequest) -> PluginMiddlewareFuture<'a>
264            + Send
265            + Sync
266            + 'static,
267    {
268        self.async_middlewares.push(PluginAsyncMiddleware {
269            path: path.into(),
270            handler: Arc::new(middleware),
271        });
272        self
273    }
274
275    pub fn with_on_request<F>(mut self, hook: F) -> Self
276    where
277        F: Fn(&AuthContext, PluginRequest) -> Result<PluginRequestAction, OpenAuthError>
278            + Send
279            + Sync
280            + 'static,
281    {
282        self.on_request = Some(Arc::new(hook));
283        self
284    }
285
286    pub fn with_on_response<F>(mut self, hook: F) -> Self
287    where
288        F: Fn(
289                &AuthContext,
290                &PluginRequest,
291                PluginResponse,
292            ) -> Result<PluginResponse, OpenAuthError>
293            + Send
294            + Sync
295            + 'static,
296    {
297        self.on_response = Some(Arc::new(hook));
298        self
299    }
300}
301
302impl fmt::Debug for AuthPlugin {
303    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
304        formatter
305            .debug_struct("AuthPlugin")
306            .field("id", &self.id)
307            .field("version", &self.version)
308            .field("options", &self.options)
309            .field("endpoints", &self.endpoints.len())
310            .field("middlewares", &self.middlewares)
311            .field("async_middlewares", &self.async_middlewares)
312            .field("on_request", &self.on_request.as_ref().map(|_| "<hook>"))
313            .field("on_response", &self.on_response.as_ref().map(|_| "<hook>"))
314            .field("init", &self.init.as_ref().map(|_| "<init>"))
315            .field("schema", &self.schema)
316            .field("rate_limit", &self.rate_limit)
317            .field("hooks", &self.hooks)
318            .field("error_codes", &self.error_codes)
319            .field("database_hooks", &self.database_hooks)
320            .field("migrations", &self.migrations)
321            .field(
322                "social_providers",
323                &self
324                    .social_providers
325                    .iter()
326                    .map(|provider| provider.id())
327                    .collect::<Vec<_>>(),
328            )
329            .field("password_validators", &self.password_validators)
330            .finish()
331    }
332}
333
334#[derive(Clone)]
335pub struct PluginMiddleware {
336    pub path: String,
337    pub handler: PluginMiddlewareHandler,
338}
339
340impl fmt::Debug for PluginMiddleware {
341    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
342        formatter
343            .debug_struct("PluginMiddleware")
344            .field("path", &self.path)
345            .field("handler", &"<middleware>")
346            .finish()
347    }
348}
349
350#[derive(Clone)]
351pub struct PluginAsyncMiddleware {
352    pub path: String,
353    pub handler: PluginAsyncMiddlewareHandler,
354}
355
356impl fmt::Debug for PluginAsyncMiddleware {
357    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
358        formatter
359            .debug_struct("PluginAsyncMiddleware")
360            .field("path", &self.path)
361            .field("handler", &"<async middleware>")
362            .finish()
363    }
364}
365
366pub enum PluginRequestAction {
367    Continue(PluginRequest),
368    Respond(PluginResponse),
369}