postrust_sql/
update.rs

1//! UPDATE statement builder.
2
3use crate::{
4    builder::SqlFragment,
5    expr::Expr,
6    identifier::{escape_ident, from_qi, QualifiedIdentifier},
7};
8
9/// Builder for UPDATE statements.
10#[derive(Clone, Debug, Default)]
11pub struct UpdateBuilder {
12    table: Option<SqlFragment>,
13    set: Vec<(String, SqlFragment)>,
14    where_clauses: Vec<SqlFragment>,
15    returning: Vec<SqlFragment>,
16}
17
18impl UpdateBuilder {
19    /// Create a new UPDATE builder.
20    pub fn new() -> Self {
21        Self::default()
22    }
23
24    /// Set the target table.
25    pub fn table(mut self, qi: &QualifiedIdentifier) -> Self {
26        self.table = Some(SqlFragment::raw(from_qi(qi)));
27        self
28    }
29
30    /// Set the target table with alias.
31    pub fn table_as(mut self, qi: &QualifiedIdentifier, alias: &str) -> Self {
32        self.table = Some(SqlFragment::raw(format!(
33            "{} AS {}",
34            from_qi(qi),
35            escape_ident(alias)
36        )));
37        self
38    }
39
40    /// Add a SET clause with parameterized value.
41    pub fn set<V: Into<crate::param::SqlParam>>(mut self, column: &str, value: V) -> Self {
42        let mut frag = SqlFragment::new();
43        frag.push_param(value);
44        self.set.push((column.to_string(), frag));
45        self
46    }
47
48    /// Add a SET clause with raw SQL.
49    pub fn set_raw(mut self, column: &str, value: SqlFragment) -> Self {
50        self.set.push((column.to_string(), value));
51        self
52    }
53
54    /// Add a WHERE clause.
55    pub fn where_expr(mut self, expr: Expr) -> Self {
56        self.where_clauses.push(expr.into_fragment());
57        self
58    }
59
60    /// Add a raw WHERE clause.
61    pub fn where_raw(mut self, sql: SqlFragment) -> Self {
62        self.where_clauses.push(sql);
63        self
64    }
65
66    /// Add RETURNING clause.
67    pub fn returning(mut self, column: &str) -> Self {
68        self.returning
69            .push(SqlFragment::raw(escape_ident(column)));
70        self
71    }
72
73    /// Add RETURNING * clause.
74    pub fn returning_all(mut self) -> Self {
75        self.returning.push(SqlFragment::raw("*"));
76        self
77    }
78
79    /// Build the UPDATE statement.
80    pub fn build(self) -> SqlFragment {
81        let mut result = SqlFragment::new();
82
83        result.push("UPDATE ");
84
85        if let Some(table) = self.table {
86            result.append(table);
87        }
88
89        // SET
90        if !self.set.is_empty() {
91            result.push(" SET ");
92            for (i, (col, val)) in self.set.into_iter().enumerate() {
93                if i > 0 {
94                    result.push(", ");
95                }
96                result.push(&escape_ident(&col));
97                result.push(" = ");
98                result.append(val);
99            }
100        }
101
102        // WHERE
103        if !self.where_clauses.is_empty() {
104            result.push(" WHERE ");
105            for (i, clause) in self.where_clauses.into_iter().enumerate() {
106                if i > 0 {
107                    result.push(" AND ");
108                }
109                result.append(clause);
110            }
111        }
112
113        // RETURNING
114        if !self.returning.is_empty() {
115            result.push(" RETURNING ");
116            for (i, ret) in self.returning.into_iter().enumerate() {
117                if i > 0 {
118                    result.push(", ");
119                }
120                result.append(ret);
121            }
122        }
123
124        result
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use crate::param::SqlParam;
132
133    #[test]
134    fn test_simple_update() {
135        let qi = QualifiedIdentifier::new("public", "users");
136        let sql = UpdateBuilder::new()
137            .table(&qi)
138            .set("name", SqlParam::text("Jane"))
139            .where_expr(Expr::eq("id", 1i64))
140            .build();
141
142        assert!(sql.sql().contains("UPDATE"));
143        assert!(sql.sql().contains("SET"));
144        assert!(sql.sql().contains("WHERE"));
145        assert_eq!(sql.params().len(), 2);
146    }
147
148    #[test]
149    fn test_update_returning() {
150        let qi = QualifiedIdentifier::unqualified("users");
151        let sql = UpdateBuilder::new()
152            .table(&qi)
153            .set("status", SqlParam::text("active"))
154            .returning_all()
155            .build();
156
157        assert!(sql.sql().contains("RETURNING *"));
158    }
159
160    #[test]
161    fn test_update_multiple_sets() {
162        let qi = QualifiedIdentifier::unqualified("users");
163        let sql = UpdateBuilder::new()
164            .table(&qi)
165            .set("name", SqlParam::text("John"))
166            .set("email", SqlParam::text("john@new.com"))
167            .set("updated_at", SqlParam::text("now()"))
168            .where_expr(Expr::eq("id", 5i64))
169            .build();
170
171        assert_eq!(sql.params().len(), 4); // 3 sets + 1 where
172    }
173}