oxide_sql_core/builder/
update.rs1use std::marker::PhantomData;
7
8use super::expr::ExprBuilder;
9use super::value::{SqlValue, ToSqlValue};
10
11pub struct NoTable;
15pub struct HasTable;
17pub struct NoSet;
19pub struct HasSet;
21
22struct Assignment {
24 column: String,
25 value: SqlValue,
26}
27
28pub struct UpdateDyn<Table, Set> {
32 table: Option<String>,
33 assignments: Vec<Assignment>,
34 where_clause: Option<ExprBuilder>,
35 _state: PhantomData<(Table, Set)>,
36}
37
38impl UpdateDyn<NoTable, NoSet> {
39 #[must_use]
41 pub fn new() -> Self {
42 Self {
43 table: None,
44 assignments: vec![],
45 where_clause: None,
46 _state: PhantomData,
47 }
48 }
49}
50
51impl Default for UpdateDyn<NoTable, NoSet> {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57impl<Set> UpdateDyn<NoTable, Set> {
59 #[must_use]
61 pub fn table(self, table: &str) -> UpdateDyn<HasTable, Set> {
62 UpdateDyn {
63 table: Some(String::from(table)),
64 assignments: self.assignments,
65 where_clause: self.where_clause,
66 _state: PhantomData,
67 }
68 }
69}
70
71impl UpdateDyn<HasTable, NoSet> {
73 #[must_use]
75 pub fn set<T: ToSqlValue>(self, column: &str, value: T) -> UpdateDyn<HasTable, HasSet> {
76 UpdateDyn {
77 table: self.table,
78 assignments: vec![Assignment {
79 column: String::from(column),
80 value: value.to_sql_value(),
81 }],
82 where_clause: self.where_clause,
83 _state: PhantomData,
84 }
85 }
86}
87
88impl UpdateDyn<HasTable, HasSet> {
90 #[must_use]
92 pub fn set<T: ToSqlValue>(mut self, column: &str, value: T) -> Self {
93 self.assignments.push(Assignment {
94 column: String::from(column),
95 value: value.to_sql_value(),
96 });
97 self
98 }
99
100 #[must_use]
102 pub fn where_clause(mut self, expr: ExprBuilder) -> Self {
103 self.where_clause = Some(expr);
104 self
105 }
106
107 #[must_use]
109 pub fn build(self) -> (String, Vec<SqlValue>) {
110 let mut sql = String::from("UPDATE ");
111 let mut params = vec![];
112
113 if let Some(ref table) = self.table {
114 sql.push_str(table);
115 }
116
117 sql.push_str(" SET ");
118
119 let set_parts: Vec<String> = self
120 .assignments
121 .iter()
122 .map(|a| format!("{} = ?", a.column))
123 .collect();
124 sql.push_str(&set_parts.join(", "));
125
126 for assignment in self.assignments {
127 params.push(assignment.value);
128 }
129
130 if let Some(ref where_expr) = self.where_clause {
131 sql.push_str(" WHERE ");
132 sql.push_str(where_expr.sql());
133 params.extend(where_expr.params().iter().cloned());
134 }
135
136 (sql, params)
137 }
138
139 #[must_use]
141 pub fn build_sql(self) -> String {
142 let (sql, _) = self.build();
143 sql
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use crate::builder::dyn_col;
151
152 #[test]
153 fn test_simple_update() {
154 let (sql, params) = UpdateDyn::new().table("users").set("name", "Bob").build();
155
156 assert_eq!(sql, "UPDATE users SET name = ?");
157 assert_eq!(params.len(), 1);
158 }
159
160 #[test]
161 fn test_update_multiple_columns() {
162 let (sql, params) = UpdateDyn::new()
163 .table("users")
164 .set("name", "Bob")
165 .set("email", "bob@example.com")
166 .set("age", 30_i32)
167 .build();
168
169 assert_eq!(sql, "UPDATE users SET name = ?, email = ?, age = ?");
170 assert_eq!(params.len(), 3);
171 }
172
173 #[test]
174 fn test_update_with_where() {
175 let (sql, params) = UpdateDyn::new()
176 .table("users")
177 .set("active", false)
178 .where_clause(dyn_col("id").eq(1_i32))
179 .build();
180
181 assert_eq!(sql, "UPDATE users SET active = ? WHERE id = ?");
182 assert_eq!(params.len(), 2);
183 }
184
185 #[test]
186 fn test_update_sql_injection_prevention() {
187 let malicious = "'; DROP TABLE users; --";
188 let (sql, params) = UpdateDyn::new()
189 .table("users")
190 .set("name", malicious)
191 .where_clause(dyn_col("id").eq(1_i32))
192 .build();
193
194 assert_eq!(sql, "UPDATE users SET name = ? WHERE id = ?");
195 assert!(matches!(¶ms[0], SqlValue::Text(s) if s == malicious));
196 }
197}