use crate::{
builder::SqlFragment,
expr::Expr,
identifier::{escape_ident, from_qi, QualifiedIdentifier},
};
#[derive(Clone, Debug, Default)]
pub struct UpdateBuilder {
table: Option<SqlFragment>,
set: Vec<(String, SqlFragment)>,
where_clauses: Vec<SqlFragment>,
returning: Vec<SqlFragment>,
}
impl UpdateBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn table(mut self, qi: &QualifiedIdentifier) -> Self {
self.table = Some(SqlFragment::raw(from_qi(qi)));
self
}
pub fn table_as(mut self, qi: &QualifiedIdentifier, alias: &str) -> Self {
self.table = Some(SqlFragment::raw(format!(
"{} AS {}",
from_qi(qi),
escape_ident(alias)
)));
self
}
pub fn set<V: Into<crate::param::SqlParam>>(mut self, column: &str, value: V) -> Self {
let mut frag = SqlFragment::new();
frag.push_param(value);
self.set.push((column.to_string(), frag));
self
}
pub fn set_raw(mut self, column: &str, value: SqlFragment) -> Self {
self.set.push((column.to_string(), value));
self
}
pub fn where_expr(mut self, expr: Expr) -> Self {
self.where_clauses.push(expr.into_fragment());
self
}
pub fn where_raw(mut self, sql: SqlFragment) -> Self {
self.where_clauses.push(sql);
self
}
pub fn returning(mut self, column: &str) -> Self {
self.returning
.push(SqlFragment::raw(escape_ident(column)));
self
}
pub fn returning_all(mut self) -> Self {
self.returning.push(SqlFragment::raw("*"));
self
}
pub fn build(self) -> SqlFragment {
let mut result = SqlFragment::new();
result.push("UPDATE ");
if let Some(table) = self.table {
result.append(table);
}
if !self.set.is_empty() {
result.push(" SET ");
for (i, (col, val)) in self.set.into_iter().enumerate() {
if i > 0 {
result.push(", ");
}
result.push(&escape_ident(&col));
result.push(" = ");
result.append(val);
}
}
if !self.where_clauses.is_empty() {
result.push(" WHERE ");
for (i, clause) in self.where_clauses.into_iter().enumerate() {
if i > 0 {
result.push(" AND ");
}
result.append(clause);
}
}
if !self.returning.is_empty() {
result.push(" RETURNING ");
for (i, ret) in self.returning.into_iter().enumerate() {
if i > 0 {
result.push(", ");
}
result.append(ret);
}
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::param::SqlParam;
#[test]
fn test_simple_update() {
let qi = QualifiedIdentifier::new("public", "users");
let sql = UpdateBuilder::new()
.table(&qi)
.set("name", SqlParam::text("Jane"))
.where_expr(Expr::eq("id", 1i64))
.build();
assert!(sql.sql().contains("UPDATE"));
assert!(sql.sql().contains("SET"));
assert!(sql.sql().contains("WHERE"));
assert_eq!(sql.params().len(), 2);
}
#[test]
fn test_update_returning() {
let qi = QualifiedIdentifier::unqualified("users");
let sql = UpdateBuilder::new()
.table(&qi)
.set("status", SqlParam::text("active"))
.returning_all()
.build();
assert!(sql.sql().contains("RETURNING *"));
}
#[test]
fn test_update_multiple_sets() {
let qi = QualifiedIdentifier::unqualified("users");
let sql = UpdateBuilder::new()
.table(&qi)
.set("name", SqlParam::text("John"))
.set("email", SqlParam::text("john@new.com"))
.set("updated_at", SqlParam::text("now()"))
.where_expr(Expr::eq("id", 5i64))
.build();
assert_eq!(sql.params().len(), 4); }
}