1use crate::args::Args;
4use crate::cond::{ArgsRef, Cond};
5use crate::cte::CTEBuilder;
6use crate::flavor::Flavor;
7use crate::injection::{Injection, InjectionMarker};
8use crate::macros::{IntoStrings, collect_into_strings};
9use crate::modifiers::{Arg, Builder, escape};
10use crate::string_builder::StringBuilder;
11use crate::where_clause::{WhereClause, WhereClauseBuilder, WhereClauseRef};
12use std::cell::RefCell;
13use std::ops::Deref;
14use std::rc::Rc;
15
16const UPDATE_MARKER_INIT: InjectionMarker = 0;
17const UPDATE_MARKER_AFTER_WITH: InjectionMarker = 1;
18const UPDATE_MARKER_AFTER_UPDATE: InjectionMarker = 2;
19const UPDATE_MARKER_AFTER_SET: InjectionMarker = 3;
20const UPDATE_MARKER_AFTER_WHERE: InjectionMarker = 4;
21const UPDATE_MARKER_AFTER_ORDER_BY: InjectionMarker = 5;
22const UPDATE_MARKER_AFTER_LIMIT: InjectionMarker = 6;
23const UPDATE_MARKER_AFTER_RETURNING: InjectionMarker = 7;
24
25#[derive(Debug)]
26pub struct UpdateBuilder {
27 args: ArgsRef,
28 cond: Cond,
29
30 tables: Vec<String>,
31 assignments: Vec<String>,
32
33 where_clause: Option<WhereClauseRef>,
34 where_var: Option<String>,
35 cte_var: Option<String>,
36 cte: Option<CTEBuilder>,
37
38 order_by_cols: Vec<String>,
39 order: Option<&'static str>,
40 limit_var: Option<String>,
41 returning: Vec<String>,
42
43 injection: Injection,
44 marker: InjectionMarker,
45}
46
47impl Deref for UpdateBuilder {
48 type Target = Cond;
49 fn deref(&self) -> &Self::Target {
50 &self.cond
51 }
52}
53
54impl Default for UpdateBuilder {
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60impl Clone for UpdateBuilder {
61 fn clone(&self) -> Self {
62 self.clone_builder()
63 }
64}
65
66impl UpdateBuilder {
67 pub fn new() -> Self {
68 let args = Rc::new(RefCell::new(Args::default()));
69 let cond = Cond::with_args(args.clone());
70 Self {
71 args,
72 cond,
73 tables: Vec::new(),
74 assignments: Vec::new(),
75 where_clause: None,
76 where_var: None,
77 cte_var: None,
78 cte: None,
79 order_by_cols: Vec::new(),
80 order: None,
81 limit_var: None,
82 returning: Vec::new(),
83 injection: Injection::new(),
84 marker: UPDATE_MARKER_INIT,
85 }
86 }
87
88 pub fn set_flavor(&mut self, flavor: Flavor) -> Flavor {
89 let mut a = self.args.borrow_mut();
90 let old = a.flavor;
91 a.flavor = flavor;
92 old
93 }
94
95 pub fn flavor(&self) -> Flavor {
96 self.args.borrow().flavor
97 }
98
99 pub fn with(&mut self, cte: &CTEBuilder) -> &mut Self {
100 let cte_clone = cte.clone();
101 let ph = self.var(Arg::Builder(Box::new(cte.clone())));
102 self.cte = Some(cte_clone);
103 self.cte_var = Some(ph);
104 self.marker = UPDATE_MARKER_AFTER_WHERE; self
106 }
107
108 fn table_names(&self) -> Vec<String> {
109 let mut table_names = Vec::new();
110 if !self.tables.is_empty() {
111 table_names.extend(self.tables.clone());
112 }
113 if let Some(cte) = &self.cte {
114 table_names.extend(cte.table_names_for_from());
115 }
116 table_names
117 }
118
119 pub fn where_clause(&self) -> Option<WhereClauseRef> {
120 self.where_clause.clone()
121 }
122
123 pub fn set_where_clause(&mut self, wc: Option<WhereClauseRef>) -> &mut Self {
124 match wc {
125 None => {
126 self.where_clause = None;
127 self.where_var = None;
128 }
129 Some(wc) => {
130 if let Some(ph) = &self.where_var {
131 self.args.borrow_mut().replace(
132 ph,
133 Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))),
134 );
135 } else {
136 let ph = self.var(Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))));
137 self.where_var = Some(ph);
138 }
139 self.where_clause = Some(wc);
140 }
141 }
142 self
143 }
144
145 pub fn clear_where_clause(&mut self) -> &mut Self {
146 self.set_where_clause(None)
147 }
148
149 pub fn clone_builder(&self) -> Self {
150 let old_args = self.args.borrow().clone();
151 let args = Rc::new(RefCell::new(old_args));
152 let cond = Cond::with_args(args.clone());
153
154 let mut cloned = Self {
155 args,
156 cond,
157 tables: self.tables.clone(),
158 assignments: self.assignments.clone(),
159 where_clause: self.where_clause.clone(),
160 where_var: self.where_var.clone(),
161 cte_var: self.cte_var.clone(),
162 cte: self.cte.clone(),
163 order_by_cols: self.order_by_cols.clone(),
164 order: self.order,
165 limit_var: self.limit_var.clone(),
166 returning: self.returning.clone(),
167 injection: self.injection.clone(),
168 marker: self.marker,
169 };
170
171 if let (Some(wc), Some(ph)) = (&self.where_clause, &self.where_var) {
172 let new_wc = Rc::new(RefCell::new(wc.borrow().clone()));
173 cloned.where_clause = Some(new_wc.clone());
174 cloned
175 .args
176 .borrow_mut()
177 .replace(ph, Arg::Builder(Box::new(WhereClauseBuilder::new(new_wc))));
178 }
179
180 if let (Some(cte), Some(ph)) = (&self.cte, &self.cte_var) {
181 let cte_for_arg = cte.clone();
182 let cte_for_field = cte_for_arg.clone();
183 cloned.cte = Some(cte_for_field);
184 cloned
185 .args
186 .borrow_mut()
187 .replace(ph, Arg::Builder(Box::new(cte_for_arg)));
188 }
189
190 cloned
191 }
192
193 pub fn build(&self) -> (String, Vec<Arg>) {
194 Builder::build(self)
195 }
196
197 fn var(&self, v: impl Into<Arg>) -> String {
198 self.args.borrow_mut().add(v)
199 }
200
201 pub fn update<T>(&mut self, tables: T) -> &mut Self
202 where
203 T: IntoStrings,
204 {
205 self.tables = collect_into_strings(tables);
206 self.marker = UPDATE_MARKER_AFTER_UPDATE;
207 self
208 }
209
210 pub fn set<T>(&mut self, assignments: T) -> &mut Self
211 where
212 T: IntoStrings,
213 {
214 self.assignments = collect_into_strings(assignments);
215 self.marker = UPDATE_MARKER_AFTER_SET;
216 self
217 }
218
219 pub fn set_more(&mut self, assignments: impl IntoStrings) -> &mut Self {
220 self.assignments.extend(collect_into_strings(assignments));
221 self.marker = UPDATE_MARKER_AFTER_SET;
222 self
223 }
224
225 pub fn where_<T>(&mut self, and_expr: T) -> &mut Self
226 where
227 T: IntoStrings,
228 {
229 let exprs = collect_into_strings(and_expr);
230 if exprs.is_empty() || exprs.iter().all(|s| s.is_empty()) {
231 return self;
232 }
233
234 if self.where_clause.is_none() {
235 let wc = WhereClause::new();
236 let ph = self.var(Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))));
237 self.where_clause = Some(wc);
238 self.where_var = Some(ph);
239 }
240
241 self.where_clause
242 .as_ref()
243 .unwrap()
244 .borrow_mut()
245 .add_where_expr(self.args.clone(), exprs);
246 self.marker = UPDATE_MARKER_AFTER_WITH;
247 self
248 }
249
250 pub fn add_where_expr<T>(&mut self, args: ArgsRef, exprs: T) -> &mut Self
251 where
252 T: IntoStrings,
253 {
254 let exprs = collect_into_strings(exprs);
255 if exprs.is_empty() || exprs.iter().all(|s| s.is_empty()) {
256 return self;
257 }
258 if self.where_clause.is_none() {
259 let wc = WhereClause::new();
260 let ph = self.var(Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))));
261 self.where_clause = Some(wc);
262 self.where_var = Some(ph);
263 }
264 self.where_clause
265 .as_ref()
266 .unwrap()
267 .borrow_mut()
268 .add_where_expr(args, exprs);
269 self.marker = UPDATE_MARKER_AFTER_WHERE;
270 self
271 }
272
273 pub fn add_where_clause(&mut self, other: &WhereClause) -> &mut Self {
274 if self.where_clause.is_none() {
275 let wc = WhereClause::new();
276 let ph = self.var(Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))));
277 self.where_clause = Some(wc);
278 self.where_var = Some(ph);
279 }
280 self.where_clause
281 .as_ref()
282 .unwrap()
283 .borrow_mut()
284 .add_where_clause(other);
285 self
286 }
287
288 pub fn add_where_clause_ref(&mut self, other: &WhereClauseRef) -> &mut Self {
289 if self.where_clause.is_none() {
290 let wc = WhereClause::new();
291 let ph = self.var(Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))));
292 self.where_clause = Some(wc);
293 self.where_var = Some(ph);
294 }
295 self.where_clause
296 .as_ref()
297 .unwrap()
298 .borrow_mut()
299 .add_where_clause(&other.borrow());
300 self
301 }
302
303 pub fn assign(&self, field: &str, value: impl Into<Arg>) -> String {
304 format!("{} = {}", escape(field), self.var(value))
305 }
306
307 pub fn incr(&self, field: &str) -> String {
308 let f = escape(field);
309 format!("{f} = {f} + 1")
310 }
311
312 pub fn decr(&self, field: &str) -> String {
313 let f = escape(field);
314 format!("{f} = {f} - 1")
315 }
316
317 pub fn add_(&self, field: &str, value: impl Into<Arg>) -> String {
318 let f = escape(field);
319 format!("{f} = {f} + {}", self.var(value))
320 }
321
322 pub fn add(&self, field: &str, value: impl Into<Arg>) -> String {
324 self.add_(field, value)
325 }
326
327 pub fn sub(&self, field: &str, value: impl Into<Arg>) -> String {
328 let f = escape(field);
329 format!("{f} = {f} - {}", self.var(value))
330 }
331
332 pub fn mul(&self, field: &str, value: impl Into<Arg>) -> String {
333 let f = escape(field);
334 format!("{f} = {f} * {}", self.var(value))
335 }
336
337 pub fn div(&self, field: &str, value: impl Into<Arg>) -> String {
338 let f = escape(field);
339 format!("{f} = {f} / {}", self.var(value))
340 }
341
342 pub fn order_by<T>(&mut self, cols: T) -> &mut Self
343 where
344 T: IntoStrings,
345 {
346 self.order_by_cols = collect_into_strings(cols);
347 self.marker = UPDATE_MARKER_AFTER_ORDER_BY;
348 self
349 }
350
351 pub fn order_by_asc(&mut self, col: impl Into<String>) -> &mut Self {
352 self.order_by_cols.push(format!("{} ASC", col.into()));
353 self.marker = UPDATE_MARKER_AFTER_ORDER_BY;
354 self
355 }
356
357 pub fn order_by_desc(&mut self, col: impl Into<String>) -> &mut Self {
358 self.order_by_cols.push(format!("{} DESC", col.into()));
359 self.marker = UPDATE_MARKER_AFTER_ORDER_BY;
360 self
361 }
362
363 pub fn asc(&mut self) -> &mut Self {
364 self.order = Some("ASC");
365 self.marker = UPDATE_MARKER_AFTER_ORDER_BY;
366 self
367 }
368
369 pub fn desc(&mut self) -> &mut Self {
370 self.order = Some("DESC");
371 self.marker = UPDATE_MARKER_AFTER_ORDER_BY;
372 self
373 }
374
375 pub fn limit(&mut self, limit: i64) -> &mut Self {
376 if limit < 0 {
377 self.limit_var = None;
378 return self;
379 }
380 self.limit_var = Some(self.var(limit));
381 self.marker = UPDATE_MARKER_AFTER_LIMIT;
382 self
383 }
384
385 pub fn returning<T>(&mut self, cols: T) -> &mut Self
386 where
387 T: IntoStrings,
388 {
389 self.returning = collect_into_strings(cols);
390 self.marker = UPDATE_MARKER_AFTER_RETURNING;
391 self
392 }
393
394 pub fn num_assignment(&self) -> usize {
396 self.assignments.iter().filter(|s| !s.is_empty()).count()
397 }
398
399 pub fn sql(&mut self, sql: impl Into<String>) -> &mut Self {
400 self.injection.sql(self.marker, sql);
401 self
402 }
403}
404
405impl Builder for UpdateBuilder {
406 fn build_with_flavor(&self, flavor: Flavor, initial_arg: &[Arg]) -> (String, Vec<Arg>) {
407 let mut buf = StringBuilder::new();
408 write_injection(&mut buf, &self.injection, UPDATE_MARKER_INIT);
409
410 if let Some(ph) = &self.cte_var {
411 buf.write_leading(ph);
412 write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_WITH);
413 }
414
415 match flavor {
416 Flavor::MySQL => {
417 let table_names = self.table_names();
418 if !table_names.is_empty() {
419 buf.write_leading("UPDATE");
420 buf.write_str(" ");
421 buf.write_str(&table_names.join(", "));
422 }
423 }
424 _ => {
425 if !self.tables.is_empty() {
426 buf.write_leading("UPDATE");
427 buf.write_str(" ");
428 buf.write_str(&self.tables.join(", "));
429 }
430 }
431 }
432 write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_UPDATE);
433
434 let assigns: Vec<String> = self
435 .assignments
436 .iter()
437 .filter(|s| !s.is_empty())
438 .cloned()
439 .collect();
440 if !assigns.is_empty() {
441 buf.write_leading("SET");
442 buf.write_str(" ");
443 buf.write_str(&assigns.join(", "));
444 }
445 write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_SET);
446
447 if flavor != Flavor::MySQL
448 && let Some(cte) = &self.cte
449 {
450 let cte_table_names = cte.table_names_for_from();
451 if !cte_table_names.is_empty() {
452 buf.write_leading("FROM");
453 buf.write_str(" ");
454 buf.write_str(&cte_table_names.join(", "));
455 }
456 }
457
458 if flavor == Flavor::SQLServer && !self.returning.is_empty() {
459 buf.write_leading("OUTPUT");
460 buf.write_str(" ");
461 let prefixed: Vec<String> = self
462 .returning
463 .iter()
464 .map(|c| format!("INSERTED.{c}"))
465 .collect();
466 buf.write_str(&prefixed.join(", "));
467 write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_RETURNING);
468 }
469
470 if let Some(ph) = &self.where_var {
471 buf.write_leading(ph);
472 write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_WHERE);
473 }
474
475 if !self.order_by_cols.is_empty() {
476 buf.write_leading("ORDER BY");
477 buf.write_str(" ");
478 buf.write_str(&self.order_by_cols.join(", "));
479 if let Some(order) = self.order {
480 buf.write_str(" ");
481 buf.write_str(order);
482 }
483 write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_ORDER_BY);
484 }
485
486 if let Some(lim) = &self.limit_var {
487 buf.write_leading("LIMIT");
488 buf.write_str(" ");
489 buf.write_str(lim);
490 write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_LIMIT);
491 }
492
493 if (flavor == Flavor::PostgreSQL || flavor == Flavor::SQLite) && !self.returning.is_empty()
494 {
495 buf.write_leading("RETURNING");
496 buf.write_str(" ");
497 buf.write_str(&self.returning.join(", "));
498 write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_RETURNING);
499 }
500
501 self.args
502 .borrow()
503 .compile_with_flavor(&buf.into_string(), flavor, initial_arg)
504 }
505
506 fn flavor(&self) -> Flavor {
507 self.flavor()
508 }
509}
510
511fn write_injection(buf: &mut StringBuilder, inj: &Injection, marker: InjectionMarker) {
512 let sqls = inj.at(marker);
513 if sqls.is_empty() {
514 return;
515 }
516 buf.write_leading("");
517 buf.write_str(&sqls.join(" "));
518}