openauth_plugins/additional_fields/
mod.rs1use std::collections::BTreeMap;
4
5use openauth_core::db::{DbField, DbFieldType, DbValue};
6use openauth_core::options::{SessionAdditionalField, UserAdditionalField};
7use openauth_core::plugin::{AuthPlugin, PluginInitOutput, PluginSchemaContribution};
8
9pub const UPSTREAM_PLUGIN_ID: &str = "additional-fields";
10
11#[derive(Debug, Clone, Default, PartialEq)]
12pub struct AdditionalFieldsOptions {
13 pub user: BTreeMap<String, AdditionalField>,
14 pub session: BTreeMap<String, AdditionalField>,
15}
16
17impl AdditionalFieldsOptions {
18 pub fn new() -> Self {
19 Self::default()
20 }
21
22 #[must_use]
23 pub fn user_field(mut self, name: impl Into<String>, field: AdditionalField) -> Self {
24 self.user.insert(name.into(), field);
25 self
26 }
27
28 #[must_use]
29 pub fn session_field(mut self, name: impl Into<String>, field: AdditionalField) -> Self {
30 self.session.insert(name.into(), field);
31 self
32 }
33}
34
35#[derive(Debug, Clone, PartialEq)]
36pub struct AdditionalField {
37 pub field_type: DbFieldType,
38 pub required: bool,
39 pub input: bool,
40 pub returned: bool,
41 pub unique: bool,
42 pub index: bool,
43 pub default_value: Option<DbValue>,
44 pub db_name: Option<String>,
45}
46
47impl AdditionalField {
48 pub fn new(field_type: DbFieldType) -> Self {
49 Self {
50 field_type,
51 required: true,
52 input: true,
53 returned: true,
54 unique: false,
55 index: false,
56 default_value: None,
57 db_name: None,
58 }
59 }
60
61 #[must_use]
62 pub fn optional(mut self) -> Self {
63 self.required = false;
64 self
65 }
66
67 #[must_use]
68 pub fn generated(mut self) -> Self {
69 self.input = false;
70 self
71 }
72
73 #[must_use]
74 pub fn hidden(mut self) -> Self {
75 self.returned = false;
76 self
77 }
78
79 #[must_use]
80 pub fn unique(mut self) -> Self {
81 self.unique = true;
82 self
83 }
84
85 #[must_use]
86 pub fn indexed(mut self) -> Self {
87 self.index = true;
88 self
89 }
90
91 #[must_use]
92 pub fn default_value(mut self, value: DbValue) -> Self {
93 self.default_value = Some(value);
94 self
95 }
96
97 #[must_use]
98 pub fn db_name(mut self, db_name: impl Into<String>) -> Self {
99 self.db_name = Some(db_name.into());
100 self
101 }
102}
103
104pub fn additional_fields(options: AdditionalFieldsOptions) -> AuthPlugin {
105 AuthPlugin::new(UPSTREAM_PLUGIN_ID).with_init(move |_context| {
106 let mut output = PluginInitOutput::new();
107 for (name, field) in &options.user {
108 output = output
109 .schema(PluginSchemaContribution::field(
110 "user",
111 name.clone(),
112 field.schema_field(name),
113 ))
114 .user_additional_field(name.clone(), field.user_runtime_field());
115 }
116 for (name, field) in &options.session {
117 output = output
118 .schema(PluginSchemaContribution::field(
119 "session",
120 name.clone(),
121 field.schema_field(name),
122 ))
123 .session_additional_field(name.clone(), field.session_runtime_field());
124 }
125 Ok(output)
126 })
127}
128
129impl AdditionalField {
130 fn schema_field(&self, logical_name: &str) -> DbField {
131 let mut field = DbField::new(
132 self.db_name
133 .clone()
134 .unwrap_or_else(|| logical_name.to_owned()),
135 self.field_type.clone(),
136 );
137 if !self.required {
138 field = field.optional();
139 }
140 if self.unique {
141 field = field.unique();
142 }
143 if self.index {
144 field = field.indexed();
145 }
146 if !self.returned {
147 field = field.hidden();
148 }
149 if !self.input {
150 field = field.generated();
151 }
152 field
153 }
154
155 fn user_runtime_field(&self) -> UserAdditionalField {
156 let mut field = UserAdditionalField::new(self.field_type.clone());
157 field.required = self.required;
158 field.input = self.input;
159 field.returned = self.returned;
160 field.default_value = self.default_value.clone();
161 field.db_name = self.db_name.clone();
162 field
163 }
164
165 fn session_runtime_field(&self) -> SessionAdditionalField {
166 let mut field = SessionAdditionalField::new(self.field_type.clone());
167 field.required = self.required;
168 field.input = self.input;
169 field.returned = self.returned;
170 field.default_value = self.default_value.clone();
171 field.db_name = self.db_name.clone();
172 field
173 }
174}