Skip to main content

oxide_sql_core/builder/
update.rs

1//! Dynamic UPDATE statement builder using the typestate pattern.
2//!
3//! This module provides string-based query building. For compile-time
4//! validated queries using schema traits, use `Update` from `builder::typed`.
5
6use std::marker::PhantomData;
7
8use super::expr::ExprBuilder;
9use super::value::{SqlValue, ToSqlValue};
10
11// Typestate markers
12
13/// Marker: No table specified yet.
14pub struct NoTable;
15/// Marker: Table has been specified.
16pub struct HasTable;
17/// Marker: No SET clause specified yet.
18pub struct NoSet;
19/// Marker: SET clause has been specified.
20pub struct HasSet;
21
22/// An assignment in the SET clause.
23struct Assignment {
24    column: String,
25    value: SqlValue,
26}
27
28/// A dynamic UPDATE statement builder using string-based column names.
29///
30/// For compile-time validated queries, use `Update` from `builder::typed`.
31pub struct UpdateDyn<Table, Set> {
32    table: Option<String>,
33    assignments: Vec<Assignment>,
34    where_clause: Option<ExprBuilder>,
35    _state: PhantomData<(Table, Set)>,
36}
37
38impl UpdateDyn<NoTable, NoSet> {
39    /// Creates a new UPDATE builder.
40    #[must_use]
41    pub fn new() -> Self {
42        Self {
43            table: None,
44            assignments: vec![],
45            where_clause: None,
46            _state: PhantomData,
47        }
48    }
49}
50
51impl Default for UpdateDyn<NoTable, NoSet> {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57// Transition: NoTable -> HasTable
58impl<Set> UpdateDyn<NoTable, Set> {
59    /// Specifies the table to update.
60    #[must_use]
61    pub fn table(self, table: &str) -> UpdateDyn<HasTable, Set> {
62        UpdateDyn {
63            table: Some(String::from(table)),
64            assignments: self.assignments,
65            where_clause: self.where_clause,
66            _state: PhantomData,
67        }
68    }
69}
70
71// Transition: NoSet -> HasSet (requires table)
72impl UpdateDyn<HasTable, NoSet> {
73    /// Adds a SET assignment.
74    #[must_use]
75    pub fn set<T: ToSqlValue>(self, column: &str, value: T) -> UpdateDyn<HasTable, HasSet> {
76        UpdateDyn {
77            table: self.table,
78            assignments: vec![Assignment {
79                column: String::from(column),
80                value: value.to_sql_value(),
81            }],
82            where_clause: self.where_clause,
83            _state: PhantomData,
84        }
85    }
86}
87
88// Methods available after SET
89impl UpdateDyn<HasTable, HasSet> {
90    /// Adds another SET assignment.
91    #[must_use]
92    pub fn set<T: ToSqlValue>(mut self, column: &str, value: T) -> Self {
93        self.assignments.push(Assignment {
94            column: String::from(column),
95            value: value.to_sql_value(),
96        });
97        self
98    }
99
100    /// Adds a WHERE clause.
101    #[must_use]
102    pub fn where_clause(mut self, expr: ExprBuilder) -> Self {
103        self.where_clause = Some(expr);
104        self
105    }
106
107    /// Builds the UPDATE statement and returns SQL with parameters.
108    #[must_use]
109    pub fn build(self) -> (String, Vec<SqlValue>) {
110        let mut sql = String::from("UPDATE ");
111        let mut params = vec![];
112
113        if let Some(ref table) = self.table {
114            sql.push_str(table);
115        }
116
117        sql.push_str(" SET ");
118
119        let set_parts: Vec<String> = self
120            .assignments
121            .iter()
122            .map(|a| format!("{} = ?", a.column))
123            .collect();
124        sql.push_str(&set_parts.join(", "));
125
126        for assignment in self.assignments {
127            params.push(assignment.value);
128        }
129
130        if let Some(ref where_expr) = self.where_clause {
131            sql.push_str(" WHERE ");
132            sql.push_str(where_expr.sql());
133            params.extend(where_expr.params().iter().cloned());
134        }
135
136        (sql, params)
137    }
138
139    /// Builds the UPDATE statement and returns only the SQL string.
140    #[must_use]
141    pub fn build_sql(self) -> String {
142        let (sql, _) = self.build();
143        sql
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use crate::builder::dyn_col;
151
152    #[test]
153    fn test_simple_update() {
154        let (sql, params) = UpdateDyn::new().table("users").set("name", "Bob").build();
155
156        assert_eq!(sql, "UPDATE users SET name = ?");
157        assert_eq!(params.len(), 1);
158    }
159
160    #[test]
161    fn test_update_multiple_columns() {
162        let (sql, params) = UpdateDyn::new()
163            .table("users")
164            .set("name", "Bob")
165            .set("email", "bob@example.com")
166            .set("age", 30_i32)
167            .build();
168
169        assert_eq!(sql, "UPDATE users SET name = ?, email = ?, age = ?");
170        assert_eq!(params.len(), 3);
171    }
172
173    #[test]
174    fn test_update_with_where() {
175        let (sql, params) = UpdateDyn::new()
176            .table("users")
177            .set("active", false)
178            .where_clause(dyn_col("id").eq(1_i32))
179            .build();
180
181        assert_eq!(sql, "UPDATE users SET active = ? WHERE id = ?");
182        assert_eq!(params.len(), 2);
183    }
184
185    #[test]
186    fn test_update_sql_injection_prevention() {
187        let malicious = "'; DROP TABLE users; --";
188        let (sql, params) = UpdateDyn::new()
189            .table("users")
190            .set("name", malicious)
191            .where_clause(dyn_col("id").eq(1_i32))
192            .build();
193
194        assert_eq!(sql, "UPDATE users SET name = ? WHERE id = ?");
195        assert!(matches!(&params[0], SqlValue::Text(s) if s == malicious));
196    }
197}