use serde::{Deserialize, Serialize};
use tokio_postgres::Client;
use crate::postgres::PgError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdateOutcome {
pub rows_affected: u64,
}
pub async fn update_cell(
client: &Client,
schema: &str,
table: &str,
column: &str,
column_type: &str,
new_value: Option<&str>,
ctid: &str,
) -> Result<UpdateOutcome, PgError> {
let qualified = format!("{}.{}", quote_ident(schema), quote_ident(table));
let col = quote_ident(column);
let rows_affected = match new_value {
Some(value) => {
let sql = format!(
"UPDATE {qualified} SET {col} = $1::{ty} WHERE ctid = $2::tid",
qualified = qualified,
col = col,
ty = quote_ident(column_type),
);
client
.execute(&sql, &[&value, &ctid])
.await
.map_err(PgError::Driver)?
}
None => {
let sql = format!(
"UPDATE {qualified} SET {col} = NULL WHERE ctid = $1::tid",
qualified = qualified,
col = col,
);
client
.execute(&sql, &[&ctid])
.await
.map_err(PgError::Driver)?
}
};
Ok(UpdateOutcome { rows_affected })
}
#[derive(Debug, Clone)]
pub struct InsertColumnInput {
pub name: String,
pub type_name: String,
pub value: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InsertedRow {
pub cells: Vec<Option<String>>,
}
pub async fn insert_row(
client: &Client,
schema: &str,
table: &str,
inputs: &[InsertColumnInput],
return_columns: &[String],
) -> Result<InsertedRow, PgError> {
let qualified = format!("{}.{}", quote_ident(schema), quote_ident(table));
let column_clause: String;
let values_clause: String;
if inputs.is_empty() {
column_clause = String::new();
values_clause = "DEFAULT VALUES".to_string();
} else {
let cols = inputs
.iter()
.map(|i| quote_ident(&i.name))
.collect::<Vec<_>>()
.join(", ");
column_clause = format!(" ({cols})");
let placeholders = inputs
.iter()
.enumerate()
.map(|(idx, i)| format!("${}::{}", idx + 1, quote_ident(&i.type_name)))
.collect::<Vec<_>>()
.join(", ");
values_clause = format!("VALUES ({placeholders})");
}
let returning = if return_columns.is_empty() {
"ctid::text AS \"__pg_rowid__\"".to_string()
} else {
return_columns
.iter()
.map(|name| {
let alias = quote_ident(name);
if name == "__pg_rowid__" {
format!("ctid::text AS {alias}")
} else {
format!("{}::text AS {alias}", quote_ident(name))
}
})
.collect::<Vec<_>>()
.join(", ")
};
let sql = format!(
"INSERT INTO {qualified}{column_clause} {values_clause} RETURNING {returning}",
qualified = qualified,
column_clause = column_clause,
values_clause = values_clause,
returning = returning,
);
let params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = inputs
.iter()
.map(|i| &i.value as &(dyn tokio_postgres::types::ToSql + Sync))
.collect();
let rows = client.query(&sql, ¶ms).await.map_err(PgError::Driver)?;
let row = rows.into_iter().next().ok_or_else(|| {
PgError::Connect("INSERT returned no row".to_string())
})?;
let mut cells: Vec<Option<String>> = Vec::with_capacity(row.len());
for idx in 0..row.len() {
let v: Option<String> = row.try_get(idx).map_err(PgError::Driver)?;
cells.push(v);
}
Ok(InsertedRow { cells })
}
pub async fn delete_rows(
client: &Client,
schema: &str,
table: &str,
ctids: &[String],
) -> Result<UpdateOutcome, PgError> {
if ctids.is_empty() {
return Ok(UpdateOutcome { rows_affected: 0 });
}
let qualified = format!("{}.{}", quote_ident(schema), quote_ident(table));
let sql = format!(
"DELETE FROM {qualified} WHERE ctid = ANY($1::text[]::tid[])",
qualified = qualified,
);
let rows_affected = client
.execute(&sql, &[&ctids])
.await
.map_err(PgError::Driver)?;
Ok(UpdateOutcome { rows_affected })
}
fn quote_ident(s: &str) -> String {
format!("\"{}\"", s.replace('"', "\"\""))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn quote_ident_handles_simple_and_quoted_names() {
assert_eq!(quote_ident("users"), "\"users\"");
assert_eq!(quote_ident("MyTable"), "\"MyTable\"");
assert_eq!(quote_ident("with\"quote"), "\"with\"\"quote\"");
assert_eq!(quote_ident("order"), "\"order\"");
assert_eq!(quote_ident(""), "\"\"");
}
#[test]
fn update_outcome_round_trips() {
let o = UpdateOutcome { rows_affected: 1 };
let json = serde_json::to_string(&o).expect("serialize");
let back: UpdateOutcome = serde_json::from_str(&json).expect("deserialize");
assert_eq!(back.rows_affected, 1);
}
}