Skip to main content

openauth_core/plugin/
init.rs

1//! Plugin initialization contracts.
2
3use super::db::{PluginDatabaseHook, PluginMigration};
4use super::error::PluginErrorCode;
5use super::rate_limit::PluginRateLimitRule;
6use super::schema::PluginSchemaContribution;
7use crate::context::AuthContext;
8use crate::error::OpenAuthError;
9use crate::options::{SessionAdditionalField, UserAdditionalField};
10#[cfg(feature = "oauth")]
11use openauth_oauth::oauth2::SocialOAuthProvider;
12use std::collections::BTreeMap;
13use std::fmt;
14use std::sync::Arc;
15
16pub type PluginInitHandler =
17    Arc<dyn Fn(&AuthContext) -> Result<PluginInitOutput, OpenAuthError> + Send + Sync>;
18
19/// Typed, additive output from a plugin init handler.
20#[derive(Clone, Default)]
21pub struct PluginInitOutput {
22    pub trusted_origins: Vec<String>,
23    pub disabled_paths: Vec<String>,
24    pub schema: Vec<PluginSchemaContribution>,
25    pub rate_limit: Vec<PluginRateLimitRule>,
26    pub error_codes: Vec<PluginErrorCode>,
27    pub database_hooks: Vec<PluginDatabaseHook>,
28    pub migrations: Vec<PluginMigration>,
29    #[cfg(feature = "oauth")]
30    pub social_providers: Vec<Arc<dyn SocialOAuthProvider>>,
31    pub user_additional_fields: BTreeMap<String, UserAdditionalField>,
32    pub session_additional_fields: BTreeMap<String, SessionAdditionalField>,
33}
34
35impl fmt::Debug for PluginInitOutput {
36    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
37        formatter
38            .debug_struct("PluginInitOutput")
39            .field("trusted_origins", &self.trusted_origins)
40            .field("disabled_paths", &self.disabled_paths)
41            .field("schema", &self.schema)
42            .field("rate_limit", &self.rate_limit)
43            .field("error_codes", &self.error_codes)
44            .field("database_hooks", &self.database_hooks)
45            .field("migrations", &self.migrations)
46            .field("user_additional_fields", &self.user_additional_fields)
47            .field("session_additional_fields", &self.session_additional_fields)
48            .field("social_providers", &debug_social_providers(self))
49            .finish()
50    }
51}
52
53impl PluginInitOutput {
54    pub fn new() -> Self {
55        Self::default()
56    }
57
58    #[must_use]
59    pub fn trusted_origin(mut self, origin: impl Into<String>) -> Self {
60        self.trusted_origins.push(origin.into());
61        self
62    }
63
64    #[must_use]
65    pub fn disabled_path(mut self, path: impl Into<String>) -> Self {
66        self.disabled_paths.push(path.into());
67        self
68    }
69
70    #[must_use]
71    pub fn schema(mut self, contribution: PluginSchemaContribution) -> Self {
72        self.schema.push(contribution);
73        self
74    }
75
76    #[must_use]
77    pub fn rate_limit(mut self, rule: PluginRateLimitRule) -> Self {
78        self.rate_limit.push(rule);
79        self
80    }
81
82    #[must_use]
83    pub fn error_code(mut self, code: PluginErrorCode) -> Self {
84        self.error_codes.push(code);
85        self
86    }
87
88    #[must_use]
89    pub fn database_hook(mut self, hook: PluginDatabaseHook) -> Self {
90        self.database_hooks.push(hook);
91        self
92    }
93
94    #[must_use]
95    pub fn migration(mut self, migration: PluginMigration) -> Self {
96        self.migrations.push(migration);
97        self
98    }
99
100    #[cfg(feature = "oauth")]
101    #[must_use]
102    pub fn social_provider(mut self, provider: impl Into<Arc<dyn SocialOAuthProvider>>) -> Self {
103        self.social_providers.push(provider.into());
104        self
105    }
106
107    #[must_use]
108    pub fn user_additional_field(
109        mut self,
110        name: impl Into<String>,
111        field: UserAdditionalField,
112    ) -> Self {
113        self.user_additional_fields.insert(name.into(), field);
114        self
115    }
116
117    #[must_use]
118    pub fn session_additional_field(
119        mut self,
120        name: impl Into<String>,
121        field: SessionAdditionalField,
122    ) -> Self {
123        self.session_additional_fields.insert(name.into(), field);
124        self
125    }
126}
127
128#[cfg(feature = "oauth")]
129fn debug_social_providers(output: &PluginInitOutput) -> Vec<&str> {
130    output
131        .social_providers
132        .iter()
133        .map(|provider| provider.id())
134        .collect()
135}
136
137#[cfg(not(feature = "oauth"))]
138fn debug_social_providers(_output: &PluginInitOutput) -> Vec<&'static str> {
139    Vec::new()
140}