use std::marker::PhantomData;
use sqlx::postgres::PgRow;
use sqlx::PgPool;
use crate::core::condition::SqlValue;
use crate::core::model::Model;
use crate::core::sqlx::pg as sqlx_pg;
use crate::orm::postgres::pool;
pub struct PivotQuery<T> {
owner_id: SqlValue,
through: &'static str,
fk: &'static str,
rfk: &'static str,
_marker: PhantomData<T>,
}
impl<T> PivotQuery<T>
where
T: Model + for<'r> sqlx::FromRow<'r, PgRow> + Send + Unpin,
{
pub fn new(
owner_id: impl Into<SqlValue>,
through: &'static str,
fk: &'static str,
rfk: &'static str,
) -> Self {
Self {
owner_id: owner_id.into(),
through,
fk,
rfk,
_marker: PhantomData,
}
}
fn pool() -> Result<PgPool, sqlx::Error> {
pool::try_current_pool().ok_or_else(|| {
sqlx::Error::Configuration(
"no database pool in scope — add OrmLayer to your router or \
call pool::with_pool() in tests"
.to_string()
.into(),
)
})
}
fn select_sql(&self) -> (String, Vec<SqlValue>) {
let t = T::table_name();
let pk = T::primary_key();
let sql = format!(
"SELECT {t}.* FROM {t} INNER JOIN {through} ON {through}.{rfk} = {t}.{pk} \
WHERE {through}.{fk} = $1",
through = self.through,
rfk = self.rfk,
fk = self.fk,
);
(sql, vec![self.owner_id.clone()])
}
fn count_sql(&self) -> (String, Vec<SqlValue>) {
let t = T::table_name();
let pk = T::primary_key();
let sql = format!(
"SELECT COUNT(*) FROM {t} INNER JOIN {through} ON {through}.{rfk} = {t}.{pk} \
WHERE {through}.{fk} = $1",
through = self.through,
rfk = self.rfk,
fk = self.fk,
);
(sql, vec![self.owner_id.clone()])
}
pub async fn get(self) -> Result<Vec<T>, sqlx::Error> {
let pool = Self::pool()?;
let (sql, params) = self.select_sql();
sqlx_pg::fetch_all_as::<T>(&pool, &sql, params).await
}
pub async fn first(self) -> Result<Option<T>, sqlx::Error> {
let pool = Self::pool()?;
let (base_sql, params) = self.select_sql();
let sql = format!("{base_sql} LIMIT 1");
sqlx_pg::fetch_optional_as::<T>(&pool, &sql, params).await
}
pub async fn count(self) -> Result<i64, sqlx::Error> {
let pool = Self::pool()?;
let (sql, params) = self.count_sql();
let row = sqlx_pg::build_query(&sql, params).fetch_one(&pool).await?;
use sqlx::Row;
row.try_get::<i64, _>(0)
}
pub async fn exists(self) -> Result<bool, sqlx::Error> {
Ok(self.count().await? > 0)
}
pub async fn attach(self, related_id: impl Into<SqlValue>) -> Result<u64, sqlx::Error> {
let pool = Self::pool()?;
let sql = format!(
"INSERT INTO {through} ({fk}, {rfk}) VALUES ($1, $2) ON CONFLICT DO NOTHING",
through = self.through,
fk = self.fk,
rfk = self.rfk,
);
let params = vec![self.owner_id, related_id.into()];
let result = sqlx_pg::build_query(&sql, params).execute(&pool).await?;
Ok(result.rows_affected())
}
pub async fn attach_with(
self,
related_id: impl Into<SqlValue>,
extra: &[(&str, SqlValue)],
) -> Result<u64, sqlx::Error> {
let pool = Self::pool()?;
let mut cols = vec![self.fk, self.rfk];
let mut ph = vec!["$1".to_string(), "$2".to_string()];
let mut params = vec![self.owner_id, related_id.into()];
for (i, (col, val)) in extra.iter().enumerate() {
cols.push(col);
ph.push(format!("${}", i + 3));
params.push(val.clone());
}
let sql = format!(
"INSERT INTO {through} ({cols}) VALUES ({ph}) ON CONFLICT DO NOTHING",
through = self.through,
cols = cols.join(", "),
ph = ph.join(", "),
);
let result = sqlx_pg::build_query(&sql, params).execute(&pool).await?;
Ok(result.rows_affected())
}
pub async fn detach(self, related_id: impl Into<SqlValue>) -> Result<u64, sqlx::Error> {
let pool = Self::pool()?;
let sql = format!(
"DELETE FROM {through} WHERE {fk} = $1 AND {rfk} = $2",
through = self.through,
fk = self.fk,
rfk = self.rfk,
);
let params = vec![self.owner_id, related_id.into()];
let result = sqlx_pg::build_query(&sql, params).execute(&pool).await?;
Ok(result.rows_affected())
}
pub async fn detach_all(self) -> Result<u64, sqlx::Error> {
let pool = Self::pool()?;
let sql = format!(
"DELETE FROM {through} WHERE {fk} = $1",
through = self.through,
fk = self.fk,
);
let params = vec![self.owner_id];
let result = sqlx_pg::build_query(&sql, params).execute(&pool).await?;
Ok(result.rows_affected())
}
pub async fn sync(self, ids: &[impl Into<SqlValue> + Clone]) -> Result<(), sqlx::Error> {
let pool = Self::pool()?;
if ids.is_empty() {
let sql = format!(
"DELETE FROM {through} WHERE {fk} = $1",
through = self.through,
fk = self.fk,
);
sqlx_pg::build_query(&sql, vec![self.owner_id])
.execute(&pool)
.await?;
return Ok(());
}
let id_vals: Vec<SqlValue> = ids.iter().map(|v| v.clone().into()).collect();
let placeholders: Vec<String> = (2..=id_vals.len() + 1).map(|i| format!("${i}")).collect();
let del_sql = format!(
"DELETE FROM {through} WHERE {fk} = $1 AND {rfk} NOT IN ({ph})",
through = self.through,
fk = self.fk,
rfk = self.rfk,
ph = placeholders.join(", "),
);
let mut del_params = vec![self.owner_id.clone()];
del_params.extend(id_vals.iter().cloned());
sqlx_pg::build_query(&del_sql, del_params)
.execute(&pool)
.await?;
for id_val in id_vals {
let ins_sql = format!(
"INSERT INTO {through} ({fk}, {rfk}) VALUES ($1, $2) ON CONFLICT DO NOTHING",
through = self.through,
fk = self.fk,
rfk = self.rfk,
);
sqlx_pg::build_query(&ins_sql, vec![self.owner_id.clone(), id_val])
.execute(&pool)
.await?;
}
Ok(())
}
pub async fn update_pivot(
self,
related_id: impl Into<SqlValue>,
data: &[(&str, SqlValue)],
) -> Result<u64, sqlx::Error> {
if data.is_empty() {
return Ok(0);
}
let pool = Self::pool()?;
let set_clauses: Vec<String> = data
.iter()
.enumerate()
.map(|(i, (col, _))| format!("{col} = ${}", i + 3))
.collect();
let sql = format!(
"UPDATE {through} SET {sets} WHERE {fk} = $1 AND {rfk} = $2",
through = self.through,
sets = set_clauses.join(", "),
fk = self.fk,
rfk = self.rfk,
);
let mut params = vec![self.owner_id, related_id.into()];
params.extend(data.iter().map(|(_, v)| v.clone()));
let result = sqlx_pg::build_query(&sql, params).execute(&pool).await?;
Ok(result.rows_affected())
}
pub async fn toggle(self, ids: &[impl Into<SqlValue> + Clone]) -> Result<(), sqlx::Error> {
let pool = Self::pool()?;
let id_vals: Vec<SqlValue> = ids.iter().map(|v| v.clone().into()).collect();
if id_vals.is_empty() {
return Ok(());
}
let placeholders: Vec<String> = (2..=id_vals.len() + 1).map(|i| format!("${i}")).collect();
let sel_sql = format!(
"SELECT {rfk} FROM {through} WHERE {fk} = $1 AND {rfk} IN ({ph})",
through = self.through,
fk = self.fk,
rfk = self.rfk,
ph = placeholders.join(", "),
);
let mut sel_params = vec![self.owner_id.clone()];
sel_params.extend(id_vals.iter().cloned());
let rows = sqlx_pg::build_query(&sel_sql, sel_params)
.fetch_all(&pool)
.await?;
use sqlx::Row;
let existing: Vec<i64> = rows
.iter()
.filter_map(|r| r.try_get::<i64, _>(0).ok())
.collect();
for id_val in id_vals {
let id_i64 = match &id_val {
SqlValue::Integer(n) => Some(*n),
_ => None,
};
let already_exists = id_i64.map(|n| existing.contains(&n)).unwrap_or(false);
if already_exists {
let del = format!(
"DELETE FROM {through} WHERE {fk} = $1 AND {rfk} = $2",
through = self.through,
fk = self.fk,
rfk = self.rfk,
);
sqlx_pg::build_query(&del, vec![self.owner_id.clone(), id_val])
.execute(&pool)
.await?;
} else {
let ins = format!(
"INSERT INTO {through} ({fk}, {rfk}) VALUES ($1, $2) ON CONFLICT DO NOTHING",
through = self.through,
fk = self.fk,
rfk = self.rfk,
);
sqlx_pg::build_query(&ins, vec![self.owner_id.clone(), id_val])
.execute(&pool)
.await?;
}
}
Ok(())
}
}