1use indexmap::IndexMap;
2use serde::{Deserialize, Serialize};
3
4use super::{IdGeneration, IdPolicy};
5use crate::error::OpenAuthError;
6
7mod builder;
8pub use builder::auth_schema;
9
10#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
12pub enum RateLimitStorage {
13 #[default]
14 Memory,
15 Database,
16 SecondaryStorage,
17}
18
19#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
21pub struct TableOptions {
22 pub name: Option<String>,
23 pub field_names: IndexMap<String, String>,
24 pub additional_fields: IndexMap<String, DbField>,
25}
26
27impl TableOptions {
28 pub fn with_name(mut self, name: impl Into<String>) -> Self {
30 self.name = Some(name.into());
31 self
32 }
33
34 pub fn with_field_name(
36 mut self,
37 logical_name: impl Into<String>,
38 db_name: impl Into<String>,
39 ) -> Self {
40 self.field_names.insert(logical_name.into(), db_name.into());
41 self
42 }
43
44 pub fn with_field(mut self, logical_name: impl Into<String>, field: DbField) -> Self {
46 self.additional_fields.insert(logical_name.into(), field);
47 self
48 }
49
50 fn field_name(&self, logical_name: &str) -> String {
51 self.field_names
52 .get(logical_name)
53 .cloned()
54 .unwrap_or_else(|| logical_name.to_owned())
55 }
56}
57
58#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
60pub struct AuthSchemaOptions {
61 pub id_policy: IdPolicy,
62 pub user: TableOptions,
63 pub account: TableOptions,
64 pub session: TableOptions,
65 pub verification: TableOptions,
66 pub rate_limit: TableOptions,
67 pub has_secondary_storage: bool,
68 pub store_session_in_database: bool,
69 pub store_verification_in_database: bool,
70 pub rate_limit_storage: RateLimitStorage,
71}
72
73#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
75pub enum DbFieldType {
76 String,
77 Number,
78 Boolean,
79 Timestamp,
80 Json,
81 StringArray,
82 NumberArray,
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
87pub enum OnDelete {
88 NoAction,
89 Restrict,
90 Cascade,
91 SetNull,
92 SetDefault,
93}
94
95#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
97pub struct ForeignKey {
98 pub table: String,
99 pub field: String,
100 pub on_delete: OnDelete,
101}
102
103impl ForeignKey {
104 pub fn new(table: impl Into<String>, field: impl Into<String>, on_delete: OnDelete) -> Self {
105 Self {
106 table: table.into(),
107 field: field.into(),
108 on_delete,
109 }
110 }
111}
112
113#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
115pub struct DbField {
116 pub name: String,
117 pub field_type: DbFieldType,
118 pub required: bool,
119 pub unique: bool,
120 pub index: bool,
121 pub returned: bool,
122 pub input: bool,
123 pub foreign_key: Option<ForeignKey>,
124 #[serde(default)]
125 pub generated_id: Option<IdGeneration>,
126}
127
128impl DbField {
129 pub fn new(name: impl Into<String>, field_type: DbFieldType) -> Self {
131 Self {
132 name: name.into(),
133 field_type,
134 required: true,
135 unique: false,
136 index: false,
137 returned: true,
138 input: true,
139 foreign_key: None,
140 generated_id: None,
141 }
142 }
143
144 pub fn optional(mut self) -> Self {
145 self.required = false;
146 self
147 }
148
149 pub fn unique(mut self) -> Self {
150 self.unique = true;
151 self
152 }
153
154 pub fn indexed(mut self) -> Self {
155 self.index = true;
156 self
157 }
158
159 pub fn hidden(mut self) -> Self {
160 self.returned = false;
161 self
162 }
163
164 pub fn generated(mut self) -> Self {
165 self.input = false;
166 self
167 }
168
169 pub fn generated_id(mut self, generation: IdGeneration) -> Self {
170 self.generated_id = Some(generation);
171 self.generated()
172 }
173
174 pub fn references(mut self, foreign_key: ForeignKey) -> Self {
175 self.foreign_key = Some(foreign_key);
176 self
177 }
178}
179
180#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
182pub struct DbTable {
183 pub name: String,
184 pub fields: IndexMap<String, DbField>,
185 pub order: Option<u16>,
186}
187
188impl DbTable {
189 pub fn field(&self, logical_name: &str) -> Option<&DbField> {
190 self.fields.get(logical_name)
191 }
192}
193
194#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
196pub struct DbSchema {
197 tables: IndexMap<String, DbTable>,
198}
199
200impl DbSchema {
201 pub fn table(&self, logical_name: &str) -> Option<&DbTable> {
202 self.tables.get(logical_name)
203 }
204
205 pub fn table_name(&self, table: &str) -> Result<&str, OpenAuthError> {
207 self.resolve_table(table)
208 .map(|(_, table)| table.name.as_str())
209 .ok_or_else(|| OpenAuthError::TableNotFound {
210 table: table.to_owned(),
211 })
212 }
213
214 pub fn field_name(&self, table: &str, field: &str) -> Result<&str, OpenAuthError> {
216 self.field(table, field)
217 .map(|field| field.name.as_str())
218 .map_err(|_| OpenAuthError::FieldNotFound {
219 table: table.to_owned(),
220 field: field.to_owned(),
221 })
222 }
223
224 pub fn field(&self, table: &str, field: &str) -> Result<&DbField, OpenAuthError> {
226 let (_, table_metadata) =
227 self.resolve_table(table)
228 .ok_or_else(|| OpenAuthError::TableNotFound {
229 table: table.to_owned(),
230 })?;
231
232 table_metadata
233 .resolve_field(field)
234 .ok_or_else(|| OpenAuthError::FieldNotFound {
235 table: table.to_owned(),
236 field: field.to_owned(),
237 })
238 }
239
240 pub fn tables(&self) -> impl Iterator<Item = (&str, &DbTable)> {
241 self.tables
242 .iter()
243 .map(|(logical_name, table)| (logical_name.as_str(), table))
244 }
245
246 pub fn insert_plugin_table(
247 &mut self,
248 logical_name: String,
249 table: DbTable,
250 ) -> Result<(), OpenAuthError> {
251 if let Some(existing) = self.tables.get(&logical_name) {
252 if existing == &table {
253 return Ok(());
254 }
255 return Err(OpenAuthError::InvalidConfig(format!(
256 "plugin schema table `{logical_name}` conflicts with an existing table"
257 )));
258 }
259 if self
260 .tables
261 .values()
262 .any(|existing| existing.name == table.name)
263 {
264 return Err(OpenAuthError::InvalidConfig(format!(
265 "plugin schema table `{logical_name}` uses existing database table `{}`",
266 table.name
267 )));
268 }
269 self.tables.insert(logical_name, table);
270 Ok(())
271 }
272
273 pub fn insert_plugin_field(
274 &mut self,
275 table: &str,
276 logical_name: String,
277 field: DbField,
278 ) -> Result<(), OpenAuthError> {
279 let (_, table_metadata) =
280 self.resolve_table_mut(table)
281 .ok_or_else(|| OpenAuthError::TableNotFound {
282 table: table.to_owned(),
283 })?;
284
285 if let Some(existing) = table_metadata.fields.get(&logical_name) {
286 if existing == &field {
287 return Ok(());
288 }
289 return Err(OpenAuthError::InvalidConfig(format!(
290 "plugin schema field `{logical_name}` conflicts with table `{table}`"
291 )));
292 }
293 if table_metadata
294 .fields
295 .values()
296 .any(|existing| existing.name == field.name)
297 {
298 return Err(OpenAuthError::InvalidConfig(format!(
299 "plugin schema field `{logical_name}` uses existing database field `{}` on table `{table}`",
300 field.name
301 )));
302 }
303 table_metadata.fields.insert(logical_name, field);
304 Ok(())
305 }
306
307 fn resolve_table(&self, table: &str) -> Option<(&str, &DbTable)> {
308 self.tables
309 .get_key_value(table)
310 .map(|(logical_name, table)| (logical_name.as_str(), table))
311 .or_else(|| {
312 self.tables
313 .iter()
314 .find(|(_, table_metadata)| table_metadata.name == table)
315 .map(|(logical_name, table)| (logical_name.as_str(), table))
316 })
317 }
318
319 fn resolve_table_mut(&mut self, table: &str) -> Option<(&str, &mut DbTable)> {
320 if self.tables.contains_key(table) {
321 let (logical_name, table_metadata) = self.tables.get_key_value_mut(table)?;
322 return Some((logical_name.as_str(), table_metadata));
323 }
324 self.tables
325 .iter_mut()
326 .find(|(_, table_metadata)| table_metadata.name == table)
327 .map(|(logical_name, table)| (logical_name.as_str(), table))
328 }
329
330 fn insert(&mut self, logical_name: impl Into<String>, table: DbTable) {
331 self.tables.insert(logical_name.into(), table);
332 }
333}
334
335impl DbTable {
336 fn resolve_field(&self, field: &str) -> Option<&DbField> {
337 self.fields
338 .get(field)
339 .or_else(|| self.fields.values().find(|metadata| metadata.name == field))
340 }
341}