Skip to main content

openauth_core/db/
schema.rs

1use indexmap::IndexMap;
2use serde::{Deserialize, Serialize};
3
4use super::IdPolicy;
5use crate::error::OpenAuthError;
6
7/// Storage backend selected for rate limit counters.
8#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
9pub enum RateLimitStorage {
10    #[default]
11    Memory,
12    Database,
13    SecondaryStorage,
14}
15
16/// Per-table schema overrides.
17#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
18pub struct TableOptions {
19    pub name: Option<String>,
20    pub field_names: IndexMap<String, String>,
21    pub additional_fields: IndexMap<String, DbField>,
22}
23
24impl TableOptions {
25    /// Return a copy of these options with a custom database table name.
26    pub fn with_name(mut self, name: impl Into<String>) -> Self {
27        self.name = Some(name.into());
28        self
29    }
30
31    /// Return a copy of these options with a custom database column name.
32    pub fn with_field_name(
33        mut self,
34        logical_name: impl Into<String>,
35        db_name: impl Into<String>,
36    ) -> Self {
37        self.field_names.insert(logical_name.into(), db_name.into());
38        self
39    }
40
41    /// Return a copy of these options with an additional logical field.
42    pub fn with_field(mut self, logical_name: impl Into<String>, field: DbField) -> Self {
43        self.additional_fields.insert(logical_name.into(), field);
44        self
45    }
46
47    fn field_name(&self, logical_name: &str) -> String {
48        self.field_names
49            .get(logical_name)
50            .cloned()
51            .unwrap_or_else(|| logical_name.to_owned())
52    }
53}
54
55/// Options used to build OpenAuth's core database schema metadata.
56#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
57pub struct AuthSchemaOptions {
58    pub id_policy: IdPolicy,
59    pub user: TableOptions,
60    pub account: TableOptions,
61    pub session: TableOptions,
62    pub verification: TableOptions,
63    pub rate_limit: TableOptions,
64    pub has_secondary_storage: bool,
65    pub store_session_in_database: bool,
66    pub store_verification_in_database: bool,
67    pub rate_limit_storage: RateLimitStorage,
68}
69
70/// Supported database field kinds for core schema metadata.
71#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
72pub enum DbFieldType {
73    String,
74    Number,
75    Boolean,
76    Timestamp,
77    Json,
78    StringArray,
79    NumberArray,
80}
81
82/// Foreign key delete behavior.
83#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
84pub enum OnDelete {
85    NoAction,
86    Restrict,
87    Cascade,
88    SetNull,
89    SetDefault,
90}
91
92/// Foreign key metadata for adapter and migration implementations.
93#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
94pub struct ForeignKey {
95    pub table: String,
96    pub field: String,
97    pub on_delete: OnDelete,
98}
99
100impl ForeignKey {
101    pub fn new(table: impl Into<String>, field: impl Into<String>, on_delete: OnDelete) -> Self {
102        Self {
103            table: table.into(),
104            field: field.into(),
105            on_delete,
106        }
107    }
108}
109
110/// Field metadata used by adapters and migrations.
111#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
112pub struct DbField {
113    pub name: String,
114    pub field_type: DbFieldType,
115    pub required: bool,
116    pub unique: bool,
117    pub index: bool,
118    pub returned: bool,
119    pub input: bool,
120    pub foreign_key: Option<ForeignKey>,
121}
122
123impl DbField {
124    /// Create a required, returned, input-accepted field.
125    pub fn new(name: impl Into<String>, field_type: DbFieldType) -> Self {
126        Self {
127            name: name.into(),
128            field_type,
129            required: true,
130            unique: false,
131            index: false,
132            returned: true,
133            input: true,
134            foreign_key: None,
135        }
136    }
137
138    pub fn optional(mut self) -> Self {
139        self.required = false;
140        self
141    }
142
143    pub fn unique(mut self) -> Self {
144        self.unique = true;
145        self
146    }
147
148    pub fn indexed(mut self) -> Self {
149        self.index = true;
150        self
151    }
152
153    pub fn hidden(mut self) -> Self {
154        self.returned = false;
155        self
156    }
157
158    pub fn generated(mut self) -> Self {
159        self.input = false;
160        self
161    }
162
163    pub fn references(mut self, foreign_key: ForeignKey) -> Self {
164        self.foreign_key = Some(foreign_key);
165        self
166    }
167}
168
169/// Table metadata keyed by logical field name.
170#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
171pub struct DbTable {
172    pub name: String,
173    pub fields: IndexMap<String, DbField>,
174    pub order: Option<u16>,
175}
176
177impl DbTable {
178    pub fn field(&self, logical_name: &str) -> Option<&DbField> {
179        self.fields.get(logical_name)
180    }
181}
182
183/// Schema metadata keyed by logical table name.
184#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
185pub struct DbSchema {
186    tables: IndexMap<String, DbTable>,
187}
188
189impl DbSchema {
190    pub fn table(&self, logical_name: &str) -> Option<&DbTable> {
191        self.tables.get(logical_name)
192    }
193
194    /// Resolve a logical or physical table name to its physical database name.
195    pub fn table_name(&self, table: &str) -> Result<&str, OpenAuthError> {
196        self.resolve_table(table)
197            .map(|(_, table)| table.name.as_str())
198            .ok_or_else(|| OpenAuthError::TableNotFound {
199                table: table.to_owned(),
200            })
201    }
202
203    /// Resolve a logical or physical field name to its physical database column name.
204    pub fn field_name(&self, table: &str, field: &str) -> Result<&str, OpenAuthError> {
205        self.field(table, field)
206            .map(|field| field.name.as_str())
207            .map_err(|_| OpenAuthError::FieldNotFound {
208                table: table.to_owned(),
209                field: field.to_owned(),
210            })
211    }
212
213    /// Resolve field metadata from logical or physical table and field names.
214    pub fn field(&self, table: &str, field: &str) -> Result<&DbField, OpenAuthError> {
215        let (_, table_metadata) =
216            self.resolve_table(table)
217                .ok_or_else(|| OpenAuthError::TableNotFound {
218                    table: table.to_owned(),
219                })?;
220
221        table_metadata
222            .resolve_field(field)
223            .ok_or_else(|| OpenAuthError::FieldNotFound {
224                table: table.to_owned(),
225                field: field.to_owned(),
226            })
227    }
228
229    pub fn tables(&self) -> impl Iterator<Item = (&str, &DbTable)> {
230        self.tables
231            .iter()
232            .map(|(logical_name, table)| (logical_name.as_str(), table))
233    }
234
235    pub fn insert_plugin_table(
236        &mut self,
237        logical_name: String,
238        table: DbTable,
239    ) -> Result<(), OpenAuthError> {
240        if let Some(existing) = self.tables.get(&logical_name) {
241            if existing == &table {
242                return Ok(());
243            }
244            return Err(OpenAuthError::InvalidConfig(format!(
245                "plugin schema table `{logical_name}` conflicts with an existing table"
246            )));
247        }
248        if self
249            .tables
250            .values()
251            .any(|existing| existing.name == table.name)
252        {
253            return Err(OpenAuthError::InvalidConfig(format!(
254                "plugin schema table `{logical_name}` uses existing database table `{}`",
255                table.name
256            )));
257        }
258        self.tables.insert(logical_name, table);
259        Ok(())
260    }
261
262    pub fn insert_plugin_field(
263        &mut self,
264        table: &str,
265        logical_name: String,
266        field: DbField,
267    ) -> Result<(), OpenAuthError> {
268        let (_, table_metadata) =
269            self.resolve_table_mut(table)
270                .ok_or_else(|| OpenAuthError::TableNotFound {
271                    table: table.to_owned(),
272                })?;
273
274        if let Some(existing) = table_metadata.fields.get(&logical_name) {
275            if existing == &field {
276                return Ok(());
277            }
278            return Err(OpenAuthError::InvalidConfig(format!(
279                "plugin schema field `{logical_name}` conflicts with table `{table}`"
280            )));
281        }
282        if table_metadata
283            .fields
284            .values()
285            .any(|existing| existing.name == field.name)
286        {
287            return Err(OpenAuthError::InvalidConfig(format!(
288                "plugin schema field `{logical_name}` uses existing database field `{}` on table `{table}`",
289                field.name
290            )));
291        }
292        table_metadata.fields.insert(logical_name, field);
293        Ok(())
294    }
295
296    fn resolve_table(&self, table: &str) -> Option<(&str, &DbTable)> {
297        self.tables
298            .get_key_value(table)
299            .map(|(logical_name, table)| (logical_name.as_str(), table))
300            .or_else(|| {
301                self.tables
302                    .iter()
303                    .find(|(_, table_metadata)| table_metadata.name == table)
304                    .map(|(logical_name, table)| (logical_name.as_str(), table))
305            })
306    }
307
308    fn resolve_table_mut(&mut self, table: &str) -> Option<(&str, &mut DbTable)> {
309        if self.tables.contains_key(table) {
310            let (logical_name, table_metadata) = self.tables.get_key_value_mut(table)?;
311            return Some((logical_name.as_str(), table_metadata));
312        }
313        self.tables
314            .iter_mut()
315            .find(|(_, table_metadata)| table_metadata.name == table)
316            .map(|(logical_name, table)| (logical_name.as_str(), table))
317    }
318
319    fn insert(&mut self, logical_name: impl Into<String>, table: DbTable) {
320        self.tables.insert(logical_name.into(), table);
321    }
322}
323
324impl DbTable {
325    fn resolve_field(&self, field: &str) -> Option<&DbField> {
326        self.fields
327            .get(field)
328            .or_else(|| self.fields.values().find(|metadata| metadata.name == field))
329    }
330}
331
332/// Build OpenAuth's core database schema metadata.
333pub fn auth_schema(options: AuthSchemaOptions) -> DbSchema {
334    let mut schema = DbSchema::default();
335    let user_table_name = table_name(&options.user, "users");
336
337    schema.insert(
338        "user",
339        table(
340            &options.user,
341            "users",
342            Some(1),
343            [
344                ("id", options.id_policy.field()),
345                ("name", field(&options.user, "name", DbFieldType::String)),
346                (
347                    "email",
348                    field(&options.user, "email", DbFieldType::String).unique(),
349                ),
350                (
351                    "email_verified",
352                    field(&options.user, "email_verified", DbFieldType::Boolean).generated(),
353                ),
354                (
355                    "image",
356                    field(&options.user, "image", DbFieldType::String).optional(),
357                ),
358                (
359                    "created_at",
360                    field(&options.user, "created_at", DbFieldType::Timestamp).generated(),
361                ),
362                (
363                    "updated_at",
364                    field(&options.user, "updated_at", DbFieldType::Timestamp).generated(),
365                ),
366            ],
367        ),
368    );
369
370    if !options.has_secondary_storage || options.store_session_in_database {
371        schema.insert(
372            "session",
373            table(
374                &options.session,
375                "sessions",
376                Some(2),
377                [
378                    ("id", options.id_policy.field()),
379                    (
380                        "expires_at",
381                        field(&options.session, "expires_at", DbFieldType::Timestamp),
382                    ),
383                    (
384                        "token",
385                        field(&options.session, "token", DbFieldType::String).unique(),
386                    ),
387                    (
388                        "created_at",
389                        field(&options.session, "created_at", DbFieldType::Timestamp).generated(),
390                    ),
391                    (
392                        "updated_at",
393                        field(&options.session, "updated_at", DbFieldType::Timestamp).generated(),
394                    ),
395                    (
396                        "ip_address",
397                        field(&options.session, "ip_address", DbFieldType::String).optional(),
398                    ),
399                    (
400                        "user_agent",
401                        field(&options.session, "user_agent", DbFieldType::String).optional(),
402                    ),
403                    (
404                        "user_id",
405                        field(&options.session, "user_id", DbFieldType::String)
406                            .indexed()
407                            .references(ForeignKey::new(
408                                user_table_name.clone(),
409                                "id",
410                                OnDelete::Cascade,
411                            )),
412                    ),
413                ],
414            ),
415        );
416    }
417
418    schema.insert(
419        "account",
420        table(
421            &options.account,
422            "accounts",
423            Some(3),
424            [
425                ("id", options.id_policy.field()),
426                (
427                    "account_id",
428                    field(&options.account, "account_id", DbFieldType::String),
429                ),
430                (
431                    "provider_id",
432                    field(&options.account, "provider_id", DbFieldType::String),
433                ),
434                (
435                    "user_id",
436                    field(&options.account, "user_id", DbFieldType::String)
437                        .indexed()
438                        .references(ForeignKey::new(user_table_name, "id", OnDelete::Cascade)),
439                ),
440                (
441                    "access_token",
442                    field(&options.account, "access_token", DbFieldType::String)
443                        .optional()
444                        .hidden(),
445                ),
446                (
447                    "refresh_token",
448                    field(&options.account, "refresh_token", DbFieldType::String)
449                        .optional()
450                        .hidden(),
451                ),
452                (
453                    "id_token",
454                    field(&options.account, "id_token", DbFieldType::String)
455                        .optional()
456                        .hidden(),
457                ),
458                (
459                    "access_token_expires_at",
460                    field(
461                        &options.account,
462                        "access_token_expires_at",
463                        DbFieldType::Timestamp,
464                    )
465                    .optional()
466                    .hidden(),
467                ),
468                (
469                    "refresh_token_expires_at",
470                    field(
471                        &options.account,
472                        "refresh_token_expires_at",
473                        DbFieldType::Timestamp,
474                    )
475                    .optional()
476                    .hidden(),
477                ),
478                (
479                    "scope",
480                    field(&options.account, "scope", DbFieldType::String).optional(),
481                ),
482                (
483                    "password",
484                    field(&options.account, "password", DbFieldType::String)
485                        .optional()
486                        .hidden(),
487                ),
488                (
489                    "created_at",
490                    field(&options.account, "created_at", DbFieldType::Timestamp).generated(),
491                ),
492                (
493                    "updated_at",
494                    field(&options.account, "updated_at", DbFieldType::Timestamp).generated(),
495                ),
496            ],
497        ),
498    );
499
500    if !options.has_secondary_storage || options.store_verification_in_database {
501        schema.insert(
502            "verification",
503            table(
504                &options.verification,
505                "verifications",
506                Some(4),
507                [
508                    ("id", options.id_policy.field()),
509                    (
510                        "identifier",
511                        field(&options.verification, "identifier", DbFieldType::String).indexed(),
512                    ),
513                    (
514                        "value",
515                        field(&options.verification, "value", DbFieldType::String),
516                    ),
517                    (
518                        "expires_at",
519                        field(&options.verification, "expires_at", DbFieldType::Timestamp),
520                    ),
521                    (
522                        "created_at",
523                        field(&options.verification, "created_at", DbFieldType::Timestamp)
524                            .generated(),
525                    ),
526                    (
527                        "updated_at",
528                        field(&options.verification, "updated_at", DbFieldType::Timestamp)
529                            .generated(),
530                    ),
531                ],
532            ),
533        );
534    }
535
536    if options.rate_limit_storage == RateLimitStorage::Database {
537        schema.insert(
538            "rate_limit",
539            table(
540                &options.rate_limit,
541                "rate_limits",
542                None,
543                [
544                    (
545                        "key",
546                        field(&options.rate_limit, "key", DbFieldType::String).unique(),
547                    ),
548                    (
549                        "count",
550                        field(&options.rate_limit, "count", DbFieldType::Number),
551                    ),
552                    (
553                        "last_request",
554                        field(&options.rate_limit, "last_request", DbFieldType::Number),
555                    ),
556                ],
557            ),
558        );
559    }
560
561    schema
562}
563
564fn table<const N: usize>(
565    options: &TableOptions,
566    default_name: &str,
567    order: Option<u16>,
568    fields: [(&str, DbField); N],
569) -> DbTable {
570    let mut mapped_fields = fields
571        .into_iter()
572        .map(|(logical_name, field)| (logical_name.to_owned(), field))
573        .collect::<IndexMap<_, _>>();
574    mapped_fields.extend(options.additional_fields.clone());
575
576    DbTable {
577        name: table_name(options, default_name),
578        fields: mapped_fields,
579        order,
580    }
581}
582
583fn table_name(options: &TableOptions, default_name: &str) -> String {
584    options
585        .name
586        .clone()
587        .unwrap_or_else(|| default_name.to_owned())
588}
589
590fn field(options: &TableOptions, logical_name: &str, field_type: DbFieldType) -> DbField {
591    DbField::new(options.field_name(logical_name), field_type)
592}