use indexmap::IndexMap;
use serde::{Deserialize, Serialize};
use super::{IdGeneration, IdPolicy};
use crate::error::OpenAuthError;
mod builder;
pub use builder::auth_schema;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum RateLimitStorage {
#[default]
Memory,
Database,
SecondaryStorage,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct TableOptions {
pub name: Option<String>,
pub field_names: IndexMap<String, String>,
pub additional_fields: IndexMap<String, DbField>,
}
impl TableOptions {
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn with_field_name(
mut self,
logical_name: impl Into<String>,
db_name: impl Into<String>,
) -> Self {
self.field_names.insert(logical_name.into(), db_name.into());
self
}
pub fn with_field(mut self, logical_name: impl Into<String>, field: DbField) -> Self {
self.additional_fields.insert(logical_name.into(), field);
self
}
fn field_name(&self, logical_name: &str) -> String {
self.field_names
.get(logical_name)
.cloned()
.unwrap_or_else(|| logical_name.to_owned())
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct AuthSchemaOptions {
pub id_policy: IdPolicy,
pub user: TableOptions,
pub account: TableOptions,
pub session: TableOptions,
pub verification: TableOptions,
pub rate_limit: TableOptions,
pub has_secondary_storage: bool,
pub store_session_in_database: bool,
pub store_verification_in_database: bool,
pub rate_limit_storage: RateLimitStorage,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum DbFieldType {
String,
Number,
Boolean,
Timestamp,
Json,
StringArray,
NumberArray,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum OnDelete {
NoAction,
Restrict,
Cascade,
SetNull,
SetDefault,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ForeignKey {
pub table: String,
pub field: String,
pub on_delete: OnDelete,
}
impl ForeignKey {
pub fn new(table: impl Into<String>, field: impl Into<String>, on_delete: OnDelete) -> Self {
Self {
table: table.into(),
field: field.into(),
on_delete,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct DbField {
pub name: String,
pub field_type: DbFieldType,
pub required: bool,
pub unique: bool,
pub index: bool,
pub returned: bool,
pub input: bool,
pub foreign_key: Option<ForeignKey>,
#[serde(default)]
pub generated_id: Option<IdGeneration>,
}
impl DbField {
pub fn new(name: impl Into<String>, field_type: DbFieldType) -> Self {
Self {
name: name.into(),
field_type,
required: true,
unique: false,
index: false,
returned: true,
input: true,
foreign_key: None,
generated_id: None,
}
}
pub fn optional(mut self) -> Self {
self.required = false;
self
}
pub fn unique(mut self) -> Self {
self.unique = true;
self
}
pub fn indexed(mut self) -> Self {
self.index = true;
self
}
pub fn hidden(mut self) -> Self {
self.returned = false;
self
}
pub fn generated(mut self) -> Self {
self.input = false;
self
}
pub fn generated_id(mut self, generation: IdGeneration) -> Self {
self.generated_id = Some(generation);
self.generated()
}
pub fn references(mut self, foreign_key: ForeignKey) -> Self {
self.foreign_key = Some(foreign_key);
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct DbTable {
pub name: String,
pub fields: IndexMap<String, DbField>,
pub order: Option<u16>,
}
impl DbTable {
pub fn field(&self, logical_name: &str) -> Option<&DbField> {
self.fields.get(logical_name)
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct DbSchema {
tables: IndexMap<String, DbTable>,
}
impl DbSchema {
pub fn table(&self, logical_name: &str) -> Option<&DbTable> {
self.tables.get(logical_name)
}
pub fn table_name(&self, table: &str) -> Result<&str, OpenAuthError> {
self.resolve_table(table)
.map(|(_, table)| table.name.as_str())
.ok_or_else(|| OpenAuthError::TableNotFound {
table: table.to_owned(),
})
}
pub fn field_name(&self, table: &str, field: &str) -> Result<&str, OpenAuthError> {
self.field(table, field)
.map(|field| field.name.as_str())
.map_err(|_| OpenAuthError::FieldNotFound {
table: table.to_owned(),
field: field.to_owned(),
})
}
pub fn field(&self, table: &str, field: &str) -> Result<&DbField, OpenAuthError> {
let (_, table_metadata) =
self.resolve_table(table)
.ok_or_else(|| OpenAuthError::TableNotFound {
table: table.to_owned(),
})?;
table_metadata
.resolve_field(field)
.ok_or_else(|| OpenAuthError::FieldNotFound {
table: table.to_owned(),
field: field.to_owned(),
})
}
pub fn tables(&self) -> impl Iterator<Item = (&str, &DbTable)> {
self.tables
.iter()
.map(|(logical_name, table)| (logical_name.as_str(), table))
}
pub fn insert_plugin_table(
&mut self,
logical_name: String,
table: DbTable,
) -> Result<(), OpenAuthError> {
if let Some(existing) = self.tables.get(&logical_name) {
if existing == &table {
return Ok(());
}
return Err(OpenAuthError::InvalidConfig(format!(
"plugin schema table `{logical_name}` conflicts with an existing table"
)));
}
if self
.tables
.values()
.any(|existing| existing.name == table.name)
{
return Err(OpenAuthError::InvalidConfig(format!(
"plugin schema table `{logical_name}` uses existing database table `{}`",
table.name
)));
}
self.tables.insert(logical_name, table);
Ok(())
}
pub fn insert_plugin_field(
&mut self,
table: &str,
logical_name: String,
field: DbField,
) -> Result<(), OpenAuthError> {
let (_, table_metadata) =
self.resolve_table_mut(table)
.ok_or_else(|| OpenAuthError::TableNotFound {
table: table.to_owned(),
})?;
if let Some(existing) = table_metadata.fields.get(&logical_name) {
if existing == &field {
return Ok(());
}
return Err(OpenAuthError::InvalidConfig(format!(
"plugin schema field `{logical_name}` conflicts with table `{table}`"
)));
}
if table_metadata
.fields
.values()
.any(|existing| existing.name == field.name)
{
return Err(OpenAuthError::InvalidConfig(format!(
"plugin schema field `{logical_name}` uses existing database field `{}` on table `{table}`",
field.name
)));
}
table_metadata.fields.insert(logical_name, field);
Ok(())
}
fn resolve_table(&self, table: &str) -> Option<(&str, &DbTable)> {
self.tables
.get_key_value(table)
.map(|(logical_name, table)| (logical_name.as_str(), table))
.or_else(|| {
self.tables
.iter()
.find(|(_, table_metadata)| table_metadata.name == table)
.map(|(logical_name, table)| (logical_name.as_str(), table))
})
}
fn resolve_table_mut(&mut self, table: &str) -> Option<(&str, &mut DbTable)> {
if self.tables.contains_key(table) {
let (logical_name, table_metadata) = self.tables.get_key_value_mut(table)?;
return Some((logical_name.as_str(), table_metadata));
}
self.tables
.iter_mut()
.find(|(_, table_metadata)| table_metadata.name == table)
.map(|(logical_name, table)| (logical_name.as_str(), table))
}
fn insert(&mut self, logical_name: impl Into<String>, table: DbTable) {
self.tables.insert(logical_name.into(), table);
}
}
impl DbTable {
fn resolve_field(&self, field: &str) -> Option<&DbField> {
self.fields
.get(field)
.or_else(|| self.fields.values().find(|metadata| metadata.name == field))
}
}