1use std::future::Future;
4use std::pin::Pin;
5
6mod db;
7mod endpoint;
8mod error;
9mod hooks;
10mod init;
11mod password;
12mod rate_limit;
13mod schema;
14
15pub use db::{
16 PluginDatabaseAfterHookHandler, PluginDatabaseAfterInput, PluginDatabaseBeforeAction,
17 PluginDatabaseBeforeHookHandler, PluginDatabaseBeforeInput, PluginDatabaseHook,
18 PluginDatabaseHookContext, PluginDatabaseOperation, PluginMigration,
19};
20pub use endpoint::PluginEndpoint;
21pub use error::PluginErrorCode;
22pub use hooks::{
23 PluginAfterHook, PluginAfterHookAction, PluginAfterHookFuture, PluginAfterHookHandler,
24 PluginAsyncAfterHook, PluginAsyncAfterHookHandler, PluginAsyncBeforeHook,
25 PluginAsyncBeforeHookHandler, PluginBeforeHook, PluginBeforeHookAction, PluginBeforeHookFuture,
26 PluginBeforeHookHandler, PluginEndpointHooks, PluginHookMatcher,
27};
28pub use init::{PluginInitHandler, PluginInitOutput};
29pub use password::{
30 PluginPasswordValidationInput, PluginPasswordValidationRejection, PluginPasswordValidator,
31 PluginPasswordValidatorFuture, PluginPasswordValidatorHandler,
32};
33pub use rate_limit::PluginRateLimitRule;
34pub use schema::PluginSchemaContribution;
35
36use crate::api::AsyncAuthEndpoint;
37use crate::context::AuthContext;
38use crate::error::OpenAuthError;
39use http::{Request, Response};
40use openauth_oauth::oauth2::SocialOAuthProvider;
41use serde_json::Value;
42use std::fmt;
43use std::sync::Arc;
44
45pub type PluginBody = Vec<u8>;
46pub type PluginRequest = Request<PluginBody>;
47pub type PluginResponse = Response<PluginBody>;
48pub type PluginMiddlewareFuture<'a> =
49 Pin<Box<dyn Future<Output = Result<Option<PluginResponse>, OpenAuthError>> + Send + 'a>>;
50pub type PluginOnRequest = Arc<
51 dyn Fn(&AuthContext, PluginRequest) -> Result<PluginRequestAction, OpenAuthError> + Send + Sync,
52>;
53pub type PluginOnResponse = Arc<
54 dyn Fn(&AuthContext, &PluginRequest, PluginResponse) -> Result<PluginResponse, OpenAuthError>
55 + Send
56 + Sync,
57>;
58pub type PluginMiddlewareHandler = Arc<
59 dyn Fn(&AuthContext, &PluginRequest) -> Result<Option<PluginResponse>, OpenAuthError>
60 + Send
61 + Sync,
62>;
63pub type PluginAsyncMiddlewareHandler = Arc<
64 dyn for<'a> Fn(&'a AuthContext, &'a PluginRequest) -> PluginMiddlewareFuture<'a> + Send + Sync,
65>;
66
67#[derive(Clone)]
68pub struct AuthPlugin {
69 pub id: String,
70 pub version: Option<String>,
71 pub options: Option<Value>,
72 pub endpoints: Vec<AsyncAuthEndpoint>,
73 pub middlewares: Vec<PluginMiddleware>,
74 pub async_middlewares: Vec<PluginAsyncMiddleware>,
75 pub on_request: Option<PluginOnRequest>,
76 pub on_response: Option<PluginOnResponse>,
77 pub init: Option<PluginInitHandler>,
78 pub schema: Vec<PluginSchemaContribution>,
79 pub rate_limit: Vec<PluginRateLimitRule>,
80 pub hooks: PluginEndpointHooks,
81 pub error_codes: Vec<PluginErrorCode>,
82 pub database_hooks: Vec<PluginDatabaseHook>,
83 pub migrations: Vec<PluginMigration>,
84 pub social_providers: Vec<Arc<dyn SocialOAuthProvider>>,
85 pub password_validators: Vec<PluginPasswordValidator>,
86}
87
88impl AuthPlugin {
89 pub fn new(id: impl Into<String>) -> Self {
90 Self {
91 id: id.into(),
92 version: None,
93 options: None,
94 endpoints: Vec::new(),
95 middlewares: Vec::new(),
96 async_middlewares: Vec::new(),
97 on_request: None,
98 on_response: None,
99 init: None,
100 schema: Vec::new(),
101 rate_limit: Vec::new(),
102 hooks: PluginEndpointHooks::default(),
103 error_codes: Vec::new(),
104 database_hooks: Vec::new(),
105 migrations: Vec::new(),
106 social_providers: Vec::new(),
107 password_validators: Vec::new(),
108 }
109 }
110
111 pub fn with_version(mut self, version: impl Into<String>) -> Self {
112 self.version = Some(version.into());
113 self
114 }
115
116 pub fn with_options(mut self, options: Value) -> Self {
117 self.options = Some(options);
118 self
119 }
120
121 pub fn with_endpoint(mut self, endpoint: AsyncAuthEndpoint) -> Self {
122 self.endpoints.push(endpoint);
123 self
124 }
125
126 pub fn with_init<F>(mut self, init: F) -> Self
127 where
128 F: Fn(&AuthContext) -> Result<PluginInitOutput, OpenAuthError> + Send + Sync + 'static,
129 {
130 self.init = Some(Arc::new(init));
131 self
132 }
133
134 pub fn with_schema(mut self, contribution: PluginSchemaContribution) -> Self {
135 self.schema.push(contribution);
136 self
137 }
138
139 pub fn with_rate_limit(mut self, rule: PluginRateLimitRule) -> Self {
140 self.rate_limit.push(rule);
141 self
142 }
143
144 pub fn with_before_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
145 where
146 F: Fn(&AuthContext, PluginRequest) -> Result<PluginBeforeHookAction, OpenAuthError>
147 + Send
148 + Sync
149 + 'static,
150 {
151 self.hooks.before.push(PluginBeforeHook {
152 matcher: PluginHookMatcher::path(path),
153 handler: Arc::new(hook),
154 });
155 self
156 }
157
158 pub fn with_after_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
159 where
160 F: Fn(
161 &AuthContext,
162 &PluginRequest,
163 PluginResponse,
164 ) -> Result<PluginAfterHookAction, OpenAuthError>
165 + Send
166 + Sync
167 + 'static,
168 {
169 self.hooks.after.push(PluginAfterHook {
170 matcher: PluginHookMatcher::path(path),
171 handler: Arc::new(hook),
172 });
173 self
174 }
175
176 pub fn with_async_before_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
177 where
178 F: for<'a> Fn(&'a AuthContext, PluginRequest) -> PluginBeforeHookFuture<'a>
179 + Send
180 + Sync
181 + 'static,
182 {
183 self.hooks.async_before.push(PluginAsyncBeforeHook {
184 matcher: PluginHookMatcher::path(path),
185 handler: Arc::new(hook),
186 });
187 self
188 }
189
190 pub fn with_async_after_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
191 where
192 F: for<'a> Fn(
193 &'a AuthContext,
194 &'a PluginRequest,
195 PluginResponse,
196 ) -> PluginAfterHookFuture<'a>
197 + Send
198 + Sync
199 + 'static,
200 {
201 self.hooks.async_after.push(PluginAsyncAfterHook {
202 matcher: PluginHookMatcher::path(path),
203 handler: Arc::new(hook),
204 });
205 self
206 }
207
208 pub fn with_error_code(mut self, error_code: PluginErrorCode) -> Self {
209 self.error_codes.push(error_code);
210 self
211 }
212
213 pub fn with_database_hook(mut self, hook: PluginDatabaseHook) -> Self {
214 self.database_hooks.push(hook);
215 self
216 }
217
218 pub fn with_migration(mut self, migration: PluginMigration) -> Self {
219 self.migrations.push(migration);
220 self
221 }
222
223 pub fn with_social_provider(
224 mut self,
225 provider: impl Into<Arc<dyn SocialOAuthProvider>>,
226 ) -> Self {
227 self.social_providers.push(provider.into());
228 self
229 }
230
231 pub fn with_password_validator<F>(mut self, validator: F) -> Self
232 where
233 F: for<'a> Fn(
234 &'a AuthContext,
235 PluginPasswordValidationInput,
236 ) -> PluginPasswordValidatorFuture<'a>
237 + Send
238 + Sync
239 + 'static,
240 {
241 self.password_validators.push(PluginPasswordValidator {
242 handler: Arc::new(validator),
243 });
244 self
245 }
246
247 pub fn with_middleware<F>(mut self, path: impl Into<String>, middleware: F) -> Self
248 where
249 F: Fn(&AuthContext, &PluginRequest) -> Result<Option<PluginResponse>, OpenAuthError>
250 + Send
251 + Sync
252 + 'static,
253 {
254 self.middlewares.push(PluginMiddleware {
255 path: path.into(),
256 handler: Arc::new(middleware),
257 });
258 self
259 }
260
261 pub fn with_async_middleware<F>(mut self, path: impl Into<String>, middleware: F) -> Self
262 where
263 F: for<'a> Fn(&'a AuthContext, &'a PluginRequest) -> PluginMiddlewareFuture<'a>
264 + Send
265 + Sync
266 + 'static,
267 {
268 self.async_middlewares.push(PluginAsyncMiddleware {
269 path: path.into(),
270 handler: Arc::new(middleware),
271 });
272 self
273 }
274
275 pub fn with_on_request<F>(mut self, hook: F) -> Self
276 where
277 F: Fn(&AuthContext, PluginRequest) -> Result<PluginRequestAction, OpenAuthError>
278 + Send
279 + Sync
280 + 'static,
281 {
282 self.on_request = Some(Arc::new(hook));
283 self
284 }
285
286 pub fn with_on_response<F>(mut self, hook: F) -> Self
287 where
288 F: Fn(
289 &AuthContext,
290 &PluginRequest,
291 PluginResponse,
292 ) -> Result<PluginResponse, OpenAuthError>
293 + Send
294 + Sync
295 + 'static,
296 {
297 self.on_response = Some(Arc::new(hook));
298 self
299 }
300}
301
302impl fmt::Debug for AuthPlugin {
303 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
304 formatter
305 .debug_struct("AuthPlugin")
306 .field("id", &self.id)
307 .field("version", &self.version)
308 .field("options", &self.options)
309 .field("endpoints", &self.endpoints.len())
310 .field("middlewares", &self.middlewares)
311 .field("async_middlewares", &self.async_middlewares)
312 .field("on_request", &self.on_request.as_ref().map(|_| "<hook>"))
313 .field("on_response", &self.on_response.as_ref().map(|_| "<hook>"))
314 .field("init", &self.init.as_ref().map(|_| "<init>"))
315 .field("schema", &self.schema)
316 .field("rate_limit", &self.rate_limit)
317 .field("hooks", &self.hooks)
318 .field("error_codes", &self.error_codes)
319 .field("database_hooks", &self.database_hooks)
320 .field("migrations", &self.migrations)
321 .field(
322 "social_providers",
323 &self
324 .social_providers
325 .iter()
326 .map(|provider| provider.id())
327 .collect::<Vec<_>>(),
328 )
329 .field("password_validators", &self.password_validators)
330 .finish()
331 }
332}
333
334#[derive(Clone)]
335pub struct PluginMiddleware {
336 pub path: String,
337 pub handler: PluginMiddlewareHandler,
338}
339
340impl fmt::Debug for PluginMiddleware {
341 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
342 formatter
343 .debug_struct("PluginMiddleware")
344 .field("path", &self.path)
345 .field("handler", &"<middleware>")
346 .finish()
347 }
348}
349
350#[derive(Clone)]
351pub struct PluginAsyncMiddleware {
352 pub path: String,
353 pub handler: PluginAsyncMiddlewareHandler,
354}
355
356impl fmt::Debug for PluginAsyncMiddleware {
357 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
358 formatter
359 .debug_struct("PluginAsyncMiddleware")
360 .field("path", &self.path)
361 .field("handler", &"<async middleware>")
362 .finish()
363 }
364}
365
366pub enum PluginRequestAction {
367 Continue(PluginRequest),
368 Respond(PluginResponse),
369}