openauth_core/plugin/
init.rs1use super::db::{PluginDatabaseHook, PluginMigration};
4use super::error::PluginErrorCode;
5use super::rate_limit::PluginRateLimitRule;
6use super::schema::PluginSchemaContribution;
7use crate::context::AuthContext;
8use crate::error::OpenAuthError;
9use crate::options::{SessionAdditionalField, UserAdditionalField};
10#[cfg(feature = "oauth")]
11use openauth_oauth::oauth2::SocialOAuthProvider;
12use std::collections::BTreeMap;
13use std::fmt;
14use std::sync::Arc;
15
16pub type PluginInitHandler =
17 Arc<dyn Fn(&AuthContext) -> Result<PluginInitOutput, OpenAuthError> + Send + Sync>;
18
19#[derive(Clone, Default)]
21pub struct PluginInitOutput {
22 pub trusted_origins: Vec<String>,
23 pub disabled_paths: Vec<String>,
24 pub schema: Vec<PluginSchemaContribution>,
25 pub rate_limit: Vec<PluginRateLimitRule>,
26 pub error_codes: Vec<PluginErrorCode>,
27 pub database_hooks: Vec<PluginDatabaseHook>,
28 pub migrations: Vec<PluginMigration>,
29 #[cfg(feature = "oauth")]
30 pub social_providers: Vec<Arc<dyn SocialOAuthProvider>>,
31 pub user_additional_fields: BTreeMap<String, UserAdditionalField>,
32 pub session_additional_fields: BTreeMap<String, SessionAdditionalField>,
33}
34
35impl fmt::Debug for PluginInitOutput {
36 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
37 formatter
38 .debug_struct("PluginInitOutput")
39 .field("trusted_origins", &self.trusted_origins)
40 .field("disabled_paths", &self.disabled_paths)
41 .field("schema", &self.schema)
42 .field("rate_limit", &self.rate_limit)
43 .field("error_codes", &self.error_codes)
44 .field("database_hooks", &self.database_hooks)
45 .field("migrations", &self.migrations)
46 .field("user_additional_fields", &self.user_additional_fields)
47 .field("session_additional_fields", &self.session_additional_fields)
48 .field("social_providers", &debug_social_providers(self))
49 .finish()
50 }
51}
52
53impl PluginInitOutput {
54 pub fn new() -> Self {
55 Self::default()
56 }
57
58 #[must_use]
59 pub fn trusted_origin(mut self, origin: impl Into<String>) -> Self {
60 self.trusted_origins.push(origin.into());
61 self
62 }
63
64 #[must_use]
65 pub fn disabled_path(mut self, path: impl Into<String>) -> Self {
66 self.disabled_paths.push(path.into());
67 self
68 }
69
70 #[must_use]
71 pub fn schema(mut self, contribution: PluginSchemaContribution) -> Self {
72 self.schema.push(contribution);
73 self
74 }
75
76 #[must_use]
77 pub fn rate_limit(mut self, rule: PluginRateLimitRule) -> Self {
78 self.rate_limit.push(rule);
79 self
80 }
81
82 #[must_use]
83 pub fn error_code(mut self, code: PluginErrorCode) -> Self {
84 self.error_codes.push(code);
85 self
86 }
87
88 #[must_use]
89 pub fn database_hook(mut self, hook: PluginDatabaseHook) -> Self {
90 self.database_hooks.push(hook);
91 self
92 }
93
94 #[must_use]
95 pub fn migration(mut self, migration: PluginMigration) -> Self {
96 self.migrations.push(migration);
97 self
98 }
99
100 #[cfg(feature = "oauth")]
101 #[must_use]
102 pub fn social_provider(mut self, provider: impl Into<Arc<dyn SocialOAuthProvider>>) -> Self {
103 self.social_providers.push(provider.into());
104 self
105 }
106
107 #[must_use]
108 pub fn user_additional_field(
109 mut self,
110 name: impl Into<String>,
111 field: UserAdditionalField,
112 ) -> Self {
113 self.user_additional_fields.insert(name.into(), field);
114 self
115 }
116
117 #[must_use]
118 pub fn session_additional_field(
119 mut self,
120 name: impl Into<String>,
121 field: SessionAdditionalField,
122 ) -> Self {
123 self.session_additional_fields.insert(name.into(), field);
124 self
125 }
126}
127
128#[cfg(feature = "oauth")]
129fn debug_social_providers(output: &PluginInitOutput) -> Vec<&str> {
130 output
131 .social_providers
132 .iter()
133 .map(|provider| provider.id())
134 .collect()
135}
136
137#[cfg(not(feature = "oauth"))]
138fn debug_social_providers(_output: &PluginInitOutput) -> Vec<&'static str> {
139 Vec::new()
140}