use super::error::ExecError;
use super::Pool;
use crate::core::SqlValue;
pub struct M2MManager {
pub src_pk: SqlValue,
pub through: &'static str,
pub src_col: &'static str,
pub dst_col: &'static str,
}
impl M2MManager {
pub async fn all_pool(&self, pool: &Pool) -> Result<Vec<i64>, ExecError> {
let dialect = pool.dialect();
let sql = format!(
"SELECT {dst} FROM {through} WHERE {src} = {p1}",
through = dialect.quote_ident(self.through),
src = dialect.quote_ident(self.src_col),
dst = dialect.quote_ident(self.dst_col),
p1 = dialect.placeholder(1),
);
let binds = vec![SqlValue::I64(self.src_pk_i64())];
fetch_i64_col_pool(pool, &sql, binds, self.dst_col).await
}
pub async fn add_pool(&self, dst_id: i64, pool: &Pool) -> Result<(), ExecError> {
let dialect = pool.dialect();
let (insert_kw, suffix) = match dialect.name() {
"mysql" => ("INSERT IGNORE INTO", ""),
_ => ("INSERT INTO", " ON CONFLICT DO NOTHING"),
};
let sql = format!(
"{insert_kw} {through} ({src}, {dst}) VALUES ({p1}, {p2}){suffix}",
through = dialect.quote_ident(self.through),
src = dialect.quote_ident(self.src_col),
dst = dialect.quote_ident(self.dst_col),
p1 = dialect.placeholder(1),
p2 = dialect.placeholder(2),
);
let binds = vec![SqlValue::I64(self.src_pk_i64()), SqlValue::I64(dst_id)];
super::executor::raw_execute_pool(pool, &sql, binds).await?;
Ok(())
}
pub async fn remove_pool(&self, dst_id: i64, pool: &Pool) -> Result<(), ExecError> {
let dialect = pool.dialect();
let sql = format!(
"DELETE FROM {through} WHERE {src} = {p1} AND {dst} = {p2}",
through = dialect.quote_ident(self.through),
src = dialect.quote_ident(self.src_col),
dst = dialect.quote_ident(self.dst_col),
p1 = dialect.placeholder(1),
p2 = dialect.placeholder(2),
);
let binds = vec![SqlValue::I64(self.src_pk_i64()), SqlValue::I64(dst_id)];
super::executor::raw_execute_pool(pool, &sql, binds).await?;
Ok(())
}
pub async fn set_pool(&self, ids: &[i64], pool: &Pool) -> Result<(), ExecError> {
let dialect = pool.dialect();
let del_sql = format!(
"DELETE FROM {through} WHERE {src} = {p1}",
through = dialect.quote_ident(self.through),
src = dialect.quote_ident(self.src_col),
p1 = dialect.placeholder(1),
);
let ins_sql_with_binds = if ids.is_empty() {
None
} else {
let mut sql = format!(
"INSERT INTO {through} ({src}, {dst}) VALUES ",
through = dialect.quote_ident(self.through),
src = dialect.quote_ident(self.src_col),
dst = dialect.quote_ident(self.dst_col),
);
let mut binds = Vec::with_capacity(ids.len() * 2);
let src_pk = self.src_pk_i64();
for (i, dst_id) in ids.iter().enumerate() {
if i > 0 {
sql.push_str(", ");
}
let p_src = dialect.placeholder(i * 2 + 1);
let p_dst = dialect.placeholder(i * 2 + 2);
sql.push_str(&format!("({p_src}, {p_dst})"));
binds.push(SqlValue::I64(src_pk));
binds.push(SqlValue::I64(*dst_id));
}
Some((sql, binds))
};
match pool {
#[cfg(feature = "postgres")]
Pool::Postgres(pg) => {
let mut tx = pg.begin().await.map_err(ExecError::Driver)?;
sqlx::query(&del_sql)
.bind(self.src_pk_i64())
.execute(&mut *tx)
.await
.map_err(ExecError::Driver)?;
if let Some((ins_sql, binds)) = ins_sql_with_binds {
let mut q = sqlx::query(&ins_sql);
for v in binds {
q = bind_pg(q, v);
}
q.execute(&mut *tx).await.map_err(ExecError::Driver)?;
}
tx.commit().await.map_err(ExecError::Driver)?;
Ok(())
}
#[cfg(feature = "mysql")]
Pool::Mysql(my) => {
let mut tx = my.begin().await.map_err(ExecError::Driver)?;
sqlx::query(&del_sql)
.bind(self.src_pk_i64())
.execute(&mut *tx)
.await
.map_err(ExecError::Driver)?;
if let Some((ins_sql, binds)) = ins_sql_with_binds {
let mut q = sqlx::query(&ins_sql);
for v in binds {
q = bind_my(q, v);
}
q.execute(&mut *tx).await.map_err(ExecError::Driver)?;
}
tx.commit().await.map_err(ExecError::Driver)?;
Ok(())
}
#[cfg(feature = "sqlite")]
Pool::Sqlite(sq) => {
let mut tx = sq.begin().await.map_err(ExecError::Driver)?;
sqlx::query(&del_sql)
.bind(self.src_pk_i64())
.execute(&mut *tx)
.await
.map_err(ExecError::Driver)?;
if let Some((ins_sql, binds)) = ins_sql_with_binds {
let mut q = sqlx::query(&ins_sql);
for v in binds {
q = bind_sqlite(q, v);
}
q.execute(&mut *tx).await.map_err(ExecError::Driver)?;
}
tx.commit().await.map_err(ExecError::Driver)?;
Ok(())
}
}
}
pub async fn clear_pool(&self, pool: &Pool) -> Result<(), ExecError> {
let dialect = pool.dialect();
let sql = format!(
"DELETE FROM {through} WHERE {src} = {p1}",
through = dialect.quote_ident(self.through),
src = dialect.quote_ident(self.src_col),
p1 = dialect.placeholder(1),
);
let binds = vec![SqlValue::I64(self.src_pk_i64())];
super::executor::raw_execute_pool(pool, &sql, binds).await?;
Ok(())
}
pub async fn contains_pool(&self, dst_id: i64, pool: &Pool) -> Result<bool, ExecError> {
let dialect = pool.dialect();
let sql = format!(
"SELECT 1 AS hit FROM {through} WHERE {src} = {p1} AND {dst} = {p2} LIMIT 1",
through = dialect.quote_ident(self.through),
src = dialect.quote_ident(self.src_col),
dst = dialect.quote_ident(self.dst_col),
p1 = dialect.placeholder(1),
p2 = dialect.placeholder(2),
);
let binds = vec![SqlValue::I64(self.src_pk_i64()), SqlValue::I64(dst_id)];
let rows = fetch_i64_col_pool(pool, &sql, binds, "hit").await?;
Ok(!rows.is_empty())
}
fn src_pk_i64(&self) -> i64 {
match &self.src_pk {
SqlValue::I64(v) => *v,
SqlValue::I32(v) => i64::from(*v),
_ => 0,
}
}
}
#[cfg(feature = "postgres")]
impl M2MManager {
pub async fn all(&self, pool: &sqlx::PgPool) -> Result<Vec<i64>, ExecError> {
self.all_pool(&Pool::Postgres(pool.clone())).await
}
pub async fn add(&self, dst_id: i64, pool: &sqlx::PgPool) -> Result<(), ExecError> {
self.add_pool(dst_id, &Pool::Postgres(pool.clone())).await
}
pub async fn remove(&self, dst_id: i64, pool: &sqlx::PgPool) -> Result<(), ExecError> {
self.remove_pool(dst_id, &Pool::Postgres(pool.clone()))
.await
}
pub async fn set(&self, ids: &[i64], pool: &sqlx::PgPool) -> Result<(), ExecError> {
self.set_pool(ids, &Pool::Postgres(pool.clone())).await
}
pub async fn clear(&self, pool: &sqlx::PgPool) -> Result<(), ExecError> {
self.clear_pool(&Pool::Postgres(pool.clone())).await
}
pub async fn contains(&self, dst_id: i64, pool: &sqlx::PgPool) -> Result<bool, ExecError> {
self.contains_pool(dst_id, &Pool::Postgres(pool.clone()))
.await
}
}
async fn fetch_i64_col_pool(
pool: &Pool,
sql: &str,
binds: Vec<SqlValue>,
col_name: &str,
) -> Result<Vec<i64>, ExecError> {
match pool {
#[cfg(feature = "postgres")]
Pool::Postgres(pg) => {
use sqlx::Row as _;
let mut q = sqlx::query(sql);
for v in binds {
q = bind_pg(q, v);
}
let rows = q.fetch_all(pg).await.map_err(ExecError::Driver)?;
rows.iter()
.map(|r| r.try_get::<i64, _>(col_name).map_err(ExecError::Driver))
.collect()
}
#[cfg(feature = "mysql")]
Pool::Mysql(my) => {
use sqlx::Row as _;
let mut q = sqlx::query(sql);
for v in binds {
q = bind_my(q, v);
}
let rows = q.fetch_all(my).await.map_err(ExecError::Driver)?;
rows.iter()
.map(|r| r.try_get::<i64, _>(col_name).map_err(ExecError::Driver))
.collect()
}
#[cfg(feature = "sqlite")]
Pool::Sqlite(sq) => {
use sqlx::Row as _;
let mut q = sqlx::query(sql);
for v in binds {
q = bind_sqlite(q, v);
}
let rows = q.fetch_all(sq).await.map_err(ExecError::Driver)?;
rows.iter()
.map(|r| r.try_get::<i64, _>(col_name).map_err(ExecError::Driver))
.collect()
}
}
}
#[cfg(feature = "postgres")]
fn bind_pg(
q: sqlx::query::Query<'_, sqlx::Postgres, sqlx::postgres::PgArguments>,
v: SqlValue,
) -> sqlx::query::Query<'_, sqlx::Postgres, sqlx::postgres::PgArguments> {
match v {
SqlValue::I64(n) => q.bind(n),
SqlValue::I32(n) => q.bind(n),
SqlValue::String(s) => q.bind(s),
SqlValue::Bool(b) => q.bind(b),
SqlValue::Null => q.bind(None::<i64>),
other => q.bind(other.to_display_string()),
}
}
#[cfg(feature = "mysql")]
fn bind_my<'a>(
q: sqlx::query::Query<'a, sqlx::MySql, sqlx::mysql::MySqlArguments>,
v: SqlValue,
) -> sqlx::query::Query<'a, sqlx::MySql, sqlx::mysql::MySqlArguments> {
match v {
SqlValue::I64(n) => q.bind(n),
SqlValue::I32(n) => q.bind(n),
SqlValue::String(s) => q.bind(s),
SqlValue::Bool(b) => q.bind(b),
SqlValue::Null => q.bind(None::<i64>),
other => q.bind(other.to_display_string()),
}
}
#[cfg(feature = "sqlite")]
fn bind_sqlite<'a>(
q: sqlx::query::Query<'a, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'a>>,
v: SqlValue,
) -> sqlx::query::Query<'a, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'a>> {
match v {
SqlValue::I64(n) => q.bind(n),
SqlValue::I32(n) => q.bind(n),
SqlValue::String(s) => q.bind(s),
SqlValue::Bool(b) => q.bind(b),
SqlValue::Null => q.bind(None::<i64>),
other => q.bind(other.to_display_string()),
}
}