1use std::future::Future;
4use std::pin::Pin;
5
6mod db;
7mod endpoint;
8mod error;
9mod hooks;
10mod init;
11mod password;
12mod rate_limit;
13mod schema;
14
15pub use db::{
16 PluginDatabaseAfterHookHandler, PluginDatabaseAfterInput, PluginDatabaseBeforeAction,
17 PluginDatabaseBeforeHookHandler, PluginDatabaseBeforeInput, PluginDatabaseHook,
18 PluginDatabaseHookContext, PluginDatabaseOperation, PluginMigration,
19};
20pub use endpoint::PluginEndpoint;
21pub use error::PluginErrorCode;
22pub use hooks::{
23 PluginAfterHook, PluginAfterHookAction, PluginAfterHookFuture, PluginAfterHookHandler,
24 PluginAsyncAfterHook, PluginAsyncAfterHookHandler, PluginAsyncBeforeHook,
25 PluginAsyncBeforeHookHandler, PluginBeforeHook, PluginBeforeHookAction, PluginBeforeHookFuture,
26 PluginBeforeHookHandler, PluginEndpointHooks, PluginHookMatcher,
27};
28pub use init::{PluginInitHandler, PluginInitOutput};
29pub use password::{
30 PluginPasswordValidationInput, PluginPasswordValidationRejection, PluginPasswordValidator,
31 PluginPasswordValidatorFuture, PluginPasswordValidatorHandler,
32};
33pub use rate_limit::PluginRateLimitRule;
34pub use schema::PluginSchemaContribution;
35
36use crate::api::AsyncAuthEndpoint;
37use crate::context::AuthContext;
38use crate::error::OpenAuthError;
39use http::{Request, Response};
40#[cfg(feature = "oauth")]
41use openauth_oauth::oauth2::SocialOAuthProvider;
42use serde_json::Value;
43use std::fmt;
44use std::sync::Arc;
45
46pub type PluginBody = Vec<u8>;
47pub type PluginRequest = Request<PluginBody>;
48pub type PluginResponse = Response<PluginBody>;
49pub type PluginMiddlewareFuture<'a> =
50 Pin<Box<dyn Future<Output = Result<Option<PluginResponse>, OpenAuthError>> + Send + 'a>>;
51pub type PluginOnRequest = Arc<
52 dyn Fn(&AuthContext, PluginRequest) -> Result<PluginRequestAction, OpenAuthError> + Send + Sync,
53>;
54pub type PluginOnResponse = Arc<
55 dyn Fn(&AuthContext, &PluginRequest, PluginResponse) -> Result<PluginResponse, OpenAuthError>
56 + Send
57 + Sync,
58>;
59pub type PluginMiddlewareHandler = Arc<
60 dyn Fn(&AuthContext, &PluginRequest) -> Result<Option<PluginResponse>, OpenAuthError>
61 + Send
62 + Sync,
63>;
64pub type PluginAsyncMiddlewareHandler = Arc<
65 dyn for<'a> Fn(&'a AuthContext, &'a PluginRequest) -> PluginMiddlewareFuture<'a> + Send + Sync,
66>;
67
68#[derive(Clone)]
69pub struct AuthPlugin {
70 pub id: String,
71 pub version: Option<String>,
72 pub options: Option<Value>,
73 pub endpoints: Vec<AsyncAuthEndpoint>,
74 pub middlewares: Vec<PluginMiddleware>,
75 pub async_middlewares: Vec<PluginAsyncMiddleware>,
76 pub on_request: Option<PluginOnRequest>,
77 pub on_response: Option<PluginOnResponse>,
78 pub init: Option<PluginInitHandler>,
79 pub schema: Vec<PluginSchemaContribution>,
80 pub rate_limit: Vec<PluginRateLimitRule>,
81 pub hooks: PluginEndpointHooks,
82 pub error_codes: Vec<PluginErrorCode>,
83 pub database_hooks: Vec<PluginDatabaseHook>,
84 pub migrations: Vec<PluginMigration>,
85 #[cfg(feature = "oauth")]
86 pub social_providers: Vec<Arc<dyn SocialOAuthProvider>>,
87 pub password_validators: Vec<PluginPasswordValidator>,
88}
89
90impl AuthPlugin {
91 pub fn new(id: impl Into<String>) -> Self {
92 Self {
93 id: id.into(),
94 version: None,
95 options: None,
96 endpoints: Vec::new(),
97 middlewares: Vec::new(),
98 async_middlewares: Vec::new(),
99 on_request: None,
100 on_response: None,
101 init: None,
102 schema: Vec::new(),
103 rate_limit: Vec::new(),
104 hooks: PluginEndpointHooks::default(),
105 error_codes: Vec::new(),
106 database_hooks: Vec::new(),
107 migrations: Vec::new(),
108 #[cfg(feature = "oauth")]
109 social_providers: Vec::new(),
110 password_validators: Vec::new(),
111 }
112 }
113
114 pub fn with_version(mut self, version: impl Into<String>) -> Self {
115 self.version = Some(version.into());
116 self
117 }
118
119 pub fn with_options(mut self, options: Value) -> Self {
120 self.options = Some(options);
121 self
122 }
123
124 pub fn with_endpoint(mut self, endpoint: AsyncAuthEndpoint) -> Self {
125 self.endpoints.push(endpoint);
126 self
127 }
128
129 pub fn with_init<F>(mut self, init: F) -> Self
130 where
131 F: Fn(&AuthContext) -> Result<PluginInitOutput, OpenAuthError> + Send + Sync + 'static,
132 {
133 self.init = Some(Arc::new(init));
134 self
135 }
136
137 pub fn with_schema(mut self, contribution: PluginSchemaContribution) -> Self {
138 self.schema.push(contribution);
139 self
140 }
141
142 pub fn with_rate_limit(mut self, rule: PluginRateLimitRule) -> Self {
143 self.rate_limit.push(rule);
144 self
145 }
146
147 pub fn with_before_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
148 where
149 F: Fn(&AuthContext, PluginRequest) -> Result<PluginBeforeHookAction, OpenAuthError>
150 + Send
151 + Sync
152 + 'static,
153 {
154 self.hooks.before.push(PluginBeforeHook {
155 matcher: PluginHookMatcher::path(path),
156 handler: Arc::new(hook),
157 });
158 self
159 }
160
161 pub fn with_after_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
162 where
163 F: Fn(
164 &AuthContext,
165 &PluginRequest,
166 PluginResponse,
167 ) -> Result<PluginAfterHookAction, OpenAuthError>
168 + Send
169 + Sync
170 + 'static,
171 {
172 self.hooks.after.push(PluginAfterHook {
173 matcher: PluginHookMatcher::path(path),
174 handler: Arc::new(hook),
175 });
176 self
177 }
178
179 pub fn with_async_before_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
180 where
181 F: for<'a> Fn(&'a AuthContext, PluginRequest) -> PluginBeforeHookFuture<'a>
182 + Send
183 + Sync
184 + 'static,
185 {
186 self.hooks.async_before.push(PluginAsyncBeforeHook {
187 matcher: PluginHookMatcher::path(path),
188 handler: Arc::new(hook),
189 });
190 self
191 }
192
193 pub fn with_async_after_hook<F>(mut self, path: impl Into<String>, hook: F) -> Self
194 where
195 F: for<'a> Fn(
196 &'a AuthContext,
197 &'a PluginRequest,
198 PluginResponse,
199 ) -> PluginAfterHookFuture<'a>
200 + Send
201 + Sync
202 + 'static,
203 {
204 self.hooks.async_after.push(PluginAsyncAfterHook {
205 matcher: PluginHookMatcher::path(path),
206 handler: Arc::new(hook),
207 });
208 self
209 }
210
211 pub fn with_error_code(mut self, error_code: PluginErrorCode) -> Self {
212 self.error_codes.push(error_code);
213 self
214 }
215
216 pub fn with_database_hook(mut self, hook: PluginDatabaseHook) -> Self {
217 self.database_hooks.push(hook);
218 self
219 }
220
221 pub fn with_migration(mut self, migration: PluginMigration) -> Self {
222 self.migrations.push(migration);
223 self
224 }
225
226 #[cfg(feature = "oauth")]
227 pub fn with_social_provider(
228 mut self,
229 provider: impl Into<Arc<dyn SocialOAuthProvider>>,
230 ) -> Self {
231 self.social_providers.push(provider.into());
232 self
233 }
234
235 pub fn with_password_validator<F>(mut self, validator: F) -> Self
236 where
237 F: for<'a> Fn(
238 &'a AuthContext,
239 PluginPasswordValidationInput,
240 ) -> PluginPasswordValidatorFuture<'a>
241 + Send
242 + Sync
243 + 'static,
244 {
245 self.password_validators.push(PluginPasswordValidator {
246 handler: Arc::new(validator),
247 });
248 self
249 }
250
251 pub fn with_middleware<F>(mut self, path: impl Into<String>, middleware: F) -> Self
252 where
253 F: Fn(&AuthContext, &PluginRequest) -> Result<Option<PluginResponse>, OpenAuthError>
254 + Send
255 + Sync
256 + 'static,
257 {
258 self.middlewares.push(PluginMiddleware {
259 path: path.into(),
260 handler: Arc::new(middleware),
261 });
262 self
263 }
264
265 pub fn with_async_middleware<F>(mut self, path: impl Into<String>, middleware: F) -> Self
266 where
267 F: for<'a> Fn(&'a AuthContext, &'a PluginRequest) -> PluginMiddlewareFuture<'a>
268 + Send
269 + Sync
270 + 'static,
271 {
272 self.async_middlewares.push(PluginAsyncMiddleware {
273 path: path.into(),
274 handler: Arc::new(middleware),
275 });
276 self
277 }
278
279 pub fn with_on_request<F>(mut self, hook: F) -> Self
280 where
281 F: Fn(&AuthContext, PluginRequest) -> Result<PluginRequestAction, OpenAuthError>
282 + Send
283 + Sync
284 + 'static,
285 {
286 self.on_request = Some(Arc::new(hook));
287 self
288 }
289
290 pub fn with_on_response<F>(mut self, hook: F) -> Self
291 where
292 F: Fn(
293 &AuthContext,
294 &PluginRequest,
295 PluginResponse,
296 ) -> Result<PluginResponse, OpenAuthError>
297 + Send
298 + Sync
299 + 'static,
300 {
301 self.on_response = Some(Arc::new(hook));
302 self
303 }
304}
305
306impl fmt::Debug for AuthPlugin {
307 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
308 formatter
309 .debug_struct("AuthPlugin")
310 .field("id", &self.id)
311 .field("version", &self.version)
312 .field("options", &self.options)
313 .field("endpoints", &self.endpoints.len())
314 .field("middlewares", &self.middlewares)
315 .field("async_middlewares", &self.async_middlewares)
316 .field("on_request", &self.on_request.as_ref().map(|_| "<hook>"))
317 .field("on_response", &self.on_response.as_ref().map(|_| "<hook>"))
318 .field("init", &self.init.as_ref().map(|_| "<init>"))
319 .field("schema", &self.schema)
320 .field("rate_limit", &self.rate_limit)
321 .field("hooks", &self.hooks)
322 .field("error_codes", &self.error_codes)
323 .field("database_hooks", &self.database_hooks)
324 .field("migrations", &self.migrations)
325 .field("social_providers", &debug_social_providers(self))
326 .field("password_validators", &self.password_validators)
327 .finish()
328 }
329}
330
331#[cfg(feature = "oauth")]
332fn debug_social_providers(plugin: &AuthPlugin) -> Vec<&str> {
333 plugin
334 .social_providers
335 .iter()
336 .map(|provider| provider.id())
337 .collect()
338}
339
340#[cfg(not(feature = "oauth"))]
341fn debug_social_providers(_plugin: &AuthPlugin) -> Vec<&'static str> {
342 Vec::new()
343}
344
345#[derive(Clone)]
346pub struct PluginMiddleware {
347 pub path: String,
348 pub handler: PluginMiddlewareHandler,
349}
350
351impl fmt::Debug for PluginMiddleware {
352 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
353 formatter
354 .debug_struct("PluginMiddleware")
355 .field("path", &self.path)
356 .field("handler", &"<middleware>")
357 .finish()
358 }
359}
360
361#[derive(Clone)]
362pub struct PluginAsyncMiddleware {
363 pub path: String,
364 pub handler: PluginAsyncMiddlewareHandler,
365}
366
367impl fmt::Debug for PluginAsyncMiddleware {
368 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
369 formatter
370 .debug_struct("PluginAsyncMiddleware")
371 .field("path", &self.path)
372 .field("handler", &"<async middleware>")
373 .finish()
374 }
375}
376
377pub enum PluginRequestAction {
378 Continue(PluginRequest),
379 Respond(PluginResponse),
380}