halo_space/
update.rs

1//! UpdateBuilder: build UPDATE statements.
2
3use crate::args::Args;
4use crate::cond::{ArgsRef, Cond};
5use crate::cte::CTEBuilder;
6use crate::flavor::Flavor;
7use crate::injection::{Injection, InjectionMarker};
8use crate::macros::{IntoStrings, collect_into_strings};
9use crate::modifiers::{Arg, Builder, escape};
10use crate::string_builder::StringBuilder;
11use crate::where_clause::{WhereClause, WhereClauseBuilder, WhereClauseRef};
12use std::cell::RefCell;
13use std::ops::Deref;
14use std::rc::Rc;
15
16const UPDATE_MARKER_INIT: InjectionMarker = 0;
17const UPDATE_MARKER_AFTER_WITH: InjectionMarker = 1;
18const UPDATE_MARKER_AFTER_UPDATE: InjectionMarker = 2;
19const UPDATE_MARKER_AFTER_SET: InjectionMarker = 3;
20const UPDATE_MARKER_AFTER_WHERE: InjectionMarker = 4;
21const UPDATE_MARKER_AFTER_ORDER_BY: InjectionMarker = 5;
22const UPDATE_MARKER_AFTER_LIMIT: InjectionMarker = 6;
23const UPDATE_MARKER_AFTER_RETURNING: InjectionMarker = 7;
24
25#[derive(Debug)]
26pub struct UpdateBuilder {
27    args: ArgsRef,
28    cond: Cond,
29
30    tables: Vec<String>,
31    assignments: Vec<String>,
32
33    where_clause: Option<WhereClauseRef>,
34    where_var: Option<String>,
35    cte_var: Option<String>,
36    cte: Option<CTEBuilder>,
37
38    order_by_cols: Vec<String>,
39    order: Option<&'static str>,
40    limit_var: Option<String>,
41    returning: Vec<String>,
42
43    injection: Injection,
44    marker: InjectionMarker,
45}
46
47impl Deref for UpdateBuilder {
48    type Target = Cond;
49    fn deref(&self) -> &Self::Target {
50        &self.cond
51    }
52}
53
54impl Default for UpdateBuilder {
55    fn default() -> Self {
56        Self::new()
57    }
58}
59
60impl Clone for UpdateBuilder {
61    fn clone(&self) -> Self {
62        self.clone_builder()
63    }
64}
65
66impl UpdateBuilder {
67    pub fn new() -> Self {
68        let args = Rc::new(RefCell::new(Args::default()));
69        let cond = Cond::with_args(args.clone());
70        Self {
71            args,
72            cond,
73            tables: Vec::new(),
74            assignments: Vec::new(),
75            where_clause: None,
76            where_var: None,
77            cte_var: None,
78            cte: None,
79            order_by_cols: Vec::new(),
80            order: None,
81            limit_var: None,
82            returning: Vec::new(),
83            injection: Injection::new(),
84            marker: UPDATE_MARKER_INIT,
85        }
86    }
87
88    pub fn set_flavor(&mut self, flavor: Flavor) -> Flavor {
89        let mut a = self.args.borrow_mut();
90        let old = a.flavor;
91        a.flavor = flavor;
92        old
93    }
94
95    pub fn flavor(&self) -> Flavor {
96        self.args.borrow().flavor
97    }
98
99    pub fn with(&mut self, cte: &CTEBuilder) -> &mut Self {
100        let cte_clone = cte.clone();
101        let ph = self.var(Arg::Builder(Box::new(cte.clone())));
102        self.cte = Some(cte_clone);
103        self.cte_var = Some(ph);
104        self.marker = UPDATE_MARKER_AFTER_WHERE; // temporarily? Wait we need new constant? run with marker?
105        self
106    }
107
108    fn table_names(&self) -> Vec<String> {
109        let mut table_names = Vec::new();
110        if !self.tables.is_empty() {
111            table_names.extend(self.tables.clone());
112        }
113        if let Some(cte) = &self.cte {
114            table_names.extend(cte.table_names_for_from());
115        }
116        table_names
117    }
118
119    pub fn where_clause(&self) -> Option<WhereClauseRef> {
120        self.where_clause.clone()
121    }
122
123    pub fn set_where_clause(&mut self, wc: Option<WhereClauseRef>) -> &mut Self {
124        match wc {
125            None => {
126                self.where_clause = None;
127                self.where_var = None;
128            }
129            Some(wc) => {
130                if let Some(ph) = &self.where_var {
131                    self.args.borrow_mut().replace(
132                        ph,
133                        Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))),
134                    );
135                } else {
136                    let ph = self.var(Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))));
137                    self.where_var = Some(ph);
138                }
139                self.where_clause = Some(wc);
140            }
141        }
142        self
143    }
144
145    pub fn clear_where_clause(&mut self) -> &mut Self {
146        self.set_where_clause(None)
147    }
148
149    pub fn clone_builder(&self) -> Self {
150        let old_args = self.args.borrow().clone();
151        let args = Rc::new(RefCell::new(old_args));
152        let cond = Cond::with_args(args.clone());
153
154        let mut cloned = Self {
155            args,
156            cond,
157            tables: self.tables.clone(),
158            assignments: self.assignments.clone(),
159            where_clause: self.where_clause.clone(),
160            where_var: self.where_var.clone(),
161            cte_var: self.cte_var.clone(),
162            cte: self.cte.clone(),
163            order_by_cols: self.order_by_cols.clone(),
164            order: self.order,
165            limit_var: self.limit_var.clone(),
166            returning: self.returning.clone(),
167            injection: self.injection.clone(),
168            marker: self.marker,
169        };
170
171        if let (Some(wc), Some(ph)) = (&self.where_clause, &self.where_var) {
172            let new_wc = Rc::new(RefCell::new(wc.borrow().clone()));
173            cloned.where_clause = Some(new_wc.clone());
174            cloned
175                .args
176                .borrow_mut()
177                .replace(ph, Arg::Builder(Box::new(WhereClauseBuilder::new(new_wc))));
178        }
179
180        if let (Some(cte), Some(ph)) = (&self.cte, &self.cte_var) {
181            let cte_for_arg = cte.clone();
182            let cte_for_field = cte_for_arg.clone();
183            cloned.cte = Some(cte_for_field);
184            cloned
185                .args
186                .borrow_mut()
187                .replace(ph, Arg::Builder(Box::new(cte_for_arg)));
188        }
189
190        cloned
191    }
192
193    pub fn build(&self) -> (String, Vec<Arg>) {
194        Builder::build(self)
195    }
196
197    fn var(&self, v: impl Into<Arg>) -> String {
198        self.args.borrow_mut().add(v)
199    }
200
201    pub fn update<T>(&mut self, tables: T) -> &mut Self
202    where
203        T: IntoStrings,
204    {
205        self.tables = collect_into_strings(tables);
206        self.marker = UPDATE_MARKER_AFTER_UPDATE;
207        self
208    }
209
210    pub fn set<T>(&mut self, assignments: T) -> &mut Self
211    where
212        T: IntoStrings,
213    {
214        self.assignments = collect_into_strings(assignments);
215        self.marker = UPDATE_MARKER_AFTER_SET;
216        self
217    }
218
219    pub fn set_more(&mut self, assignments: impl IntoStrings) -> &mut Self {
220        self.assignments.extend(collect_into_strings(assignments));
221        self.marker = UPDATE_MARKER_AFTER_SET;
222        self
223    }
224
225    pub fn where_<T>(&mut self, and_expr: T) -> &mut Self
226    where
227        T: IntoStrings,
228    {
229        let exprs = collect_into_strings(and_expr);
230        if exprs.is_empty() || exprs.iter().all(|s| s.is_empty()) {
231            return self;
232        }
233
234        if self.where_clause.is_none() {
235            let wc = WhereClause::new();
236            let ph = self.var(Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))));
237            self.where_clause = Some(wc);
238            self.where_var = Some(ph);
239        }
240
241        self.where_clause
242            .as_ref()
243            .unwrap()
244            .borrow_mut()
245            .add_where_expr(self.args.clone(), exprs);
246        self.marker = UPDATE_MARKER_AFTER_WITH;
247        self
248    }
249
250    pub fn add_where_expr<T>(&mut self, args: ArgsRef, exprs: T) -> &mut Self
251    where
252        T: IntoStrings,
253    {
254        let exprs = collect_into_strings(exprs);
255        if exprs.is_empty() || exprs.iter().all(|s| s.is_empty()) {
256            return self;
257        }
258        if self.where_clause.is_none() {
259            let wc = WhereClause::new();
260            let ph = self.var(Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))));
261            self.where_clause = Some(wc);
262            self.where_var = Some(ph);
263        }
264        self.where_clause
265            .as_ref()
266            .unwrap()
267            .borrow_mut()
268            .add_where_expr(args, exprs);
269        self.marker = UPDATE_MARKER_AFTER_WHERE;
270        self
271    }
272
273    pub fn add_where_clause(&mut self, other: &WhereClause) -> &mut Self {
274        if self.where_clause.is_none() {
275            let wc = WhereClause::new();
276            let ph = self.var(Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))));
277            self.where_clause = Some(wc);
278            self.where_var = Some(ph);
279        }
280        self.where_clause
281            .as_ref()
282            .unwrap()
283            .borrow_mut()
284            .add_where_clause(other);
285        self
286    }
287
288    pub fn add_where_clause_ref(&mut self, other: &WhereClauseRef) -> &mut Self {
289        if self.where_clause.is_none() {
290            let wc = WhereClause::new();
291            let ph = self.var(Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))));
292            self.where_clause = Some(wc);
293            self.where_var = Some(ph);
294        }
295        self.where_clause
296            .as_ref()
297            .unwrap()
298            .borrow_mut()
299            .add_where_clause(&other.borrow());
300        self
301    }
302
303    pub fn assign(&self, field: &str, value: impl Into<Arg>) -> String {
304        format!("{} = {}", escape(field), self.var(value))
305    }
306
307    pub fn incr(&self, field: &str) -> String {
308        let f = escape(field);
309        format!("{f} = {f} + 1")
310    }
311
312    pub fn decr(&self, field: &str) -> String {
313        let f = escape(field);
314        format!("{f} = {f} - 1")
315    }
316
317    pub fn add_(&self, field: &str, value: impl Into<Arg>) -> String {
318        let f = escape(field);
319        format!("{f} = {f} + {}", self.var(value))
320    }
321
322    /// Add an assignment expression.
323    pub fn add(&self, field: &str, value: impl Into<Arg>) -> String {
324        self.add_(field, value)
325    }
326
327    pub fn sub(&self, field: &str, value: impl Into<Arg>) -> String {
328        let f = escape(field);
329        format!("{f} = {f} - {}", self.var(value))
330    }
331
332    pub fn mul(&self, field: &str, value: impl Into<Arg>) -> String {
333        let f = escape(field);
334        format!("{f} = {f} * {}", self.var(value))
335    }
336
337    pub fn div(&self, field: &str, value: impl Into<Arg>) -> String {
338        let f = escape(field);
339        format!("{f} = {f} / {}", self.var(value))
340    }
341
342    pub fn order_by<T>(&mut self, cols: T) -> &mut Self
343    where
344        T: IntoStrings,
345    {
346        self.order_by_cols = collect_into_strings(cols);
347        self.marker = UPDATE_MARKER_AFTER_ORDER_BY;
348        self
349    }
350
351    pub fn order_by_asc(&mut self, col: impl Into<String>) -> &mut Self {
352        self.order_by_cols.push(format!("{} ASC", col.into()));
353        self.marker = UPDATE_MARKER_AFTER_ORDER_BY;
354        self
355    }
356
357    pub fn order_by_desc(&mut self, col: impl Into<String>) -> &mut Self {
358        self.order_by_cols.push(format!("{} DESC", col.into()));
359        self.marker = UPDATE_MARKER_AFTER_ORDER_BY;
360        self
361    }
362
363    pub fn asc(&mut self) -> &mut Self {
364        self.order = Some("ASC");
365        self.marker = UPDATE_MARKER_AFTER_ORDER_BY;
366        self
367    }
368
369    pub fn desc(&mut self) -> &mut Self {
370        self.order = Some("DESC");
371        self.marker = UPDATE_MARKER_AFTER_ORDER_BY;
372        self
373    }
374
375    pub fn limit(&mut self, limit: i64) -> &mut Self {
376        if limit < 0 {
377            self.limit_var = None;
378            return self;
379        }
380        self.limit_var = Some(self.var(limit));
381        self.marker = UPDATE_MARKER_AFTER_LIMIT;
382        self
383    }
384
385    pub fn returning<T>(&mut self, cols: T) -> &mut Self
386    where
387        T: IntoStrings,
388    {
389        self.returning = collect_into_strings(cols);
390        self.marker = UPDATE_MARKER_AFTER_RETURNING;
391        self
392    }
393
394    /// Number of assignment expressions.
395    pub fn num_assignment(&self) -> usize {
396        self.assignments.iter().filter(|s| !s.is_empty()).count()
397    }
398
399    pub fn sql(&mut self, sql: impl Into<String>) -> &mut Self {
400        self.injection.sql(self.marker, sql);
401        self
402    }
403}
404
405impl Builder for UpdateBuilder {
406    fn build_with_flavor(&self, flavor: Flavor, initial_arg: &[Arg]) -> (String, Vec<Arg>) {
407        let mut buf = StringBuilder::new();
408        write_injection(&mut buf, &self.injection, UPDATE_MARKER_INIT);
409
410        if let Some(ph) = &self.cte_var {
411            buf.write_leading(ph);
412            write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_WITH);
413        }
414
415        match flavor {
416            Flavor::MySQL => {
417                let table_names = self.table_names();
418                if !table_names.is_empty() {
419                    buf.write_leading("UPDATE");
420                    buf.write_str(" ");
421                    buf.write_str(&table_names.join(", "));
422                }
423            }
424            _ => {
425                if !self.tables.is_empty() {
426                    buf.write_leading("UPDATE");
427                    buf.write_str(" ");
428                    buf.write_str(&self.tables.join(", "));
429                }
430            }
431        }
432        write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_UPDATE);
433
434        let assigns: Vec<String> = self
435            .assignments
436            .iter()
437            .filter(|s| !s.is_empty())
438            .cloned()
439            .collect();
440        if !assigns.is_empty() {
441            buf.write_leading("SET");
442            buf.write_str(" ");
443            buf.write_str(&assigns.join(", "));
444        }
445        write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_SET);
446
447        if flavor != Flavor::MySQL
448            && let Some(cte) = &self.cte
449        {
450            let cte_table_names = cte.table_names_for_from();
451            if !cte_table_names.is_empty() {
452                buf.write_leading("FROM");
453                buf.write_str(" ");
454                buf.write_str(&cte_table_names.join(", "));
455            }
456        }
457
458        if flavor == Flavor::SQLServer && !self.returning.is_empty() {
459            buf.write_leading("OUTPUT");
460            buf.write_str(" ");
461            let prefixed: Vec<String> = self
462                .returning
463                .iter()
464                .map(|c| format!("INSERTED.{c}"))
465                .collect();
466            buf.write_str(&prefixed.join(", "));
467            write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_RETURNING);
468        }
469
470        if let Some(ph) = &self.where_var {
471            buf.write_leading(ph);
472            write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_WHERE);
473        }
474
475        if !self.order_by_cols.is_empty() {
476            buf.write_leading("ORDER BY");
477            buf.write_str(" ");
478            buf.write_str(&self.order_by_cols.join(", "));
479            if let Some(order) = self.order {
480                buf.write_str(" ");
481                buf.write_str(order);
482            }
483            write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_ORDER_BY);
484        }
485
486        if let Some(lim) = &self.limit_var {
487            buf.write_leading("LIMIT");
488            buf.write_str(" ");
489            buf.write_str(lim);
490            write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_LIMIT);
491        }
492
493        if (flavor == Flavor::PostgreSQL || flavor == Flavor::SQLite) && !self.returning.is_empty()
494        {
495            buf.write_leading("RETURNING");
496            buf.write_str(" ");
497            buf.write_str(&self.returning.join(", "));
498            write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_RETURNING);
499        }
500
501        self.args
502            .borrow()
503            .compile_with_flavor(&buf.into_string(), flavor, initial_arg)
504    }
505
506    fn flavor(&self) -> Flavor {
507        self.flavor()
508    }
509}
510
511fn write_injection(buf: &mut StringBuilder, inj: &Injection, marker: InjectionMarker) {
512    let sqls = inj.at(marker);
513    if sqls.is_empty() {
514        return;
515    }
516    buf.write_leading("");
517    buf.write_str(&sqls.join(" "));
518}