use std::marker::PhantomData;
use super::expr::ExprBuilder;
use super::value::SqlValue;
pub struct NoTable;
pub struct HasTable;
pub struct DeleteDyn<Table> {
table: Option<String>,
where_clause: Option<ExprBuilder>,
_state: PhantomData<Table>,
}
impl DeleteDyn<NoTable> {
#[must_use]
pub fn new() -> Self {
Self {
table: None,
where_clause: None,
_state: PhantomData,
}
}
}
impl Default for DeleteDyn<NoTable> {
fn default() -> Self {
Self::new()
}
}
impl DeleteDyn<NoTable> {
#[must_use]
pub fn from(self, table: &str) -> DeleteDyn<HasTable> {
DeleteDyn {
table: Some(String::from(table)),
where_clause: self.where_clause,
_state: PhantomData,
}
}
}
impl DeleteDyn<HasTable> {
#[must_use]
pub fn where_clause(mut self, expr: ExprBuilder) -> Self {
self.where_clause = Some(expr);
self
}
#[must_use]
pub fn build(self) -> (String, Vec<SqlValue>) {
let mut sql = String::from("DELETE FROM ");
let mut params = vec![];
if let Some(ref table) = self.table {
sql.push_str(table);
}
if let Some(ref where_expr) = self.where_clause {
sql.push_str(" WHERE ");
sql.push_str(where_expr.sql());
params.extend(where_expr.params().iter().cloned());
}
(sql, params)
}
#[must_use]
pub fn build_sql(self) -> String {
let (sql, _) = self.build();
sql
}
#[must_use]
pub const fn has_where_clause(&self) -> bool {
self.where_clause.is_some()
}
}
pub struct SafeDeleteDyn<Table> {
inner: DeleteDyn<Table>,
}
impl SafeDeleteDyn<NoTable> {
#[must_use]
pub fn new() -> Self {
Self {
inner: DeleteDyn::new(),
}
}
#[must_use]
pub fn from(self, table: &str) -> SafeDeleteDyn<HasTable> {
SafeDeleteDyn {
inner: self.inner.from(table),
}
}
}
impl Default for SafeDeleteDyn<NoTable> {
fn default() -> Self {
Self::new()
}
}
pub struct SafeDeleteDynWithWhere {
inner: DeleteDyn<HasTable>,
}
impl SafeDeleteDyn<HasTable> {
#[must_use]
pub fn where_clause(self, expr: ExprBuilder) -> SafeDeleteDynWithWhere {
SafeDeleteDynWithWhere {
inner: self.inner.where_clause(expr),
}
}
}
impl SafeDeleteDynWithWhere {
#[must_use]
pub fn build(self) -> (String, Vec<SqlValue>) {
self.inner.build()
}
#[must_use]
pub fn build_sql(self) -> String {
self.inner.build_sql()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::builder::dyn_col;
#[test]
fn test_simple_delete() {
let (sql, params) = DeleteDyn::new()
.from("users")
.where_clause(dyn_col("id").eq(1_i32))
.build();
assert_eq!(sql, "DELETE FROM users WHERE id = ?");
assert_eq!(params.len(), 1);
}
#[test]
fn test_delete_all() {
let (sql, params) = DeleteDyn::new().from("temp_data").build();
assert_eq!(sql, "DELETE FROM temp_data");
assert!(params.is_empty());
}
#[test]
fn test_delete_complex_where() {
let (sql, params) = DeleteDyn::new()
.from("orders")
.where_clause(
dyn_col("status")
.eq("cancelled")
.and(dyn_col("created_at").lt("2024-01-01")),
)
.build();
assert_eq!(
sql,
"DELETE FROM orders WHERE status = ? AND created_at < ?"
);
assert_eq!(params.len(), 2);
}
#[test]
fn test_safe_delete() {
let (sql, params) = SafeDeleteDyn::new()
.from("users")
.where_clause(dyn_col("id").eq(1_i32))
.build();
assert_eq!(sql, "DELETE FROM users WHERE id = ?");
assert_eq!(params.len(), 1);
}
#[test]
fn test_delete_sql_injection_prevention() {
let malicious = "1; DROP TABLE users; --";
let (sql, params) = DeleteDyn::new()
.from("users")
.where_clause(dyn_col("id").eq(malicious))
.build();
assert_eq!(sql, "DELETE FROM users WHERE id = ?");
assert!(matches!(¶ms[0], SqlValue::Text(s) if s == malicious));
}
}