1use indexmap::IndexMap;
2use serde::{Deserialize, Serialize};
3
4use super::IdPolicy;
5use crate::error::OpenAuthError;
6
7#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
9pub enum RateLimitStorage {
10 #[default]
11 Memory,
12 Database,
13 SecondaryStorage,
14}
15
16#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
18pub struct TableOptions {
19 pub name: Option<String>,
20 pub field_names: IndexMap<String, String>,
21 pub additional_fields: IndexMap<String, DbField>,
22}
23
24impl TableOptions {
25 pub fn with_name(mut self, name: impl Into<String>) -> Self {
27 self.name = Some(name.into());
28 self
29 }
30
31 pub fn with_field_name(
33 mut self,
34 logical_name: impl Into<String>,
35 db_name: impl Into<String>,
36 ) -> Self {
37 self.field_names.insert(logical_name.into(), db_name.into());
38 self
39 }
40
41 pub fn with_field(mut self, logical_name: impl Into<String>, field: DbField) -> Self {
43 self.additional_fields.insert(logical_name.into(), field);
44 self
45 }
46
47 fn field_name(&self, logical_name: &str) -> String {
48 self.field_names
49 .get(logical_name)
50 .cloned()
51 .unwrap_or_else(|| logical_name.to_owned())
52 }
53}
54
55#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
57pub struct AuthSchemaOptions {
58 pub id_policy: IdPolicy,
59 pub user: TableOptions,
60 pub account: TableOptions,
61 pub session: TableOptions,
62 pub verification: TableOptions,
63 pub rate_limit: TableOptions,
64 pub has_secondary_storage: bool,
65 pub store_session_in_database: bool,
66 pub store_verification_in_database: bool,
67 pub rate_limit_storage: RateLimitStorage,
68}
69
70#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
72pub enum DbFieldType {
73 String,
74 Number,
75 Boolean,
76 Timestamp,
77 Json,
78 StringArray,
79 NumberArray,
80}
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
84pub enum OnDelete {
85 NoAction,
86 Restrict,
87 Cascade,
88 SetNull,
89 SetDefault,
90}
91
92#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
94pub struct ForeignKey {
95 pub table: String,
96 pub field: String,
97 pub on_delete: OnDelete,
98}
99
100impl ForeignKey {
101 pub fn new(table: impl Into<String>, field: impl Into<String>, on_delete: OnDelete) -> Self {
102 Self {
103 table: table.into(),
104 field: field.into(),
105 on_delete,
106 }
107 }
108}
109
110#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
112pub struct DbField {
113 pub name: String,
114 pub field_type: DbFieldType,
115 pub required: bool,
116 pub unique: bool,
117 pub index: bool,
118 pub returned: bool,
119 pub input: bool,
120 pub foreign_key: Option<ForeignKey>,
121}
122
123impl DbField {
124 pub fn new(name: impl Into<String>, field_type: DbFieldType) -> Self {
126 Self {
127 name: name.into(),
128 field_type,
129 required: true,
130 unique: false,
131 index: false,
132 returned: true,
133 input: true,
134 foreign_key: None,
135 }
136 }
137
138 pub fn optional(mut self) -> Self {
139 self.required = false;
140 self
141 }
142
143 pub fn unique(mut self) -> Self {
144 self.unique = true;
145 self
146 }
147
148 pub fn indexed(mut self) -> Self {
149 self.index = true;
150 self
151 }
152
153 pub fn hidden(mut self) -> Self {
154 self.returned = false;
155 self
156 }
157
158 pub fn generated(mut self) -> Self {
159 self.input = false;
160 self
161 }
162
163 pub fn references(mut self, foreign_key: ForeignKey) -> Self {
164 self.foreign_key = Some(foreign_key);
165 self
166 }
167}
168
169#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
171pub struct DbTable {
172 pub name: String,
173 pub fields: IndexMap<String, DbField>,
174 pub order: Option<u16>,
175}
176
177impl DbTable {
178 pub fn field(&self, logical_name: &str) -> Option<&DbField> {
179 self.fields.get(logical_name)
180 }
181}
182
183#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
185pub struct DbSchema {
186 tables: IndexMap<String, DbTable>,
187}
188
189impl DbSchema {
190 pub fn table(&self, logical_name: &str) -> Option<&DbTable> {
191 self.tables.get(logical_name)
192 }
193
194 pub fn table_name(&self, table: &str) -> Result<&str, OpenAuthError> {
196 self.resolve_table(table)
197 .map(|(_, table)| table.name.as_str())
198 .ok_or_else(|| OpenAuthError::TableNotFound {
199 table: table.to_owned(),
200 })
201 }
202
203 pub fn field_name(&self, table: &str, field: &str) -> Result<&str, OpenAuthError> {
205 self.field(table, field)
206 .map(|field| field.name.as_str())
207 .map_err(|_| OpenAuthError::FieldNotFound {
208 table: table.to_owned(),
209 field: field.to_owned(),
210 })
211 }
212
213 pub fn field(&self, table: &str, field: &str) -> Result<&DbField, OpenAuthError> {
215 let (_, table_metadata) =
216 self.resolve_table(table)
217 .ok_or_else(|| OpenAuthError::TableNotFound {
218 table: table.to_owned(),
219 })?;
220
221 table_metadata
222 .resolve_field(field)
223 .ok_or_else(|| OpenAuthError::FieldNotFound {
224 table: table.to_owned(),
225 field: field.to_owned(),
226 })
227 }
228
229 pub fn tables(&self) -> impl Iterator<Item = (&str, &DbTable)> {
230 self.tables
231 .iter()
232 .map(|(logical_name, table)| (logical_name.as_str(), table))
233 }
234
235 pub fn insert_plugin_table(
236 &mut self,
237 logical_name: String,
238 table: DbTable,
239 ) -> Result<(), OpenAuthError> {
240 if let Some(existing) = self.tables.get(&logical_name) {
241 if existing == &table {
242 return Ok(());
243 }
244 return Err(OpenAuthError::InvalidConfig(format!(
245 "plugin schema table `{logical_name}` conflicts with an existing table"
246 )));
247 }
248 if self
249 .tables
250 .values()
251 .any(|existing| existing.name == table.name)
252 {
253 return Err(OpenAuthError::InvalidConfig(format!(
254 "plugin schema table `{logical_name}` uses existing database table `{}`",
255 table.name
256 )));
257 }
258 self.tables.insert(logical_name, table);
259 Ok(())
260 }
261
262 pub fn insert_plugin_field(
263 &mut self,
264 table: &str,
265 logical_name: String,
266 field: DbField,
267 ) -> Result<(), OpenAuthError> {
268 let (_, table_metadata) =
269 self.resolve_table_mut(table)
270 .ok_or_else(|| OpenAuthError::TableNotFound {
271 table: table.to_owned(),
272 })?;
273
274 if let Some(existing) = table_metadata.fields.get(&logical_name) {
275 if existing == &field {
276 return Ok(());
277 }
278 return Err(OpenAuthError::InvalidConfig(format!(
279 "plugin schema field `{logical_name}` conflicts with table `{table}`"
280 )));
281 }
282 if table_metadata
283 .fields
284 .values()
285 .any(|existing| existing.name == field.name)
286 {
287 return Err(OpenAuthError::InvalidConfig(format!(
288 "plugin schema field `{logical_name}` uses existing database field `{}` on table `{table}`",
289 field.name
290 )));
291 }
292 table_metadata.fields.insert(logical_name, field);
293 Ok(())
294 }
295
296 fn resolve_table(&self, table: &str) -> Option<(&str, &DbTable)> {
297 self.tables
298 .get_key_value(table)
299 .map(|(logical_name, table)| (logical_name.as_str(), table))
300 .or_else(|| {
301 self.tables
302 .iter()
303 .find(|(_, table_metadata)| table_metadata.name == table)
304 .map(|(logical_name, table)| (logical_name.as_str(), table))
305 })
306 }
307
308 fn resolve_table_mut(&mut self, table: &str) -> Option<(&str, &mut DbTable)> {
309 if self.tables.contains_key(table) {
310 let (logical_name, table_metadata) = self.tables.get_key_value_mut(table)?;
311 return Some((logical_name.as_str(), table_metadata));
312 }
313 self.tables
314 .iter_mut()
315 .find(|(_, table_metadata)| table_metadata.name == table)
316 .map(|(logical_name, table)| (logical_name.as_str(), table))
317 }
318
319 fn insert(&mut self, logical_name: impl Into<String>, table: DbTable) {
320 self.tables.insert(logical_name.into(), table);
321 }
322}
323
324impl DbTable {
325 fn resolve_field(&self, field: &str) -> Option<&DbField> {
326 self.fields
327 .get(field)
328 .or_else(|| self.fields.values().find(|metadata| metadata.name == field))
329 }
330}
331
332pub fn auth_schema(options: AuthSchemaOptions) -> DbSchema {
334 let mut schema = DbSchema::default();
335 let user_table_name = table_name(&options.user, "users");
336
337 schema.insert(
338 "user",
339 table(
340 &options.user,
341 "users",
342 Some(1),
343 [
344 ("id", options.id_policy.field()),
345 ("name", field(&options.user, "name", DbFieldType::String)),
346 (
347 "email",
348 field(&options.user, "email", DbFieldType::String).unique(),
349 ),
350 (
351 "email_verified",
352 field(&options.user, "email_verified", DbFieldType::Boolean).generated(),
353 ),
354 (
355 "image",
356 field(&options.user, "image", DbFieldType::String).optional(),
357 ),
358 (
359 "created_at",
360 field(&options.user, "created_at", DbFieldType::Timestamp).generated(),
361 ),
362 (
363 "updated_at",
364 field(&options.user, "updated_at", DbFieldType::Timestamp).generated(),
365 ),
366 ],
367 ),
368 );
369
370 if !options.has_secondary_storage || options.store_session_in_database {
371 schema.insert(
372 "session",
373 table(
374 &options.session,
375 "sessions",
376 Some(2),
377 [
378 ("id", options.id_policy.field()),
379 (
380 "expires_at",
381 field(&options.session, "expires_at", DbFieldType::Timestamp),
382 ),
383 (
384 "token",
385 field(&options.session, "token", DbFieldType::String).unique(),
386 ),
387 (
388 "created_at",
389 field(&options.session, "created_at", DbFieldType::Timestamp).generated(),
390 ),
391 (
392 "updated_at",
393 field(&options.session, "updated_at", DbFieldType::Timestamp).generated(),
394 ),
395 (
396 "ip_address",
397 field(&options.session, "ip_address", DbFieldType::String).optional(),
398 ),
399 (
400 "user_agent",
401 field(&options.session, "user_agent", DbFieldType::String).optional(),
402 ),
403 (
404 "user_id",
405 field(&options.session, "user_id", DbFieldType::String)
406 .indexed()
407 .references(ForeignKey::new(
408 user_table_name.clone(),
409 "id",
410 OnDelete::Cascade,
411 )),
412 ),
413 ],
414 ),
415 );
416 }
417
418 schema.insert(
419 "account",
420 table(
421 &options.account,
422 "accounts",
423 Some(3),
424 [
425 ("id", options.id_policy.field()),
426 (
427 "account_id",
428 field(&options.account, "account_id", DbFieldType::String),
429 ),
430 (
431 "provider_id",
432 field(&options.account, "provider_id", DbFieldType::String),
433 ),
434 (
435 "user_id",
436 field(&options.account, "user_id", DbFieldType::String)
437 .indexed()
438 .references(ForeignKey::new(user_table_name, "id", OnDelete::Cascade)),
439 ),
440 (
441 "access_token",
442 field(&options.account, "access_token", DbFieldType::String)
443 .optional()
444 .hidden(),
445 ),
446 (
447 "refresh_token",
448 field(&options.account, "refresh_token", DbFieldType::String)
449 .optional()
450 .hidden(),
451 ),
452 (
453 "id_token",
454 field(&options.account, "id_token", DbFieldType::String)
455 .optional()
456 .hidden(),
457 ),
458 (
459 "access_token_expires_at",
460 field(
461 &options.account,
462 "access_token_expires_at",
463 DbFieldType::Timestamp,
464 )
465 .optional()
466 .hidden(),
467 ),
468 (
469 "refresh_token_expires_at",
470 field(
471 &options.account,
472 "refresh_token_expires_at",
473 DbFieldType::Timestamp,
474 )
475 .optional()
476 .hidden(),
477 ),
478 (
479 "scope",
480 field(&options.account, "scope", DbFieldType::String).optional(),
481 ),
482 (
483 "password",
484 field(&options.account, "password", DbFieldType::String)
485 .optional()
486 .hidden(),
487 ),
488 (
489 "created_at",
490 field(&options.account, "created_at", DbFieldType::Timestamp).generated(),
491 ),
492 (
493 "updated_at",
494 field(&options.account, "updated_at", DbFieldType::Timestamp).generated(),
495 ),
496 ],
497 ),
498 );
499
500 if !options.has_secondary_storage || options.store_verification_in_database {
501 schema.insert(
502 "verification",
503 table(
504 &options.verification,
505 "verifications",
506 Some(4),
507 [
508 ("id", options.id_policy.field()),
509 (
510 "identifier",
511 field(&options.verification, "identifier", DbFieldType::String).indexed(),
512 ),
513 (
514 "value",
515 field(&options.verification, "value", DbFieldType::String),
516 ),
517 (
518 "expires_at",
519 field(&options.verification, "expires_at", DbFieldType::Timestamp),
520 ),
521 (
522 "created_at",
523 field(&options.verification, "created_at", DbFieldType::Timestamp)
524 .generated(),
525 ),
526 (
527 "updated_at",
528 field(&options.verification, "updated_at", DbFieldType::Timestamp)
529 .generated(),
530 ),
531 ],
532 ),
533 );
534 }
535
536 if options.rate_limit_storage == RateLimitStorage::Database {
537 schema.insert(
538 "rate_limit",
539 table(
540 &options.rate_limit,
541 "rate_limits",
542 None,
543 [
544 (
545 "key",
546 field(&options.rate_limit, "key", DbFieldType::String).unique(),
547 ),
548 (
549 "count",
550 field(&options.rate_limit, "count", DbFieldType::Number),
551 ),
552 (
553 "last_request",
554 field(&options.rate_limit, "last_request", DbFieldType::Number),
555 ),
556 ],
557 ),
558 );
559 }
560
561 schema
562}
563
564fn table<const N: usize>(
565 options: &TableOptions,
566 default_name: &str,
567 order: Option<u16>,
568 fields: [(&str, DbField); N],
569) -> DbTable {
570 let mut mapped_fields = fields
571 .into_iter()
572 .map(|(logical_name, field)| (logical_name.to_owned(), field))
573 .collect::<IndexMap<_, _>>();
574 mapped_fields.extend(options.additional_fields.clone());
575
576 DbTable {
577 name: table_name(options, default_name),
578 fields: mapped_fields,
579 order,
580 }
581}
582
583fn table_name(options: &TableOptions, default_name: &str) -> String {
584 options
585 .name
586 .clone()
587 .unwrap_or_else(|| default_name.to_owned())
588}
589
590fn field(options: &TableOptions, logical_name: &str, field_type: DbFieldType) -> DbField {
591 DbField::new(options.field_name(logical_name), field_type)
592}