Skip to main content

openauth_plugins/additional_fields/
mod.rs

1//! Additional fields plugin.
2
3use std::collections::BTreeMap;
4
5use openauth_core::db::{DbField, DbFieldType, DbValue};
6use openauth_core::options::{SessionAdditionalField, UserAdditionalField};
7use openauth_core::plugin::{AuthPlugin, PluginInitOutput, PluginSchemaContribution};
8
9pub const UPSTREAM_PLUGIN_ID: &str = "additional-fields";
10
11#[derive(Debug, Clone, Default, PartialEq)]
12pub struct AdditionalFieldsOptions {
13    pub user: BTreeMap<String, AdditionalField>,
14    pub session: BTreeMap<String, AdditionalField>,
15}
16
17impl AdditionalFieldsOptions {
18    pub fn new() -> Self {
19        Self::default()
20    }
21
22    #[must_use]
23    pub fn user_field(mut self, name: impl Into<String>, field: AdditionalField) -> Self {
24        self.user.insert(name.into(), field);
25        self
26    }
27
28    #[must_use]
29    pub fn session_field(mut self, name: impl Into<String>, field: AdditionalField) -> Self {
30        self.session.insert(name.into(), field);
31        self
32    }
33}
34
35#[derive(Debug, Clone, PartialEq)]
36pub struct AdditionalField {
37    pub field_type: DbFieldType,
38    pub required: bool,
39    pub input: bool,
40    pub returned: bool,
41    pub unique: bool,
42    pub index: bool,
43    pub default_value: Option<DbValue>,
44    pub db_name: Option<String>,
45}
46
47impl AdditionalField {
48    pub fn new(field_type: DbFieldType) -> Self {
49        Self {
50            field_type,
51            required: true,
52            input: true,
53            returned: true,
54            unique: false,
55            index: false,
56            default_value: None,
57            db_name: None,
58        }
59    }
60
61    #[must_use]
62    pub fn optional(mut self) -> Self {
63        self.required = false;
64        self
65    }
66
67    #[must_use]
68    pub fn generated(mut self) -> Self {
69        self.input = false;
70        self
71    }
72
73    #[must_use]
74    pub fn hidden(mut self) -> Self {
75        self.returned = false;
76        self
77    }
78
79    #[must_use]
80    pub fn unique(mut self) -> Self {
81        self.unique = true;
82        self
83    }
84
85    #[must_use]
86    pub fn indexed(mut self) -> Self {
87        self.index = true;
88        self
89    }
90
91    #[must_use]
92    pub fn default_value(mut self, value: DbValue) -> Self {
93        self.default_value = Some(value);
94        self
95    }
96
97    #[must_use]
98    pub fn db_name(mut self, db_name: impl Into<String>) -> Self {
99        self.db_name = Some(db_name.into());
100        self
101    }
102}
103
104pub fn additional_fields(options: AdditionalFieldsOptions) -> AuthPlugin {
105    AuthPlugin::new(UPSTREAM_PLUGIN_ID).with_init(move |_context| {
106        let mut output = PluginInitOutput::new();
107        for (name, field) in &options.user {
108            output = output
109                .schema(PluginSchemaContribution::field(
110                    "user",
111                    name.clone(),
112                    field.schema_field(name),
113                ))
114                .user_additional_field(name.clone(), field.user_runtime_field());
115        }
116        for (name, field) in &options.session {
117            output = output
118                .schema(PluginSchemaContribution::field(
119                    "session",
120                    name.clone(),
121                    field.schema_field(name),
122                ))
123                .session_additional_field(name.clone(), field.session_runtime_field());
124        }
125        Ok(output)
126    })
127}
128
129impl AdditionalField {
130    fn schema_field(&self, logical_name: &str) -> DbField {
131        let mut field = DbField::new(
132            self.db_name
133                .clone()
134                .unwrap_or_else(|| logical_name.to_owned()),
135            self.field_type.clone(),
136        );
137        if !self.required {
138            field = field.optional();
139        }
140        if self.unique {
141            field = field.unique();
142        }
143        if self.index {
144            field = field.indexed();
145        }
146        if !self.returned {
147            field = field.hidden();
148        }
149        if !self.input {
150            field = field.generated();
151        }
152        field
153    }
154
155    fn user_runtime_field(&self) -> UserAdditionalField {
156        let mut field = UserAdditionalField::new(self.field_type.clone());
157        field.required = self.required;
158        field.input = self.input;
159        field.returned = self.returned;
160        field.default_value = self.default_value.clone();
161        field.db_name = self.db_name.clone();
162        field
163    }
164
165    fn session_runtime_field(&self) -> SessionAdditionalField {
166        let mut field = SessionAdditionalField::new(self.field_type.clone());
167        field.required = self.required;
168        field.input = self.input;
169        field.returned = self.returned;
170        field.default_value = self.default_value.clone();
171        field.db_name = self.db_name.clone();
172        field
173    }
174}