use sqlx::{postgres::PgRow, PgPool};
use super::executor;
use crate::core::condition::SqlValue;
use crate::core::model::Model;
use crate::core::query::QueryBuilder;
use crate::core::sqlx::pg as sqlx_pg;
use crate::orm::model_query::ModelQuery;
pub trait PgModel: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin + 'static {
fn all(
pool: &PgPool,
) -> impl std::future::Future<Output = Result<Vec<Self>, sqlx::Error>> + Send
where
Self: Sized,
{
executor::fetch_all(pool, Self::query())
}
fn find_where(
pool: &PgPool,
builder: QueryBuilder<Self>,
) -> impl std::future::Future<Output = Result<Vec<Self>, sqlx::Error>> + Send
where
Self: Sized,
{
executor::fetch_all(pool, builder)
}
fn find_by_pk(
pool: &PgPool,
id: impl Into<SqlValue> + Send,
) -> impl std::future::Future<Output = Result<Option<Self>, sqlx::Error>> + Send
where
Self: Sized,
{
executor::fetch_optional(pool, Self::find(id))
}
fn count(pool: &PgPool) -> impl std::future::Future<Output = Result<i64, sqlx::Error>> + Send
where
Self: Sized,
{
executor::count(pool, Self::query())
}
fn count_where(
pool: &PgPool,
builder: QueryBuilder<Self>,
) -> impl std::future::Future<Output = Result<i64, sqlx::Error>> + Send
where
Self: Sized,
{
executor::count(pool, builder)
}
fn create(
pool: &PgPool,
data: &[(&str, SqlValue)],
) -> impl std::future::Future<Output = Result<u64, sqlx::Error>> + Send
where
Self: Sized,
{
executor::insert::<Self>(pool, Self::table_name(), data)
}
fn update_by_pk(
pool: &PgPool,
id: impl Into<SqlValue> + Send,
data: &[(&str, SqlValue)],
) -> impl std::future::Future<Output = Result<u64, sqlx::Error>> + Send
where
Self: Sized,
{
let builder = Self::find(id);
executor::update::<Self>(pool, builder, data)
}
fn delete_by_pk(
pool: &PgPool,
id: impl Into<SqlValue> + Send,
) -> impl std::future::Future<Output = Result<u64, sqlx::Error>> + Send
where
Self: Sized,
{
executor::delete(pool, Self::find(id))
}
fn delete_where(
pool: &PgPool,
builder: QueryBuilder<Self>,
) -> impl std::future::Future<Output = Result<u64, sqlx::Error>> + Send
where
Self: Sized,
{
executor::delete(pool, builder)
}
fn update_where(
pool: &PgPool,
builder: QueryBuilder<Self>,
data: &[(&str, SqlValue)],
) -> impl std::future::Future<Output = Result<u64, sqlx::Error>> + Send
where
Self: Sized,
{
executor::update::<Self>(pool, builder, data)
}
fn bulk_create(
pool: &PgPool,
rows: &[Vec<(&str, SqlValue)>],
) -> impl std::future::Future<Output = Result<u64, sqlx::Error>> + Send
where
Self: Sized,
{
executor::bulk_insert::<Self>(pool, Self::table_name(), rows)
}
fn bulk_upsert(
pool: &PgPool,
rows: &[Vec<(&str, SqlValue)>],
conflict_cols: &[&str],
) -> impl std::future::Future<Output = Result<u64, sqlx::Error>> + Send
where
Self: Sized,
{
executor::bulk_upsert::<Self>(pool, rows, conflict_cols)
}
fn bulk_insert_chunked(
pool: &PgPool,
rows: &[Vec<(&str, SqlValue)>],
chunk_size: usize,
) -> impl std::future::Future<Output = Result<u64, sqlx::Error>> + Send
where
Self: Sized,
{
let pool = pool.clone();
let owned: Vec<Vec<(String, SqlValue)>> = rows
.iter()
.map(|r| r.iter().map(|(k, v)| (k.to_string(), v.clone())).collect())
.collect();
async move {
let mut total = 0u64;
for chunk in owned.chunks(chunk_size) {
let data: Vec<Vec<(&str, SqlValue)>> = chunk
.iter()
.map(|r| r.iter().map(|(k, v)| (k.as_str(), v.clone())).collect())
.collect();
total += executor::bulk_insert::<Self>(&pool, Self::table_name(), &data).await?;
}
Ok(total)
}
}
fn create_returning(
pool: &PgPool,
data: &[(&str, SqlValue)],
) -> impl std::future::Future<Output = Result<Self, sqlx::Error>> + Send
where
Self: Sized,
{
executor::insert_returning::<Self>(pool, Self::table_name(), data)
}
fn filter(col: &str, val: impl Into<SqlValue>) -> ModelQuery<Self>
where
Self: Sized + Sync + 'static,
{
ModelQuery::new(Self::query().where_eq(col, val))
}
fn all_query() -> ModelQuery<Self>
where
Self: Sized + Sync + 'static,
{
ModelQuery::new(Self::query())
}
fn find_query(id: impl Into<SqlValue>) -> ModelQuery<Self>
where
Self: Sized + Sync + 'static,
{
ModelQuery::new(Self::find(id))
}
fn find_or_fail(
pool: &PgPool,
id: impl Into<SqlValue> + Send,
) -> impl std::future::Future<Output = Result<Self, sqlx::Error>> + Send
where
Self: Sized,
{
let pool = pool.clone();
let builder = Self::find(id);
async move {
executor::fetch_optional(&pool, builder)
.await?
.ok_or(sqlx::Error::RowNotFound)
}
}
fn delete_self(
&self,
pool: &PgPool,
) -> impl std::future::Future<Output = Result<u64, sqlx::Error>> + Send
where
Self: Sized,
{
let pk = self.pk_value();
let pool = pool.clone();
async move { executor::delete::<Self>(&pool, Self::find(pk)).await }
}
fn insert_many(
pool: &PgPool,
rows: &[Vec<(&str, SqlValue)>],
) -> impl std::future::Future<Output = Result<Vec<Self>, sqlx::Error>> + Send
where
Self: Sized,
{
executor::bulk_insert_returning::<Self>(pool, Self::table_name(), rows)
}
fn upsert_returning(
pool: &PgPool,
data: &[(&str, SqlValue)],
conflict_cols: &[&str],
) -> impl std::future::Future<Output = Result<Self, sqlx::Error>> + Send
where
Self: Sized,
{
executor::upsert_returning::<Self>(pool, data, conflict_cols)
}
fn sum(
pool: &PgPool,
col: &str,
) -> impl std::future::Future<Output = Result<Option<f64>, sqlx::Error>> + Send
where
Self: Sized,
{
let agg = format!("SUM({col})");
let pool = pool.clone();
let builder = Self::query();
async move { executor::aggregate(&pool, builder, &agg).await }
}
fn avg(
pool: &PgPool,
col: &str,
) -> impl std::future::Future<Output = Result<Option<f64>, sqlx::Error>> + Send
where
Self: Sized,
{
let agg = format!("AVG({col})");
let pool = pool.clone();
let builder = Self::query();
async move { executor::aggregate(&pool, builder, &agg).await }
}
fn max(
pool: &PgPool,
col: &str,
) -> impl std::future::Future<Output = Result<Option<f64>, sqlx::Error>> + Send
where
Self: Sized,
{
let agg = format!("MAX({col})");
let pool = pool.clone();
let builder = Self::query();
async move { executor::aggregate(&pool, builder, &agg).await }
}
fn min(
pool: &PgPool,
col: &str,
) -> impl std::future::Future<Output = Result<Option<f64>, sqlx::Error>> + Send
where
Self: Sized,
{
let agg = format!("MIN({col})");
let pool = pool.clone();
let builder = Self::query();
async move { executor::aggregate(&pool, builder, &agg).await }
}
fn soft_delete_by_pk(
pool: &PgPool,
id: impl Into<SqlValue> + Send,
) -> impl std::future::Future<Output = Result<u64, sqlx::Error>> + Send
where
Self: Sized,
{
let col = Self::soft_delete_column().unwrap_or("deleted_at");
executor::soft_delete::<Self>(pool, Self::find(id), col)
}
fn restore_by_pk(
pool: &PgPool,
id: impl Into<SqlValue> + Send,
) -> impl std::future::Future<Output = Result<u64, sqlx::Error>> + Send
where
Self: Sized,
{
let col = Self::soft_delete_column().unwrap_or("deleted_at");
let builder = Self::query().where_eq(Self::primary_key(), id);
executor::restore::<Self>(pool, builder, col)
}
fn touch_by_pk(
pool: &PgPool,
id: impl Into<SqlValue> + Send,
) -> impl std::future::Future<Output = Result<u64, sqlx::Error>> + Send
where
Self: Sized,
{
executor::touch::<Self>(pool, Self::find(id))
}
fn soft_delete_where(
pool: &PgPool,
builder: QueryBuilder<Self>,
) -> impl std::future::Future<Output = Result<u64, sqlx::Error>> + Send
where
Self: Sized,
{
let col = Self::soft_delete_column().unwrap_or("deleted_at");
executor::soft_delete::<Self>(pool, builder, col)
}
fn first_or_create(
pool: &PgPool,
find: &[(&str, SqlValue)],
create: &[(&str, SqlValue)],
) -> impl std::future::Future<Output = Result<Self, sqlx::Error>> + Send
where
Self: Sized,
{
let pool = pool.clone();
let find_owned: Vec<(String, SqlValue)> = find
.iter()
.map(|(k, v)| (k.to_string(), v.clone()))
.collect();
let create_owned: Vec<(String, SqlValue)> = create
.iter()
.map(|(k, v)| (k.to_string(), v.clone()))
.collect();
async move { do_first_or_create::<Self>(&pool, find_owned, create_owned).await }
}
fn subscribe(
pool: &PgPool,
) -> impl std::future::Future<Output = Result<sqlx::postgres::PgListener, sqlx::Error>> + Send
where
Self: Sized,
{
let pool = pool.clone();
let channel = Self::table_name();
async move {
let mut listener = sqlx::postgres::PgListener::connect_with(&pool).await?;
listener.listen(channel).await?;
Ok(listener)
}
}
fn subscribe_channel(
pool: &PgPool,
channel: &str,
) -> impl std::future::Future<Output = Result<sqlx::postgres::PgListener, sqlx::Error>> + Send
where
Self: Sized,
{
let pool = pool.clone();
let channel = channel.to_string();
async move {
let mut listener = sqlx::postgres::PgListener::connect_with(&pool).await?;
listener.listen(&channel).await?;
Ok(listener)
}
}
fn update_or_create(
pool: &PgPool,
find: &[(&str, SqlValue)],
update: &[(&str, SqlValue)],
) -> impl std::future::Future<Output = Result<Self, sqlx::Error>> + Send
where
Self: Sized,
{
let pool = pool.clone();
let find_owned: Vec<(String, SqlValue)> = find
.iter()
.map(|(k, v)| (k.to_string(), v.clone()))
.collect();
let update_owned: Vec<(String, SqlValue)> = update
.iter()
.map(|(k, v)| (k.to_string(), v.clone()))
.collect();
async move { do_update_or_create::<Self>(&pool, find_owned, update_owned).await }
}
}
impl<T> PgModel for T where T: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin + 'static {}
async fn do_first_or_create<T>(
pool: &PgPool,
find: Vec<(String, SqlValue)>,
create: Vec<(String, SqlValue)>,
) -> Result<T, sqlx::Error>
where
T: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
let mut tx = pool.begin().await?;
let mut qb: QueryBuilder<T> = QueryBuilder::new(T::table_name());
for (col, val) in &find {
qb = qb.where_eq(col.as_str(), val.clone());
}
let (base_sql, params) = qb.limit(1).to_sql();
let sql_locked = format!("{base_sql} FOR UPDATE");
let existing = sqlx_pg::build_query_as::<T>(&sql_locked, params)
.fetch_optional(&mut *tx)
.await?;
if let Some(row) = existing {
tx.commit().await?;
return Ok(row);
}
let all_data: Vec<(&str, SqlValue)> = find
.iter()
.chain(create.iter())
.map(|(k, v)| (k.as_str(), v.clone()))
.collect();
let (insert_sql, insert_params) = QueryBuilder::<T>::insert_sql(T::table_name(), &all_data);
let sql_returning = format!("{insert_sql} RETURNING *");
let new_row = sqlx_pg::build_query_as::<T>(&sql_returning, insert_params)
.fetch_one(&mut *tx)
.await?;
tx.commit().await?;
Ok(new_row)
}
async fn do_update_or_create<T>(
pool: &PgPool,
find: Vec<(String, SqlValue)>,
update: Vec<(String, SqlValue)>,
) -> Result<T, sqlx::Error>
where
T: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
let mut tx = pool.begin().await?;
let mut qb: QueryBuilder<T> = QueryBuilder::new(T::table_name());
for (col, val) in &find {
qb = qb.where_eq(col.as_str(), val.clone());
}
let (base_sql, params) = qb.limit(1).to_sql();
let sql_locked = format!("{base_sql} FOR UPDATE");
let existing = sqlx_pg::build_query_as::<T>(&sql_locked, params)
.fetch_optional(&mut *tx)
.await?;
if let Some(existing_row) = existing {
let pk_val = T::pk_value(&existing_row);
let update_data: Vec<(&str, SqlValue)> = update
.iter()
.map(|(k, v)| (k.as_str(), v.clone()))
.collect();
let (update_sql, update_params) = QueryBuilder::<T>::new(T::table_name())
.where_eq(T::primary_key(), pk_val.clone())
.to_update_sql(&update_data);
sqlx_pg::build_query(&update_sql, update_params)
.execute(&mut *tx)
.await?;
let (select_sql, select_params) = QueryBuilder::<T>::new(T::table_name())
.where_eq(T::primary_key(), pk_val)
.limit(1)
.to_sql();
let updated_row = sqlx_pg::build_query_as::<T>(&select_sql, select_params)
.fetch_one(&mut *tx)
.await?;
tx.commit().await?;
return Ok(updated_row);
}
let all_data: Vec<(&str, SqlValue)> = find
.iter()
.chain(update.iter())
.map(|(k, v)| (k.as_str(), v.clone()))
.collect();
let (insert_sql, insert_params) = QueryBuilder::<T>::insert_sql(T::table_name(), &all_data);
let sql_returning = format!("{insert_sql} RETURNING *");
let new_row = sqlx_pg::build_query_as::<T>(&sql_returning, insert_params)
.fetch_one(&mut *tx)
.await?;
tx.commit().await?;
Ok(new_row)
}