1use crate::args::Args;
4use crate::cond::{ArgsRef, Cond};
5use crate::cte::CTEBuilder;
6use crate::flavor::Flavor;
7use crate::injection::{Injection, InjectionMarker};
8use crate::modifiers::{Arg, Builder, escape};
9use crate::string_builder::StringBuilder;
10use crate::where_clause::{WhereClause, WhereClauseBuilder, WhereClauseRef};
11use std::cell::RefCell;
12use std::ops::Deref;
13use std::rc::Rc;
14
15const UPDATE_MARKER_INIT: InjectionMarker = 0;
16const UPDATE_MARKER_AFTER_WITH: InjectionMarker = 1;
17const UPDATE_MARKER_AFTER_UPDATE: InjectionMarker = 2;
18const UPDATE_MARKER_AFTER_SET: InjectionMarker = 3;
19const UPDATE_MARKER_AFTER_WHERE: InjectionMarker = 4;
20const UPDATE_MARKER_AFTER_ORDER_BY: InjectionMarker = 5;
21const UPDATE_MARKER_AFTER_LIMIT: InjectionMarker = 6;
22const UPDATE_MARKER_AFTER_RETURNING: InjectionMarker = 7;
23
24#[derive(Debug)]
25pub struct UpdateBuilder {
26 args: ArgsRef,
27 cond: Cond,
28
29 tables: Vec<String>,
30 assignments: Vec<String>,
31
32 where_clause: Option<WhereClauseRef>,
33 where_var: Option<String>,
34 cte_var: Option<String>,
35 cte: Option<CTEBuilder>,
36
37 order_by_cols: Vec<String>,
38 order: Option<&'static str>,
39 limit_var: Option<String>,
40 returning: Vec<String>,
41
42 injection: Injection,
43 marker: InjectionMarker,
44}
45
46impl Deref for UpdateBuilder {
47 type Target = Cond;
48 fn deref(&self) -> &Self::Target {
49 &self.cond
50 }
51}
52
53impl Default for UpdateBuilder {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl Clone for UpdateBuilder {
60 fn clone(&self) -> Self {
61 self.clone_builder()
62 }
63}
64
65impl UpdateBuilder {
66 pub fn new() -> Self {
67 let args = Rc::new(RefCell::new(Args::default()));
68 let cond = Cond::with_args(args.clone());
69 Self {
70 args,
71 cond,
72 tables: Vec::new(),
73 assignments: Vec::new(),
74 where_clause: None,
75 where_var: None,
76 cte_var: None,
77 cte: None,
78 order_by_cols: Vec::new(),
79 order: None,
80 limit_var: None,
81 returning: Vec::new(),
82 injection: Injection::new(),
83 marker: UPDATE_MARKER_INIT,
84 }
85 }
86
87 pub fn set_flavor(&mut self, flavor: Flavor) -> Flavor {
88 let mut a = self.args.borrow_mut();
89 let old = a.flavor;
90 a.flavor = flavor;
91 old
92 }
93
94 pub fn flavor(&self) -> Flavor {
95 self.args.borrow().flavor
96 }
97
98 pub fn with(&mut self, cte: &CTEBuilder) -> &mut Self {
99 let cte_clone = cte.clone();
100 let ph = self.var(Arg::Builder(Box::new(cte.clone())));
101 self.cte = Some(cte_clone);
102 self.cte_var = Some(ph);
103 self.marker = UPDATE_MARKER_AFTER_WHERE; self
105 }
106
107 fn table_names(&self) -> Vec<String> {
108 let mut table_names = Vec::new();
109 if !self.tables.is_empty() {
110 table_names.extend(self.tables.clone());
111 }
112 if let Some(cte) = &self.cte {
113 table_names.extend(cte.table_names_for_from());
114 }
115 table_names
116 }
117
118 pub fn where_clause(&self) -> Option<WhereClauseRef> {
119 self.where_clause.clone()
120 }
121
122 pub fn set_where_clause(&mut self, wc: Option<WhereClauseRef>) -> &mut Self {
123 match wc {
124 None => {
125 self.where_clause = None;
126 self.where_var = None;
127 }
128 Some(wc) => {
129 if let Some(ph) = &self.where_var {
130 self.args.borrow_mut().replace(
131 ph,
132 Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))),
133 );
134 } else {
135 let ph = self.var(Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))));
136 self.where_var = Some(ph);
137 }
138 self.where_clause = Some(wc);
139 }
140 }
141 self
142 }
143
144 pub fn clear_where_clause(&mut self) -> &mut Self {
145 self.set_where_clause(None)
146 }
147
148 pub fn clone_builder(&self) -> Self {
149 let old_args = self.args.borrow().clone();
150 let args = Rc::new(RefCell::new(old_args));
151 let cond = Cond::with_args(args.clone());
152
153 let mut cloned = Self {
154 args,
155 cond,
156 tables: self.tables.clone(),
157 assignments: self.assignments.clone(),
158 where_clause: self.where_clause.clone(),
159 where_var: self.where_var.clone(),
160 cte_var: self.cte_var.clone(),
161 cte: self.cte.clone(),
162 order_by_cols: self.order_by_cols.clone(),
163 order: self.order,
164 limit_var: self.limit_var.clone(),
165 returning: self.returning.clone(),
166 injection: self.injection.clone(),
167 marker: self.marker,
168 };
169
170 if let (Some(wc), Some(ph)) = (&self.where_clause, &self.where_var) {
171 let new_wc = Rc::new(RefCell::new(wc.borrow().clone()));
172 cloned.where_clause = Some(new_wc.clone());
173 cloned
174 .args
175 .borrow_mut()
176 .replace(ph, Arg::Builder(Box::new(WhereClauseBuilder::new(new_wc))));
177 }
178
179 if let (Some(cte), Some(ph)) = (&self.cte, &self.cte_var) {
180 let cte_for_arg = cte.clone();
181 let cte_for_field = cte_for_arg.clone();
182 cloned.cte = Some(cte_for_field);
183 cloned
184 .args
185 .borrow_mut()
186 .replace(ph, Arg::Builder(Box::new(cte_for_arg)));
187 }
188
189 cloned
190 }
191
192 fn var(&self, v: impl Into<Arg>) -> String {
193 self.args.borrow_mut().add(v)
194 }
195
196 pub fn update(&mut self, tables: impl IntoIterator<Item = impl Into<String>>) -> &mut Self {
197 self.tables = tables.into_iter().map(Into::into).collect();
198 self.marker = UPDATE_MARKER_AFTER_UPDATE;
199 self
200 }
201
202 pub fn set(&mut self, assignments: impl IntoIterator<Item = impl Into<String>>) -> &mut Self {
203 self.assignments = assignments.into_iter().map(Into::into).collect();
204 self.marker = UPDATE_MARKER_AFTER_SET;
205 self
206 }
207
208 pub fn set_more(
209 &mut self,
210 assignments: impl IntoIterator<Item = impl Into<String>>,
211 ) -> &mut Self {
212 self.assignments
213 .extend(assignments.into_iter().map(Into::into));
214 self.marker = UPDATE_MARKER_AFTER_SET;
215 self
216 }
217
218 pub fn where_(&mut self, and_expr: impl IntoIterator<Item = impl Into<String>>) -> &mut Self {
219 let exprs: Vec<String> = and_expr.into_iter().map(Into::into).collect();
220 if exprs.is_empty() || exprs.iter().all(|s| s.is_empty()) {
221 return self;
222 }
223
224 if self.where_clause.is_none() {
225 let wc = WhereClause::new();
226 let ph = self.var(Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))));
227 self.where_clause = Some(wc);
228 self.where_var = Some(ph);
229 }
230
231 self.where_clause
232 .as_ref()
233 .unwrap()
234 .borrow_mut()
235 .add_where_expr(self.args.clone(), exprs);
236 self.marker = UPDATE_MARKER_AFTER_WITH;
237 self
238 }
239
240 pub fn add_where_expr(
241 &mut self,
242 args: ArgsRef,
243 exprs: impl IntoIterator<Item = impl Into<String>>,
244 ) -> &mut Self {
245 let exprs: Vec<String> = exprs.into_iter().map(Into::into).collect();
246 if exprs.is_empty() || exprs.iter().all(|s| s.is_empty()) {
247 return self;
248 }
249 if self.where_clause.is_none() {
250 let wc = WhereClause::new();
251 let ph = self.var(Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))));
252 self.where_clause = Some(wc);
253 self.where_var = Some(ph);
254 }
255 self.where_clause
256 .as_ref()
257 .unwrap()
258 .borrow_mut()
259 .add_where_expr(args, exprs);
260 self.marker = UPDATE_MARKER_AFTER_WHERE;
261 self
262 }
263
264 pub fn add_where_clause(&mut self, other: &WhereClause) -> &mut Self {
265 if self.where_clause.is_none() {
266 let wc = WhereClause::new();
267 let ph = self.var(Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))));
268 self.where_clause = Some(wc);
269 self.where_var = Some(ph);
270 }
271 self.where_clause
272 .as_ref()
273 .unwrap()
274 .borrow_mut()
275 .add_where_clause(other);
276 self
277 }
278
279 pub fn add_where_clause_ref(&mut self, other: &WhereClauseRef) -> &mut Self {
280 if self.where_clause.is_none() {
281 let wc = WhereClause::new();
282 let ph = self.var(Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))));
283 self.where_clause = Some(wc);
284 self.where_var = Some(ph);
285 }
286 self.where_clause
287 .as_ref()
288 .unwrap()
289 .borrow_mut()
290 .add_where_clause(&other.borrow());
291 self
292 }
293
294 pub fn assign(&self, field: &str, value: impl Into<Arg>) -> String {
295 format!("{} = {}", escape(field), self.var(value))
296 }
297
298 pub fn incr(&self, field: &str) -> String {
299 let f = escape(field);
300 format!("{f} = {f} + 1")
301 }
302
303 pub fn decr(&self, field: &str) -> String {
304 let f = escape(field);
305 format!("{f} = {f} - 1")
306 }
307
308 pub fn add_(&self, field: &str, value: impl Into<Arg>) -> String {
309 let f = escape(field);
310 format!("{f} = {f} + {}", self.var(value))
311 }
312
313 pub fn add(&self, field: &str, value: impl Into<Arg>) -> String {
315 self.add_(field, value)
316 }
317
318 pub fn sub(&self, field: &str, value: impl Into<Arg>) -> String {
319 let f = escape(field);
320 format!("{f} = {f} - {}", self.var(value))
321 }
322
323 pub fn mul(&self, field: &str, value: impl Into<Arg>) -> String {
324 let f = escape(field);
325 format!("{f} = {f} * {}", self.var(value))
326 }
327
328 pub fn div(&self, field: &str, value: impl Into<Arg>) -> String {
329 let f = escape(field);
330 format!("{f} = {f} / {}", self.var(value))
331 }
332
333 pub fn order_by(&mut self, cols: impl IntoIterator<Item = impl Into<String>>) -> &mut Self {
334 self.order_by_cols = cols.into_iter().map(Into::into).collect();
335 self.marker = UPDATE_MARKER_AFTER_ORDER_BY;
336 self
337 }
338
339 pub fn order_by_asc(&mut self, col: impl Into<String>) -> &mut Self {
340 self.order_by_cols.push(format!("{} ASC", col.into()));
341 self.marker = UPDATE_MARKER_AFTER_ORDER_BY;
342 self
343 }
344
345 pub fn order_by_desc(&mut self, col: impl Into<String>) -> &mut Self {
346 self.order_by_cols.push(format!("{} DESC", col.into()));
347 self.marker = UPDATE_MARKER_AFTER_ORDER_BY;
348 self
349 }
350
351 pub fn asc(&mut self) -> &mut Self {
352 self.order = Some("ASC");
353 self.marker = UPDATE_MARKER_AFTER_ORDER_BY;
354 self
355 }
356
357 pub fn desc(&mut self) -> &mut Self {
358 self.order = Some("DESC");
359 self.marker = UPDATE_MARKER_AFTER_ORDER_BY;
360 self
361 }
362
363 pub fn limit(&mut self, limit: i64) -> &mut Self {
364 if limit < 0 {
365 self.limit_var = None;
366 return self;
367 }
368 self.limit_var = Some(self.var(limit));
369 self.marker = UPDATE_MARKER_AFTER_LIMIT;
370 self
371 }
372
373 pub fn returning(&mut self, cols: impl IntoIterator<Item = impl Into<String>>) -> &mut Self {
374 self.returning = cols.into_iter().map(Into::into).collect();
375 self.marker = UPDATE_MARKER_AFTER_RETURNING;
376 self
377 }
378
379 pub fn num_assignment(&self) -> usize {
381 self.assignments.iter().filter(|s| !s.is_empty()).count()
382 }
383
384 pub fn sql(&mut self, sql: impl Into<String>) -> &mut Self {
385 self.injection.sql(self.marker, sql);
386 self
387 }
388}
389
390impl Builder for UpdateBuilder {
391 fn build_with_flavor(&self, flavor: Flavor, initial_arg: &[Arg]) -> (String, Vec<Arg>) {
392 let mut buf = StringBuilder::new();
393 write_injection(&mut buf, &self.injection, UPDATE_MARKER_INIT);
394
395 if let Some(ph) = &self.cte_var {
396 buf.write_leading(ph);
397 write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_WITH);
398 }
399
400 match flavor {
401 Flavor::MySQL => {
402 let table_names = self.table_names();
403 if !table_names.is_empty() {
404 buf.write_leading("UPDATE");
405 buf.write_str(" ");
406 buf.write_str(&table_names.join(", "));
407 }
408 }
409 _ => {
410 if !self.tables.is_empty() {
411 buf.write_leading("UPDATE");
412 buf.write_str(" ");
413 buf.write_str(&self.tables.join(", "));
414 }
415 }
416 }
417 write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_UPDATE);
418
419 let assigns: Vec<String> = self
420 .assignments
421 .iter()
422 .filter(|s| !s.is_empty())
423 .cloned()
424 .collect();
425 if !assigns.is_empty() {
426 buf.write_leading("SET");
427 buf.write_str(" ");
428 buf.write_str(&assigns.join(", "));
429 }
430 write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_SET);
431
432 if flavor != Flavor::MySQL
433 && let Some(cte) = &self.cte
434 {
435 let cte_table_names = cte.table_names_for_from();
436 if !cte_table_names.is_empty() {
437 buf.write_leading("FROM");
438 buf.write_str(" ");
439 buf.write_str(&cte_table_names.join(", "));
440 }
441 }
442
443 if flavor == Flavor::SQLServer && !self.returning.is_empty() {
444 buf.write_leading("OUTPUT");
445 buf.write_str(" ");
446 let prefixed: Vec<String> = self
447 .returning
448 .iter()
449 .map(|c| format!("INSERTED.{c}"))
450 .collect();
451 buf.write_str(&prefixed.join(", "));
452 write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_RETURNING);
453 }
454
455 if let Some(ph) = &self.where_var {
456 buf.write_leading(ph);
457 write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_WHERE);
458 }
459
460 if !self.order_by_cols.is_empty() {
461 buf.write_leading("ORDER BY");
462 buf.write_str(" ");
463 buf.write_str(&self.order_by_cols.join(", "));
464 if let Some(order) = self.order {
465 buf.write_str(" ");
466 buf.write_str(order);
467 }
468 write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_ORDER_BY);
469 }
470
471 if let Some(lim) = &self.limit_var {
472 buf.write_leading("LIMIT");
473 buf.write_str(" ");
474 buf.write_str(lim);
475 write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_LIMIT);
476 }
477
478 if (flavor == Flavor::PostgreSQL || flavor == Flavor::SQLite) && !self.returning.is_empty()
479 {
480 buf.write_leading("RETURNING");
481 buf.write_str(" ");
482 buf.write_str(&self.returning.join(", "));
483 write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_RETURNING);
484 }
485
486 self.args
487 .borrow()
488 .compile_with_flavor(&buf.into_string(), flavor, initial_arg)
489 }
490
491 fn flavor(&self) -> Flavor {
492 self.flavor()
493 }
494}
495
496fn write_injection(buf: &mut StringBuilder, inj: &Injection, marker: InjectionMarker) {
497 let sqls = inj.at(marker);
498 if sqls.is_empty() {
499 return;
500 }
501 buf.write_leading("");
502 buf.write_str(&sqls.join(" "));
503}