use std::marker::PhantomData;
use super::expr::ExprBuilder;
use super::value::{SqlValue, ToSqlValue};
pub struct NoTable;
pub struct HasTable;
pub struct NoSet;
pub struct HasSet;
struct Assignment {
column: String,
value: SqlValue,
}
pub struct UpdateDyn<Table, Set> {
table: Option<String>,
assignments: Vec<Assignment>,
where_clause: Option<ExprBuilder>,
_state: PhantomData<(Table, Set)>,
}
impl UpdateDyn<NoTable, NoSet> {
#[must_use]
pub fn new() -> Self {
Self {
table: None,
assignments: vec![],
where_clause: None,
_state: PhantomData,
}
}
}
impl Default for UpdateDyn<NoTable, NoSet> {
fn default() -> Self {
Self::new()
}
}
impl<Set> UpdateDyn<NoTable, Set> {
#[must_use]
pub fn table(self, table: &str) -> UpdateDyn<HasTable, Set> {
UpdateDyn {
table: Some(String::from(table)),
assignments: self.assignments,
where_clause: self.where_clause,
_state: PhantomData,
}
}
}
impl UpdateDyn<HasTable, NoSet> {
#[must_use]
pub fn set<T: ToSqlValue>(self, column: &str, value: T) -> UpdateDyn<HasTable, HasSet> {
UpdateDyn {
table: self.table,
assignments: vec![Assignment {
column: String::from(column),
value: value.to_sql_value(),
}],
where_clause: self.where_clause,
_state: PhantomData,
}
}
}
impl UpdateDyn<HasTable, HasSet> {
#[must_use]
pub fn set<T: ToSqlValue>(mut self, column: &str, value: T) -> Self {
self.assignments.push(Assignment {
column: String::from(column),
value: value.to_sql_value(),
});
self
}
#[must_use]
pub fn where_clause(mut self, expr: ExprBuilder) -> Self {
self.where_clause = Some(expr);
self
}
#[must_use]
pub fn build(self) -> (String, Vec<SqlValue>) {
let mut sql = String::from("UPDATE ");
let mut params = vec![];
if let Some(ref table) = self.table {
sql.push_str(table);
}
sql.push_str(" SET ");
let set_parts: Vec<String> = self
.assignments
.iter()
.map(|a| format!("{} = ?", a.column))
.collect();
sql.push_str(&set_parts.join(", "));
for assignment in self.assignments {
params.push(assignment.value);
}
if let Some(ref where_expr) = self.where_clause {
sql.push_str(" WHERE ");
sql.push_str(where_expr.sql());
params.extend(where_expr.params().iter().cloned());
}
(sql, params)
}
#[must_use]
pub fn build_sql(self) -> String {
let (sql, _) = self.build();
sql
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::builder::dyn_col;
#[test]
fn test_simple_update() {
let (sql, params) = UpdateDyn::new().table("users").set("name", "Bob").build();
assert_eq!(sql, "UPDATE users SET name = ?");
assert_eq!(params.len(), 1);
}
#[test]
fn test_update_multiple_columns() {
let (sql, params) = UpdateDyn::new()
.table("users")
.set("name", "Bob")
.set("email", "bob@example.com")
.set("age", 30_i32)
.build();
assert_eq!(sql, "UPDATE users SET name = ?, email = ?, age = ?");
assert_eq!(params.len(), 3);
}
#[test]
fn test_update_with_where() {
let (sql, params) = UpdateDyn::new()
.table("users")
.set("active", false)
.where_clause(dyn_col("id").eq(1_i32))
.build();
assert_eq!(sql, "UPDATE users SET active = ? WHERE id = ?");
assert_eq!(params.len(), 2);
}
#[test]
fn test_update_sql_injection_prevention() {
let malicious = "'; DROP TABLE users; --";
let (sql, params) = UpdateDyn::new()
.table("users")
.set("name", malicious)
.where_clause(dyn_col("id").eq(1_i32))
.build();
assert_eq!(sql, "UPDATE users SET name = ? WHERE id = ?");
assert!(matches!(¶ms[0], SqlValue::Text(s) if s == malicious));
}
}