oxide_sql_core/builder/
delete.rs1use std::marker::PhantomData;
7
8use super::expr::ExprBuilder;
9use super::value::SqlValue;
10
11pub struct NoTable;
15pub struct HasTable;
17
18pub struct DeleteDyn<Table> {
26 table: Option<String>,
27 where_clause: Option<ExprBuilder>,
28 _state: PhantomData<Table>,
29}
30
31impl DeleteDyn<NoTable> {
32 #[must_use]
34 pub fn new() -> Self {
35 Self {
36 table: None,
37 where_clause: None,
38 _state: PhantomData,
39 }
40 }
41}
42
43impl Default for DeleteDyn<NoTable> {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49impl DeleteDyn<NoTable> {
51 #[must_use]
53 pub fn from(self, table: &str) -> DeleteDyn<HasTable> {
54 DeleteDyn {
55 table: Some(String::from(table)),
56 where_clause: self.where_clause,
57 _state: PhantomData,
58 }
59 }
60}
61
62impl DeleteDyn<HasTable> {
64 #[must_use]
69 pub fn where_clause(mut self, expr: ExprBuilder) -> Self {
70 self.where_clause = Some(expr);
71 self
72 }
73
74 #[must_use]
78 pub fn build(self) -> (String, Vec<SqlValue>) {
79 let mut sql = String::from("DELETE FROM ");
80 let mut params = vec![];
81
82 if let Some(ref table) = self.table {
83 sql.push_str(table);
84 }
85
86 if let Some(ref where_expr) = self.where_clause {
87 sql.push_str(" WHERE ");
88 sql.push_str(where_expr.sql());
89 params.extend(where_expr.params().iter().cloned());
90 }
91
92 (sql, params)
93 }
94
95 #[must_use]
97 pub fn build_sql(self) -> String {
98 let (sql, _) = self.build();
99 sql
100 }
101
102 #[must_use]
104 pub const fn has_where_clause(&self) -> bool {
105 self.where_clause.is_some()
106 }
107}
108
109pub struct SafeDeleteDyn<Table> {
113 inner: DeleteDyn<Table>,
114}
115
116impl SafeDeleteDyn<NoTable> {
117 #[must_use]
119 pub fn new() -> Self {
120 Self {
121 inner: DeleteDyn::new(),
122 }
123 }
124
125 #[must_use]
127 pub fn from(self, table: &str) -> SafeDeleteDyn<HasTable> {
128 SafeDeleteDyn {
129 inner: self.inner.from(table),
130 }
131 }
132}
133
134impl Default for SafeDeleteDyn<NoTable> {
135 fn default() -> Self {
136 Self::new()
137 }
138}
139
140pub struct SafeDeleteDynWithWhere {
142 inner: DeleteDyn<HasTable>,
143}
144
145impl SafeDeleteDyn<HasTable> {
146 #[must_use]
148 pub fn where_clause(self, expr: ExprBuilder) -> SafeDeleteDynWithWhere {
149 SafeDeleteDynWithWhere {
150 inner: self.inner.where_clause(expr),
151 }
152 }
153}
154
155impl SafeDeleteDynWithWhere {
156 #[must_use]
158 pub fn build(self) -> (String, Vec<SqlValue>) {
159 self.inner.build()
160 }
161
162 #[must_use]
164 pub fn build_sql(self) -> String {
165 self.inner.build_sql()
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use crate::builder::dyn_col;
173
174 #[test]
175 fn test_simple_delete() {
176 let (sql, params) = DeleteDyn::new()
177 .from("users")
178 .where_clause(dyn_col("id").eq(1_i32))
179 .build();
180
181 assert_eq!(sql, "DELETE FROM users WHERE id = ?");
182 assert_eq!(params.len(), 1);
183 }
184
185 #[test]
186 fn test_delete_all() {
187 let (sql, params) = DeleteDyn::new().from("temp_data").build();
188
189 assert_eq!(sql, "DELETE FROM temp_data");
190 assert!(params.is_empty());
191 }
192
193 #[test]
194 fn test_delete_complex_where() {
195 let (sql, params) = DeleteDyn::new()
196 .from("orders")
197 .where_clause(
198 dyn_col("status")
199 .eq("cancelled")
200 .and(dyn_col("created_at").lt("2024-01-01")),
201 )
202 .build();
203
204 assert_eq!(
205 sql,
206 "DELETE FROM orders WHERE status = ? AND created_at < ?"
207 );
208 assert_eq!(params.len(), 2);
209 }
210
211 #[test]
212 fn test_safe_delete() {
213 let (sql, params) = SafeDeleteDyn::new()
214 .from("users")
215 .where_clause(dyn_col("id").eq(1_i32))
216 .build();
217
218 assert_eq!(sql, "DELETE FROM users WHERE id = ?");
219 assert_eq!(params.len(), 1);
220 }
221
222 #[test]
231 fn test_delete_sql_injection_prevention() {
232 let malicious = "1; DROP TABLE users; --";
233 let (sql, params) = DeleteDyn::new()
234 .from("users")
235 .where_clause(dyn_col("id").eq(malicious))
236 .build();
237
238 assert_eq!(sql, "DELETE FROM users WHERE id = ?");
239 assert!(matches!(¶ms[0], SqlValue::Text(s) if s == malicious));
240 }
241}