use chrono::Utc;
use sqlx::{Pool, Postgres};
use crate::oauth2::{
errors::OAuth2Error,
types::{AccountSearchField, OAuth2Account, Provider, ProviderUserId},
};
use crate::storage::validate_postgres_table_schema;
use crate::userdb::DB_TABLE_USERS;
use super::config::DB_TABLE_OAUTH2_ACCOUNTS;
pub(super) async fn create_tables_postgres(pool: &Pool<Postgres>) -> Result<(), OAuth2Error> {
let oauth2_table = DB_TABLE_OAUTH2_ACCOUNTS.as_str();
let users_table = DB_TABLE_USERS.as_str();
sqlx::query(&format!(
r#"
CREATE TABLE IF NOT EXISTS {oauth2_table} (
sequence_number BIGSERIAL PRIMARY KEY,
id TEXT NOT NULL UNIQUE,
user_id TEXT NOT NULL REFERENCES {users_table}(id) ON DELETE CASCADE,
provider TEXT NOT NULL,
provider_user_id TEXT NOT NULL,
name TEXT NOT NULL,
email TEXT NOT NULL,
picture TEXT,
metadata JSONB NOT NULL,
created_at TIMESTAMPTZ NOT NULL,
updated_at TIMESTAMPTZ NOT NULL,
UNIQUE(provider, provider_user_id)
)
"#
))
.execute(pool)
.await
.map_err(|e| OAuth2Error::Storage(e.to_string()))?;
sqlx::query(&format!(
r#"
CREATE INDEX IF NOT EXISTS idx_{}_user_id ON {}(user_id)
"#,
oauth2_table.replace(".", "_"),
oauth2_table
))
.execute(pool)
.await
.map_err(|e| OAuth2Error::Storage(e.to_string()))?;
Ok(())
}
pub(super) async fn validate_oauth2_tables_postgres(
pool: &Pool<Postgres>,
) -> Result<(), OAuth2Error> {
let oauth2_table = DB_TABLE_OAUTH2_ACCOUNTS.as_str();
let expected_columns = [
("sequence_number", "bigint"),
("id", "text"),
("user_id", "text"),
("provider", "text"),
("provider_user_id", "text"),
("name", "text"),
("email", "text"),
("picture", "text"),
("metadata", "jsonb"),
("created_at", "timestamp with time zone"),
("updated_at", "timestamp with time zone"),
];
validate_postgres_table_schema(pool, oauth2_table, &expected_columns, OAuth2Error::Storage)
.await
}
pub(super) async fn get_oauth2_accounts_by_field_postgres(
pool: &Pool<Postgres>,
field: &AccountSearchField,
) -> Result<Vec<OAuth2Account>, OAuth2Error> {
let table_name = DB_TABLE_OAUTH2_ACCOUNTS.as_str();
let (query, value) = match field {
AccountSearchField::Id(id) => (
&format!("SELECT * FROM {table_name} WHERE id = $1"),
id.as_str(),
),
AccountSearchField::UserId(user_id) => (
&format!("SELECT * FROM {table_name} WHERE user_id = $1"),
user_id.as_str(),
),
AccountSearchField::Provider(provider) => (
&format!("SELECT * FROM {table_name} WHERE provider = $1"),
provider.as_str(),
),
AccountSearchField::ProviderUserId(provider_user_id) => (
&format!("SELECT * FROM {table_name} WHERE provider_user_id = $1"),
provider_user_id.as_str(),
),
AccountSearchField::Name(name) => (
&format!("SELECT * FROM {table_name} WHERE name = $1"),
name.as_str(),
),
AccountSearchField::Email(email) => (
&format!("SELECT * FROM {table_name} WHERE email = $1"),
email.as_str(),
),
};
sqlx::query_as::<_, OAuth2Account>(query)
.bind(value)
.fetch_all(pool)
.await
.map_err(|e| OAuth2Error::Storage(e.to_string()))
}
pub(super) async fn get_oauth2_account_by_provider_postgres(
pool: &Pool<Postgres>,
provider: Provider,
provider_user_id: ProviderUserId,
) -> Result<Option<OAuth2Account>, OAuth2Error> {
let table_name = DB_TABLE_OAUTH2_ACCOUNTS.as_str();
sqlx::query_as::<_, OAuth2Account>(&format!(
r#"
SELECT * FROM {table_name}
WHERE provider = $1 AND provider_user_id = $2
"#
))
.bind(provider.as_str())
.bind(provider_user_id.as_str())
.fetch_optional(pool)
.await
.map_err(|e| OAuth2Error::Storage(e.to_string()))
}
pub(super) async fn upsert_oauth2_account_postgres(
pool: &Pool<Postgres>,
account: OAuth2Account,
) -> Result<OAuth2Account, OAuth2Error> {
let table_name = DB_TABLE_OAUTH2_ACCOUNTS.as_str();
let mut tx = pool
.begin()
.await
.map_err(|e| OAuth2Error::Storage(e.to_string()))?;
let existing = sqlx::query_as::<_, OAuth2Account>(&format!(
r#"
SELECT * FROM {table_name}
WHERE provider = $1 AND provider_user_id = $2
"#
))
.bind(&account.provider)
.bind(&account.provider_user_id)
.fetch_optional(&mut *tx)
.await
.map_err(|e| OAuth2Error::Storage(e.to_string()))?;
let account_id = if let Some(existing) = existing {
sqlx::query(&format!(
r#"
UPDATE {table_name} SET
name = $1,
email = $2,
picture = $3,
metadata = $4,
updated_at = $5
WHERE id = $6
"#
))
.bind(&account.name)
.bind(&account.email)
.bind(&account.picture)
.bind(
serde_json::to_value(&account.metadata)
.map_err(|e| OAuth2Error::Storage(e.to_string()))?,
)
.bind(Utc::now())
.bind(&existing.id)
.execute(&mut *tx)
.await
.map_err(|e| OAuth2Error::Storage(e.to_string()))?;
existing.id
} else {
let id = account.id.clone();
sqlx::query(
&format!(
r#"
INSERT INTO {table_name}
(id, user_id, provider, provider_user_id, name, email, picture, metadata, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
"#
)
)
.bind(&id)
.bind(&account.user_id)
.bind(&account.provider)
.bind(&account.provider_user_id)
.bind(&account.name)
.bind(&account.email)
.bind(&account.picture)
.bind(serde_json::to_value(&account.metadata).map_err(|e| OAuth2Error::Storage(e.to_string()))?)
.bind(Utc::now())
.bind(Utc::now())
.execute(&mut *tx)
.await
.map_err(|e| OAuth2Error::Storage(e.to_string()))?;
id
};
let updated_account = sqlx::query_as::<_, OAuth2Account>(&format!(
r#"
SELECT * FROM {table_name} WHERE id = $1
"#
))
.bind(account_id)
.fetch_one(&mut *tx)
.await
.map_err(|e| OAuth2Error::Storage(e.to_string()))?;
tx.commit()
.await
.map_err(|e| OAuth2Error::Storage(e.to_string()))?;
Ok(updated_account)
}
pub(super) async fn delete_oauth2_accounts_by_field_postgres(
pool: &Pool<Postgres>,
field: &AccountSearchField,
) -> Result<(), OAuth2Error> {
let table_name = DB_TABLE_OAUTH2_ACCOUNTS.as_str();
let (query, value) = match field {
AccountSearchField::Id(id) => (
&format!("DELETE FROM {table_name} WHERE id = $1"),
id.as_str(),
),
AccountSearchField::UserId(user_id) => (
&format!("DELETE FROM {table_name} WHERE user_id = $1"),
user_id.as_str(),
),
AccountSearchField::Provider(provider) => (
&format!("DELETE FROM {table_name} WHERE provider = $1"),
provider.as_str(),
),
AccountSearchField::ProviderUserId(provider_user_id) => (
&format!("DELETE FROM {table_name} WHERE provider_user_id = $1"),
provider_user_id.as_str(),
),
AccountSearchField::Name(name) => (
&format!("DELETE FROM {table_name} WHERE name = $1"),
name.as_str(),
),
AccountSearchField::Email(email) => (
&format!("DELETE FROM {table_name} WHERE email = $1"),
email.as_str(),
),
};
sqlx::query(query)
.bind(value)
.execute(pool)
.await
.map_err(|e| OAuth2Error::Storage(e.to_string()))?;
Ok(())
}