use super::error::ExecError;
use super::Pool;
use crate::core::SqlValue;
pub struct M2MManager {
pub src_pk: SqlValue,
pub through: &'static str,
pub src_col: &'static str,
pub dst_col: &'static str,
}
impl M2MManager {
pub async fn all(&self, pool: &Pool) -> Result<Vec<i64>, ExecError> {
let dialect = pool.dialect();
let sql = format!(
"SELECT {dst} FROM {through} WHERE {src} = {p1}",
through = dialect.quote_ident(self.through),
src = dialect.quote_ident(self.src_col),
dst = dialect.quote_ident(self.dst_col),
p1 = dialect.placeholder(1),
);
let binds = vec![SqlValue::I64(self.src_pk_i64())];
fetch_i64_col_pool(pool, &sql, binds, self.dst_col).await
}
pub async fn add(&self, dst_id: i64, pool: &Pool) -> Result<(), ExecError> {
let dialect = pool.dialect();
let (insert_kw, suffix) = match dialect.name() {
"mysql" => ("INSERT IGNORE INTO", ""),
_ => ("INSERT INTO", " ON CONFLICT DO NOTHING"),
};
let sql = format!(
"{insert_kw} {through} ({src}, {dst}) VALUES ({p1}, {p2}){suffix}",
through = dialect.quote_ident(self.through),
src = dialect.quote_ident(self.src_col),
dst = dialect.quote_ident(self.dst_col),
p1 = dialect.placeholder(1),
p2 = dialect.placeholder(2),
);
let binds = vec![SqlValue::I64(self.src_pk_i64()), SqlValue::I64(dst_id)];
super::executor::raw_execute_pool(pool, &sql, binds).await?;
crate::signals::m2m::send_m2m_changed(crate::signals::m2m::M2mChangedContext {
action: crate::signals::m2m::M2mAction::Add,
through: self.through,
src_col: self.src_col,
dst_col: self.dst_col,
src_pk: self.src_pk_i64(),
dst_pks: vec![dst_id],
})
.await;
Ok(())
}
pub async fn remove(&self, dst_id: i64, pool: &Pool) -> Result<(), ExecError> {
let dialect = pool.dialect();
let sql = format!(
"DELETE FROM {through} WHERE {src} = {p1} AND {dst} = {p2}",
through = dialect.quote_ident(self.through),
src = dialect.quote_ident(self.src_col),
dst = dialect.quote_ident(self.dst_col),
p1 = dialect.placeholder(1),
p2 = dialect.placeholder(2),
);
let binds = vec![SqlValue::I64(self.src_pk_i64()), SqlValue::I64(dst_id)];
super::executor::raw_execute_pool(pool, &sql, binds).await?;
crate::signals::m2m::send_m2m_changed(crate::signals::m2m::M2mChangedContext {
action: crate::signals::m2m::M2mAction::Remove,
through: self.through,
src_col: self.src_col,
dst_col: self.dst_col,
src_pk: self.src_pk_i64(),
dst_pks: vec![dst_id],
})
.await;
Ok(())
}
pub async fn set(&self, ids: &[i64], pool: &Pool) -> Result<(), ExecError> {
let dialect = pool.dialect();
let del_sql = format!(
"DELETE FROM {through} WHERE {src} = {p1}",
through = dialect.quote_ident(self.through),
src = dialect.quote_ident(self.src_col),
p1 = dialect.placeholder(1),
);
let ins_sql_with_binds = if ids.is_empty() {
None
} else {
let mut sql = format!(
"INSERT INTO {through} ({src}, {dst}) VALUES ",
through = dialect.quote_ident(self.through),
src = dialect.quote_ident(self.src_col),
dst = dialect.quote_ident(self.dst_col),
);
let mut binds = Vec::with_capacity(ids.len() * 2);
let src_pk = self.src_pk_i64();
for (i, dst_id) in ids.iter().enumerate() {
if i > 0 {
sql.push_str(", ");
}
let p_src = dialect.placeholder(i * 2 + 1);
let p_dst = dialect.placeholder(i * 2 + 2);
sql.push_str(&format!("({p_src}, {p_dst})"));
binds.push(SqlValue::I64(src_pk));
binds.push(SqlValue::I64(*dst_id));
}
Some((sql, binds))
};
let mut tx = crate::sql::transaction_pool(pool).await?;
crate::sql::raw_execute_tx(&mut tx, &del_sql, vec![SqlValue::I64(self.src_pk_i64())])
.await?;
if let Some((ins_sql, binds)) = ins_sql_with_binds {
crate::sql::raw_execute_tx(&mut tx, &ins_sql, binds).await?;
}
tx.commit().await.map_err(ExecError::Driver)?;
crate::signals::m2m::send_m2m_changed(crate::signals::m2m::M2mChangedContext {
action: crate::signals::m2m::M2mAction::Set,
through: self.through,
src_col: self.src_col,
dst_col: self.dst_col,
src_pk: self.src_pk_i64(),
dst_pks: ids.to_vec(),
})
.await;
Ok(())
}
pub async fn clear(&self, pool: &Pool) -> Result<(), ExecError> {
let dialect = pool.dialect();
let sql = format!(
"DELETE FROM {through} WHERE {src} = {p1}",
through = dialect.quote_ident(self.through),
src = dialect.quote_ident(self.src_col),
p1 = dialect.placeholder(1),
);
let binds = vec![SqlValue::I64(self.src_pk_i64())];
super::executor::raw_execute_pool(pool, &sql, binds).await?;
crate::signals::m2m::send_m2m_changed(crate::signals::m2m::M2mChangedContext {
action: crate::signals::m2m::M2mAction::Clear,
through: self.through,
src_col: self.src_col,
dst_col: self.dst_col,
src_pk: self.src_pk_i64(),
dst_pks: Vec::new(),
})
.await;
Ok(())
}
pub async fn contains(&self, dst_id: i64, pool: &Pool) -> Result<bool, ExecError> {
let dialect = pool.dialect();
let sql = format!(
"SELECT 1 AS hit FROM {through} WHERE {src} = {p1} AND {dst} = {p2} LIMIT 1",
through = dialect.quote_ident(self.through),
src = dialect.quote_ident(self.src_col),
dst = dialect.quote_ident(self.dst_col),
p1 = dialect.placeholder(1),
p2 = dialect.placeholder(2),
);
let binds = vec![SqlValue::I64(self.src_pk_i64()), SqlValue::I64(dst_id)];
let rows = fetch_i64_col_pool(pool, &sql, binds, "hit").await?;
Ok(!rows.is_empty())
}
fn src_pk_i64(&self) -> i64 {
match &self.src_pk {
SqlValue::I64(v) => *v,
SqlValue::I32(v) => i64::from(*v),
_ => 0,
}
}
}
pub struct GenericM2MManager {
pub src_pk: SqlValue,
pub src_schema: &'static crate::core::ModelSchema,
pub through: &'static str,
pub pk_col: &'static str,
pub ct_col: &'static str,
pub dst_col: &'static str,
}
impl GenericM2MManager {
fn src_pk_i64(&self) -> i64 {
match &self.src_pk {
SqlValue::I64(v) => *v,
SqlValue::I32(v) => i64::from(*v),
_ => 0,
}
}
async fn ct_id(&self, pool: &Pool) -> Result<i64, ExecError> {
crate::contenttypes::ContentType::get_for_schema(pool, self.src_schema)
.await?
.and_then(|ct| ct.id.get().copied())
.ok_or(ExecError::ContentTypeNotRegistered {
table: self.src_schema.table,
})
}
async fn signal(&self, action: crate::signals::m2m::M2mAction, dst_pks: Vec<i64>) {
crate::signals::m2m::send_m2m_changed(crate::signals::m2m::M2mChangedContext {
action,
through: self.through,
src_col: self.pk_col,
dst_col: self.dst_col,
src_pk: self.src_pk_i64(),
dst_pks,
})
.await;
}
pub async fn all(&self, pool: &Pool) -> Result<Vec<i64>, ExecError> {
let dialect = pool.dialect();
let ct = self.ct_id(pool).await?;
let sql = format!(
"SELECT {dst} FROM {through} WHERE {pk} = {p1} AND {ctc} = {p2}",
through = dialect.quote_ident(self.through),
pk = dialect.quote_ident(self.pk_col),
ctc = dialect.quote_ident(self.ct_col),
dst = dialect.quote_ident(self.dst_col),
p1 = dialect.placeholder(1),
p2 = dialect.placeholder(2),
);
let binds = vec![SqlValue::I64(self.src_pk_i64()), SqlValue::I64(ct)];
fetch_i64_col_pool(pool, &sql, binds, self.dst_col).await
}
pub async fn add(&self, dst_id: i64, pool: &Pool) -> Result<(), ExecError> {
let dialect = pool.dialect();
let ct = self.ct_id(pool).await?;
let (insert_kw, suffix) = match dialect.name() {
"mysql" => ("INSERT IGNORE INTO", ""),
_ => ("INSERT INTO", " ON CONFLICT DO NOTHING"),
};
let sql = format!(
"{insert_kw} {through} ({pk}, {ctc}, {dst}) VALUES ({p1}, {p2}, {p3}){suffix}",
through = dialect.quote_ident(self.through),
pk = dialect.quote_ident(self.pk_col),
ctc = dialect.quote_ident(self.ct_col),
dst = dialect.quote_ident(self.dst_col),
p1 = dialect.placeholder(1),
p2 = dialect.placeholder(2),
p3 = dialect.placeholder(3),
);
let binds = vec![
SqlValue::I64(self.src_pk_i64()),
SqlValue::I64(ct),
SqlValue::I64(dst_id),
];
super::executor::raw_execute_pool(pool, &sql, binds).await?;
self.signal(crate::signals::m2m::M2mAction::Add, vec![dst_id])
.await;
Ok(())
}
pub async fn remove(&self, dst_id: i64, pool: &Pool) -> Result<(), ExecError> {
let dialect = pool.dialect();
let ct = self.ct_id(pool).await?;
let sql = format!(
"DELETE FROM {through} WHERE {pk} = {p1} AND {ctc} = {p2} AND {dst} = {p3}",
through = dialect.quote_ident(self.through),
pk = dialect.quote_ident(self.pk_col),
ctc = dialect.quote_ident(self.ct_col),
dst = dialect.quote_ident(self.dst_col),
p1 = dialect.placeholder(1),
p2 = dialect.placeholder(2),
p3 = dialect.placeholder(3),
);
let binds = vec![
SqlValue::I64(self.src_pk_i64()),
SqlValue::I64(ct),
SqlValue::I64(dst_id),
];
super::executor::raw_execute_pool(pool, &sql, binds).await?;
self.signal(crate::signals::m2m::M2mAction::Remove, vec![dst_id])
.await;
Ok(())
}
pub async fn set(&self, ids: &[i64], pool: &Pool) -> Result<(), ExecError> {
let dialect = pool.dialect();
let ct = self.ct_id(pool).await?;
let src_pk = self.src_pk_i64();
let del_sql = format!(
"DELETE FROM {through} WHERE {pk} = {p1} AND {ctc} = {p2}",
through = dialect.quote_ident(self.through),
pk = dialect.quote_ident(self.pk_col),
ctc = dialect.quote_ident(self.ct_col),
p1 = dialect.placeholder(1),
p2 = dialect.placeholder(2),
);
let ins = if ids.is_empty() {
None
} else {
let mut sql = format!(
"INSERT INTO {through} ({pk}, {ctc}, {dst}) VALUES ",
through = dialect.quote_ident(self.through),
pk = dialect.quote_ident(self.pk_col),
ctc = dialect.quote_ident(self.ct_col),
dst = dialect.quote_ident(self.dst_col),
);
let mut binds = Vec::with_capacity(ids.len() * 3);
for (i, dst_id) in ids.iter().enumerate() {
if i > 0 {
sql.push_str(", ");
}
let p1 = dialect.placeholder(i * 3 + 1);
let p2 = dialect.placeholder(i * 3 + 2);
let p3 = dialect.placeholder(i * 3 + 3);
sql.push_str(&format!("({p1}, {p2}, {p3})"));
binds.push(SqlValue::I64(src_pk));
binds.push(SqlValue::I64(ct));
binds.push(SqlValue::I64(*dst_id));
}
Some((sql, binds))
};
let mut tx = crate::sql::transaction_pool(pool).await?;
crate::sql::raw_execute_tx(
&mut tx,
&del_sql,
vec![SqlValue::I64(src_pk), SqlValue::I64(ct)],
)
.await?;
if let Some((ins_sql, binds)) = ins {
crate::sql::raw_execute_tx(&mut tx, &ins_sql, binds).await?;
}
tx.commit().await.map_err(ExecError::Driver)?;
self.signal(crate::signals::m2m::M2mAction::Set, ids.to_vec())
.await;
Ok(())
}
pub async fn clear(&self, pool: &Pool) -> Result<(), ExecError> {
let dialect = pool.dialect();
let ct = self.ct_id(pool).await?;
let sql = format!(
"DELETE FROM {through} WHERE {pk} = {p1} AND {ctc} = {p2}",
through = dialect.quote_ident(self.through),
pk = dialect.quote_ident(self.pk_col),
ctc = dialect.quote_ident(self.ct_col),
p1 = dialect.placeholder(1),
p2 = dialect.placeholder(2),
);
let binds = vec![SqlValue::I64(self.src_pk_i64()), SqlValue::I64(ct)];
super::executor::raw_execute_pool(pool, &sql, binds).await?;
self.signal(crate::signals::m2m::M2mAction::Clear, Vec::new())
.await;
Ok(())
}
pub async fn contains(&self, dst_id: i64, pool: &Pool) -> Result<bool, ExecError> {
let dialect = pool.dialect();
let ct = self.ct_id(pool).await?;
let sql = format!(
"SELECT {dst} FROM {through} WHERE {pk} = {p1} AND {ctc} = {p2} AND {dst} = {p3} LIMIT 1",
through = dialect.quote_ident(self.through),
pk = dialect.quote_ident(self.pk_col),
ctc = dialect.quote_ident(self.ct_col),
dst = dialect.quote_ident(self.dst_col),
p1 = dialect.placeholder(1),
p2 = dialect.placeholder(2),
p3 = dialect.placeholder(3),
);
let binds = vec![
SqlValue::I64(self.src_pk_i64()),
SqlValue::I64(ct),
SqlValue::I64(dst_id),
];
let rows = fetch_i64_col_pool(pool, &sql, binds, "hit").await?;
Ok(!rows.is_empty())
}
}
impl M2MManager {
#[deprecated(note = "renamed to `all` — drop the `_pool` suffix")]
pub async fn all_pool(&self, pool: &Pool) -> Result<Vec<i64>, ExecError> {
self.all(pool).await
}
#[deprecated(note = "renamed to `add` — drop the `_pool` suffix")]
pub async fn add_pool(&self, dst_id: i64, pool: &Pool) -> Result<(), ExecError> {
self.add(dst_id, pool).await
}
#[deprecated(note = "renamed to `remove` — drop the `_pool` suffix")]
pub async fn remove_pool(&self, dst_id: i64, pool: &Pool) -> Result<(), ExecError> {
self.remove(dst_id, pool).await
}
#[deprecated(note = "renamed to `set` — drop the `_pool` suffix")]
pub async fn set_pool(&self, ids: &[i64], pool: &Pool) -> Result<(), ExecError> {
self.set(ids, pool).await
}
#[deprecated(note = "renamed to `clear` — drop the `_pool` suffix")]
pub async fn clear_pool(&self, pool: &Pool) -> Result<(), ExecError> {
self.clear(pool).await
}
#[deprecated(note = "renamed to `contains` — drop the `_pool` suffix")]
pub async fn contains_pool(&self, dst_id: i64, pool: &Pool) -> Result<bool, ExecError> {
self.contains(dst_id, pool).await
}
}
async fn fetch_i64_col_pool(
pool: &Pool,
sql: &str,
binds: Vec<SqlValue>,
_col_name: &str,
) -> Result<Vec<i64>, ExecError> {
let rows: Vec<(i64,)> = crate::sql::raw_query_pool(sql, binds, pool).await?;
Ok(rows.into_iter().map(|(v,)| v).collect())
}