use sqlx::postgres::PgPoolCopyExt;
use crate::core::condition::SqlValue;
use crate::orm::postgres::model::PgModel;
pub struct BatchService<M>(std::marker::PhantomData<M>);
impl<M: PgModel> BatchService<M> {
pub async fn bulk_insert(
rows: &[Vec<(&str, SqlValue)>],
pool: &sqlx::PgPool,
) -> Result<u64, sqlx::Error> {
M::bulk_create(pool, rows).await
}
pub async fn bulk_insert_chunked(
rows: &[Vec<(&str, SqlValue)>],
chunk_size: usize,
pool: &sqlx::PgPool,
) -> Result<u64, sqlx::Error> {
M::bulk_insert_chunked(pool, rows, chunk_size).await
}
pub async fn bulk_upsert_by(
unique_col: &str,
rows: &[Vec<(&str, SqlValue)>],
pool: &sqlx::PgPool,
) -> Result<u64, sqlx::Error> {
M::bulk_upsert(pool, rows, &[unique_col]).await
}
pub async fn delete_where(
conditions: &[(&str, SqlValue)],
pool: &sqlx::PgPool,
) -> Result<u64, sqlx::Error> {
let mut builder = M::query();
for (col, val) in conditions {
builder = builder.where_eq(col, val.clone());
}
M::delete_where(pool, builder).await
}
pub async fn bulk_update(
data: &[(&str, SqlValue)],
ids: impl IntoIterator<Item = impl Into<SqlValue>>,
pool: &sqlx::PgPool,
) -> Result<u64, sqlx::Error> {
let id_vals: Vec<SqlValue> = ids.into_iter().map(Into::into).collect();
if id_vals.is_empty() {
return Ok(0);
}
let builder = M::query().where_in(M::primary_key(), id_vals);
M::update_where(pool, builder, data).await
}
pub async fn copy_insert(
rows: &[Vec<(&str, SqlValue)>],
pool: &sqlx::PgPool,
) -> Result<u64, sqlx::Error> {
if rows.is_empty() {
return Ok(0);
}
let cols: Vec<&str> = rows[0].iter().map(|(c, _)| *c).collect();
let col_list: Vec<String> = cols.iter().map(|c| format!("\"{c}\"")).collect();
let copy_sql = format!(
"COPY \"{}\" ({}) FROM STDIN (FORMAT CSV, NULL '')",
M::table_name(),
col_list.join(", ")
);
let mut copy = pool.copy_in_raw(©_sql).await?;
for row in rows {
let values: Vec<String> = row.iter().map(|(_, v)| csv_escape(v)).collect();
let line = format!("{}\n", values.join(","));
copy.send(line.as_bytes()).await?;
}
let rows_copied = copy.finish().await?;
Ok(rows_copied)
}
}
fn csv_escape(val: &SqlValue) -> String {
match val {
SqlValue::Null => String::new(),
SqlValue::Bool(b) => {
if *b {
"t".into()
} else {
"f".into()
}
}
SqlValue::Integer(n) => n.to_string(),
SqlValue::Float(f) => f.to_string(),
SqlValue::Text(s) => {
format!("\"{}\"", s.replace('"', "\"\""))
}
SqlValue::Uuid(u) => u.to_string(),
SqlValue::Json(j) => {
let s = j.to_string();
format!("\"{}\"", s.replace('"', "\"\""))
}
SqlValue::Array(vals) => {
let inner: Vec<String> = vals.iter().map(csv_escape).collect();
format!("\"{{{}}}\"", inner.join(","))
}
}
}