Skip to main content

openauth_core/plugin/
init.rs

1//! Plugin initialization contracts.
2
3use super::db::{PluginDatabaseHook, PluginMigration};
4use super::error::PluginErrorCode;
5use super::rate_limit::PluginRateLimitRule;
6use super::schema::PluginSchemaContribution;
7use crate::context::AuthContext;
8use crate::error::OpenAuthError;
9use crate::options::{SessionAdditionalField, UserAdditionalField};
10use openauth_oauth::oauth2::SocialOAuthProvider;
11use std::collections::BTreeMap;
12use std::fmt;
13use std::sync::Arc;
14
15pub type PluginInitHandler =
16    Arc<dyn Fn(&AuthContext) -> Result<PluginInitOutput, OpenAuthError> + Send + Sync>;
17
18/// Typed, additive output from a plugin init handler.
19#[derive(Clone, Default)]
20pub struct PluginInitOutput {
21    pub trusted_origins: Vec<String>,
22    pub disabled_paths: Vec<String>,
23    pub schema: Vec<PluginSchemaContribution>,
24    pub rate_limit: Vec<PluginRateLimitRule>,
25    pub error_codes: Vec<PluginErrorCode>,
26    pub database_hooks: Vec<PluginDatabaseHook>,
27    pub migrations: Vec<PluginMigration>,
28    pub social_providers: Vec<Arc<dyn SocialOAuthProvider>>,
29    pub user_additional_fields: BTreeMap<String, UserAdditionalField>,
30    pub session_additional_fields: BTreeMap<String, SessionAdditionalField>,
31}
32
33impl fmt::Debug for PluginInitOutput {
34    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
35        formatter
36            .debug_struct("PluginInitOutput")
37            .field("trusted_origins", &self.trusted_origins)
38            .field("disabled_paths", &self.disabled_paths)
39            .field("schema", &self.schema)
40            .field("rate_limit", &self.rate_limit)
41            .field("error_codes", &self.error_codes)
42            .field("database_hooks", &self.database_hooks)
43            .field("migrations", &self.migrations)
44            .field("user_additional_fields", &self.user_additional_fields)
45            .field("session_additional_fields", &self.session_additional_fields)
46            .field(
47                "social_providers",
48                &self
49                    .social_providers
50                    .iter()
51                    .map(|provider| provider.id())
52                    .collect::<Vec<_>>(),
53            )
54            .finish()
55    }
56}
57
58impl PluginInitOutput {
59    pub fn new() -> Self {
60        Self::default()
61    }
62
63    #[must_use]
64    pub fn trusted_origin(mut self, origin: impl Into<String>) -> Self {
65        self.trusted_origins.push(origin.into());
66        self
67    }
68
69    #[must_use]
70    pub fn disabled_path(mut self, path: impl Into<String>) -> Self {
71        self.disabled_paths.push(path.into());
72        self
73    }
74
75    #[must_use]
76    pub fn schema(mut self, contribution: PluginSchemaContribution) -> Self {
77        self.schema.push(contribution);
78        self
79    }
80
81    #[must_use]
82    pub fn rate_limit(mut self, rule: PluginRateLimitRule) -> Self {
83        self.rate_limit.push(rule);
84        self
85    }
86
87    #[must_use]
88    pub fn error_code(mut self, code: PluginErrorCode) -> Self {
89        self.error_codes.push(code);
90        self
91    }
92
93    #[must_use]
94    pub fn database_hook(mut self, hook: PluginDatabaseHook) -> Self {
95        self.database_hooks.push(hook);
96        self
97    }
98
99    #[must_use]
100    pub fn migration(mut self, migration: PluginMigration) -> Self {
101        self.migrations.push(migration);
102        self
103    }
104
105    #[must_use]
106    pub fn social_provider(mut self, provider: impl Into<Arc<dyn SocialOAuthProvider>>) -> Self {
107        self.social_providers.push(provider.into());
108        self
109    }
110
111    #[must_use]
112    pub fn user_additional_field(
113        mut self,
114        name: impl Into<String>,
115        field: UserAdditionalField,
116    ) -> Self {
117        self.user_additional_fields.insert(name.into(), field);
118        self
119    }
120
121    #[must_use]
122    pub fn session_additional_field(
123        mut self,
124        name: impl Into<String>,
125        field: SessionAdditionalField,
126    ) -> Self {
127        self.session_additional_fields.insert(name.into(), field);
128        self
129    }
130}