1use crate::{
4 builder::SqlFragment,
5 expr::Expr,
6 identifier::{escape_ident, from_qi, QualifiedIdentifier},
7};
8
9#[derive(Clone, Debug, Default)]
11pub struct UpdateBuilder {
12 table: Option<SqlFragment>,
13 set: Vec<(String, SqlFragment)>,
14 where_clauses: Vec<SqlFragment>,
15 returning: Vec<SqlFragment>,
16}
17
18impl UpdateBuilder {
19 pub fn new() -> Self {
21 Self::default()
22 }
23
24 pub fn table(mut self, qi: &QualifiedIdentifier) -> Self {
26 self.table = Some(SqlFragment::raw(from_qi(qi)));
27 self
28 }
29
30 pub fn table_as(mut self, qi: &QualifiedIdentifier, alias: &str) -> Self {
32 self.table = Some(SqlFragment::raw(format!(
33 "{} AS {}",
34 from_qi(qi),
35 escape_ident(alias)
36 )));
37 self
38 }
39
40 pub fn set<V: Into<crate::param::SqlParam>>(mut self, column: &str, value: V) -> Self {
42 let mut frag = SqlFragment::new();
43 frag.push_param(value);
44 self.set.push((column.to_string(), frag));
45 self
46 }
47
48 pub fn set_raw(mut self, column: &str, value: SqlFragment) -> Self {
50 self.set.push((column.to_string(), value));
51 self
52 }
53
54 pub fn where_expr(mut self, expr: Expr) -> Self {
56 self.where_clauses.push(expr.into_fragment());
57 self
58 }
59
60 pub fn where_raw(mut self, sql: SqlFragment) -> Self {
62 self.where_clauses.push(sql);
63 self
64 }
65
66 pub fn returning(mut self, column: &str) -> Self {
68 self.returning
69 .push(SqlFragment::raw(escape_ident(column)));
70 self
71 }
72
73 pub fn returning_all(mut self) -> Self {
75 self.returning.push(SqlFragment::raw("*"));
76 self
77 }
78
79 pub fn build(self) -> SqlFragment {
81 let mut result = SqlFragment::new();
82
83 result.push("UPDATE ");
84
85 if let Some(table) = self.table {
86 result.append(table);
87 }
88
89 if !self.set.is_empty() {
91 result.push(" SET ");
92 for (i, (col, val)) in self.set.into_iter().enumerate() {
93 if i > 0 {
94 result.push(", ");
95 }
96 result.push(&escape_ident(&col));
97 result.push(" = ");
98 result.append(val);
99 }
100 }
101
102 if !self.where_clauses.is_empty() {
104 result.push(" WHERE ");
105 for (i, clause) in self.where_clauses.into_iter().enumerate() {
106 if i > 0 {
107 result.push(" AND ");
108 }
109 result.append(clause);
110 }
111 }
112
113 if !self.returning.is_empty() {
115 result.push(" RETURNING ");
116 for (i, ret) in self.returning.into_iter().enumerate() {
117 if i > 0 {
118 result.push(", ");
119 }
120 result.append(ret);
121 }
122 }
123
124 result
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use crate::param::SqlParam;
132
133 #[test]
134 fn test_simple_update() {
135 let qi = QualifiedIdentifier::new("public", "users");
136 let sql = UpdateBuilder::new()
137 .table(&qi)
138 .set("name", SqlParam::text("Jane"))
139 .where_expr(Expr::eq("id", 1i64))
140 .build();
141
142 assert!(sql.sql().contains("UPDATE"));
143 assert!(sql.sql().contains("SET"));
144 assert!(sql.sql().contains("WHERE"));
145 assert_eq!(sql.params().len(), 2);
146 }
147
148 #[test]
149 fn test_update_returning() {
150 let qi = QualifiedIdentifier::unqualified("users");
151 let sql = UpdateBuilder::new()
152 .table(&qi)
153 .set("status", SqlParam::text("active"))
154 .returning_all()
155 .build();
156
157 assert!(sql.sql().contains("RETURNING *"));
158 }
159
160 #[test]
161 fn test_update_multiple_sets() {
162 let qi = QualifiedIdentifier::unqualified("users");
163 let sql = UpdateBuilder::new()
164 .table(&qi)
165 .set("name", SqlParam::text("John"))
166 .set("email", SqlParam::text("john@new.com"))
167 .set("updated_at", SqlParam::text("now()"))
168 .where_expr(Expr::eq("id", 5i64))
169 .build();
170
171 assert_eq!(sql.params().len(), 4); }
173}