openauth_core/plugin/
init.rs1use super::db::{PluginDatabaseHook, PluginMigration};
4use super::error::PluginErrorCode;
5use super::rate_limit::PluginRateLimitRule;
6use super::schema::PluginSchemaContribution;
7use crate::context::AuthContext;
8use crate::error::OpenAuthError;
9use crate::options::{SessionAdditionalField, UserAdditionalField};
10use openauth_oauth::oauth2::SocialOAuthProvider;
11use std::collections::BTreeMap;
12use std::fmt;
13use std::sync::Arc;
14
15pub type PluginInitHandler =
16 Arc<dyn Fn(&AuthContext) -> Result<PluginInitOutput, OpenAuthError> + Send + Sync>;
17
18#[derive(Clone, Default)]
20pub struct PluginInitOutput {
21 pub trusted_origins: Vec<String>,
22 pub disabled_paths: Vec<String>,
23 pub schema: Vec<PluginSchemaContribution>,
24 pub rate_limit: Vec<PluginRateLimitRule>,
25 pub error_codes: Vec<PluginErrorCode>,
26 pub database_hooks: Vec<PluginDatabaseHook>,
27 pub migrations: Vec<PluginMigration>,
28 pub social_providers: Vec<Arc<dyn SocialOAuthProvider>>,
29 pub user_additional_fields: BTreeMap<String, UserAdditionalField>,
30 pub session_additional_fields: BTreeMap<String, SessionAdditionalField>,
31}
32
33impl fmt::Debug for PluginInitOutput {
34 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
35 formatter
36 .debug_struct("PluginInitOutput")
37 .field("trusted_origins", &self.trusted_origins)
38 .field("disabled_paths", &self.disabled_paths)
39 .field("schema", &self.schema)
40 .field("rate_limit", &self.rate_limit)
41 .field("error_codes", &self.error_codes)
42 .field("database_hooks", &self.database_hooks)
43 .field("migrations", &self.migrations)
44 .field("user_additional_fields", &self.user_additional_fields)
45 .field("session_additional_fields", &self.session_additional_fields)
46 .field(
47 "social_providers",
48 &self
49 .social_providers
50 .iter()
51 .map(|provider| provider.id())
52 .collect::<Vec<_>>(),
53 )
54 .finish()
55 }
56}
57
58impl PluginInitOutput {
59 pub fn new() -> Self {
60 Self::default()
61 }
62
63 #[must_use]
64 pub fn trusted_origin(mut self, origin: impl Into<String>) -> Self {
65 self.trusted_origins.push(origin.into());
66 self
67 }
68
69 #[must_use]
70 pub fn disabled_path(mut self, path: impl Into<String>) -> Self {
71 self.disabled_paths.push(path.into());
72 self
73 }
74
75 #[must_use]
76 pub fn schema(mut self, contribution: PluginSchemaContribution) -> Self {
77 self.schema.push(contribution);
78 self
79 }
80
81 #[must_use]
82 pub fn rate_limit(mut self, rule: PluginRateLimitRule) -> Self {
83 self.rate_limit.push(rule);
84 self
85 }
86
87 #[must_use]
88 pub fn error_code(mut self, code: PluginErrorCode) -> Self {
89 self.error_codes.push(code);
90 self
91 }
92
93 #[must_use]
94 pub fn database_hook(mut self, hook: PluginDatabaseHook) -> Self {
95 self.database_hooks.push(hook);
96 self
97 }
98
99 #[must_use]
100 pub fn migration(mut self, migration: PluginMigration) -> Self {
101 self.migrations.push(migration);
102 self
103 }
104
105 #[must_use]
106 pub fn social_provider(mut self, provider: impl Into<Arc<dyn SocialOAuthProvider>>) -> Self {
107 self.social_providers.push(provider.into());
108 self
109 }
110
111 #[must_use]
112 pub fn user_additional_field(
113 mut self,
114 name: impl Into<String>,
115 field: UserAdditionalField,
116 ) -> Self {
117 self.user_additional_fields.insert(name.into(), field);
118 self
119 }
120
121 #[must_use]
122 pub fn session_additional_field(
123 mut self,
124 name: impl Into<String>,
125 field: SessionAdditionalField,
126 ) -> Self {
127 self.session_additional_fields.insert(name.into(), field);
128 self
129 }
130}