use std::collections::HashMap;
use std::fmt::Debug;
use chrono::{DateTime, Utc};
use serde::{Serialize, Deserialize};
use sqlx::{Pool, Postgres, Row};
use uuid::Uuid;
use crate::error::{ModelError, ModelResult};
use crate::query::QueryBuilder;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum PrimaryKey {
Integer(i64),
Uuid(Uuid),
Composite(HashMap<String, String>),
}
impl std::fmt::Display for PrimaryKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PrimaryKey::Integer(id) => write!(f, "{}", id),
PrimaryKey::Uuid(id) => write!(f, "{}", id),
PrimaryKey::Composite(fields) => {
let pairs: Vec<String> = fields.iter()
.map(|(k, v)| format!("{}:{}", k, v))
.collect();
write!(f, "{}", pairs.join(","))
}
}
}
}
impl PrimaryKey {
pub fn as_i64(&self) -> Option<i64> {
match self {
PrimaryKey::Integer(id) => Some(*id),
_ => None,
}
}
pub fn as_uuid(&self) -> Option<Uuid> {
match self {
PrimaryKey::Uuid(id) => Some(*id),
_ => None,
}
}
}
pub trait Model: Send + Sync + Debug + Serialize + for<'de> Deserialize<'de> {
type PrimaryKey: Clone + Send + Sync + Debug + std::fmt::Display;
fn table_name() -> &'static str;
fn primary_key_name() -> &'static str {
"id"
}
fn primary_key(&self) -> Option<Self::PrimaryKey>;
fn set_primary_key(&mut self, key: Self::PrimaryKey);
fn uses_timestamps() -> bool {
false
}
fn uses_soft_deletes() -> bool {
false
}
fn created_at(&self) -> Option<DateTime<Utc>> {
None
}
fn set_created_at(&mut self, _timestamp: DateTime<Utc>) {}
fn updated_at(&self) -> Option<DateTime<Utc>> {
None
}
fn set_updated_at(&mut self, _timestamp: DateTime<Utc>) {}
fn deleted_at(&self) -> Option<DateTime<Utc>> {
None
}
fn set_deleted_at(&mut self, _timestamp: Option<DateTime<Utc>>) {}
fn is_soft_deleted(&self) -> bool {
self.deleted_at().is_some()
}
async fn find(pool: &Pool<Postgres>, id: Self::PrimaryKey) -> ModelResult<Option<Self>>
where
Self: Sized,
{
let query: QueryBuilder<Self> = QueryBuilder::new()
.select("*")
.from(Self::table_name())
.where_eq(Self::primary_key_name(), id.to_string());
let sql = query.to_sql();
let row = sqlx::query(&sql)
.fetch_optional(pool)
.await?;
match row {
Some(row) => {
let model = Self::from_row(&row)?;
Ok(Some(model))
}
None => Ok(None),
}
}
async fn find_or_fail(pool: &Pool<Postgres>, id: Self::PrimaryKey) -> ModelResult<Self>
where
Self: Sized,
{
Self::find(pool, id)
.await?
.ok_or_else(|| ModelError::NotFound(Self::table_name().to_string()))
}
async fn create(pool: &Pool<Postgres>, mut model: Self) -> ModelResult<Self>
where
Self: Sized,
{
if Self::uses_timestamps() {
let now = Utc::now();
model.set_created_at(now);
model.set_updated_at(now);
}
let insert_sql = format!("INSERT INTO {} DEFAULT VALUES RETURNING *", Self::table_name());
let row = sqlx::query(&insert_sql)
.fetch_one(pool)
.await?;
Self::from_row(&row)
}
async fn update(&mut self, pool: &Pool<Postgres>) -> ModelResult<()> {
if let Some(pk) = self.primary_key() {
if Self::uses_timestamps() {
self.set_updated_at(Utc::now());
}
let update_sql = format!(
"UPDATE {} SET updated_at = NOW() WHERE {} = $1",
Self::table_name(),
Self::primary_key_name()
);
sqlx::query(&update_sql)
.bind(pk.to_string())
.execute(pool)
.await?;
Ok(())
} else {
Err(ModelError::MissingPrimaryKey)
}
}
async fn delete(self, pool: &Pool<Postgres>) -> ModelResult<()> {
if let Some(pk) = self.primary_key() {
if Self::uses_soft_deletes() {
let soft_delete_sql = format!(
"UPDATE {} SET deleted_at = NOW() WHERE {} = $1",
Self::table_name(),
Self::primary_key_name()
);
sqlx::query(&soft_delete_sql)
.bind(pk.to_string())
.execute(pool)
.await?;
} else {
let delete_sql = format!(
"DELETE FROM {} WHERE {} = $1",
Self::table_name(),
Self::primary_key_name()
);
sqlx::query(&delete_sql)
.bind(pk.to_string())
.execute(pool)
.await?;
}
Ok(())
} else {
Err(ModelError::MissingPrimaryKey)
}
}
fn query() -> QueryBuilder<Self>
where
Self: Sized,
{
let builder = QueryBuilder::new()
.from(Self::table_name());
if Self::uses_soft_deletes() {
builder.where_null("deleted_at")
} else {
builder
}
}
async fn all(pool: &Pool<Postgres>) -> ModelResult<Vec<Self>>
where
Self: Sized,
{
let query = Self::query().select("*");
let sql = query.to_sql();
let rows = sqlx::query(&sql)
.fetch_all(pool)
.await?;
let mut models = Vec::new();
for row in rows {
models.push(Self::from_row(&row)?);
}
Ok(models)
}
async fn count(pool: &Pool<Postgres>) -> ModelResult<i64>
where
Self: Sized,
{
let query = Self::query().select("COUNT(*)");
let sql = query.to_sql();
let row = sqlx::query(&sql)
.fetch_one(pool)
.await?;
let count: i64 = row.try_get(0)?;
Ok(count)
}
fn from_row(row: &sqlx::postgres::PgRow) -> ModelResult<Self>
where
Self: Sized;
fn to_fields(&self) -> HashMap<String, serde_json::Value>;
}
pub trait ModelExtensions: Model {
async fn refresh(&mut self, pool: &Pool<Postgres>) -> ModelResult<()>
where
Self: Sized,
{
if let Some(pk) = self.primary_key() {
if let Some(refreshed) = Self::find(pool, pk).await? {
*self = refreshed;
Ok(())
} else {
Err(ModelError::NotFound(Self::table_name().to_string()))
}
} else {
Err(ModelError::MissingPrimaryKey)
}
}
async fn exists(&self, pool: &Pool<Postgres>) -> ModelResult<bool>
where
Self: Sized,
{
if let Some(pk) = self.primary_key() {
let exists = Self::find(pool, pk).await?.is_some();
Ok(exists)
} else {
Ok(false)
}
}
async fn save(&mut self, pool: &Pool<Postgres>) -> ModelResult<()>
where
Self: Sized,
{
if self.primary_key().is_some() && self.exists(pool).await? {
self.update(pool).await
} else {
Err(ModelError::Validation("Cannot save new model without primary key support from derive macro".to_string()))
}
}
}
impl<T: Model> ModelExtensions for T {}