halo/
update.rs

1//! UpdateBuilder:构建 UPDATE 语句(对齐 go-sqlbuilder `update.go` 的核心行为)。
2
3use crate::args::Args;
4use crate::cond::{ArgsRef, Cond};
5use crate::cte::CTEBuilder;
6use crate::flavor::Flavor;
7use crate::injection::{Injection, InjectionMarker};
8use crate::modifiers::{Arg, Builder, escape};
9use crate::string_builder::StringBuilder;
10use crate::where_clause::{WhereClause, WhereClauseBuilder, WhereClauseRef};
11use std::cell::RefCell;
12use std::ops::Deref;
13use std::rc::Rc;
14
15const UPDATE_MARKER_INIT: InjectionMarker = 0;
16const UPDATE_MARKER_AFTER_WITH: InjectionMarker = 1;
17const UPDATE_MARKER_AFTER_UPDATE: InjectionMarker = 2;
18const UPDATE_MARKER_AFTER_SET: InjectionMarker = 3;
19const UPDATE_MARKER_AFTER_WHERE: InjectionMarker = 4;
20const UPDATE_MARKER_AFTER_ORDER_BY: InjectionMarker = 5;
21const UPDATE_MARKER_AFTER_LIMIT: InjectionMarker = 6;
22const UPDATE_MARKER_AFTER_RETURNING: InjectionMarker = 7;
23
24#[derive(Debug)]
25pub struct UpdateBuilder {
26    args: ArgsRef,
27    cond: Cond,
28
29    tables: Vec<String>,
30    assignments: Vec<String>,
31
32    where_clause: Option<WhereClauseRef>,
33    where_var: Option<String>,
34    cte_var: Option<String>,
35    cte: Option<CTEBuilder>,
36
37    order_by_cols: Vec<String>,
38    order: Option<&'static str>,
39    limit_var: Option<String>,
40    returning: Vec<String>,
41
42    injection: Injection,
43    marker: InjectionMarker,
44}
45
46impl Deref for UpdateBuilder {
47    type Target = Cond;
48    fn deref(&self) -> &Self::Target {
49        &self.cond
50    }
51}
52
53impl Default for UpdateBuilder {
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59impl Clone for UpdateBuilder {
60    fn clone(&self) -> Self {
61        self.clone_builder()
62    }
63}
64
65impl UpdateBuilder {
66    pub fn new() -> Self {
67        let args = Rc::new(RefCell::new(Args::default()));
68        let cond = Cond::with_args(args.clone());
69        Self {
70            args,
71            cond,
72            tables: Vec::new(),
73            assignments: Vec::new(),
74            where_clause: None,
75            where_var: None,
76            cte_var: None,
77            cte: None,
78            order_by_cols: Vec::new(),
79            order: None,
80            limit_var: None,
81            returning: Vec::new(),
82            injection: Injection::new(),
83            marker: UPDATE_MARKER_INIT,
84        }
85    }
86
87    pub fn set_flavor(&mut self, flavor: Flavor) -> Flavor {
88        let mut a = self.args.borrow_mut();
89        let old = a.flavor;
90        a.flavor = flavor;
91        old
92    }
93
94    pub fn flavor(&self) -> Flavor {
95        self.args.borrow().flavor
96    }
97
98    pub fn with(&mut self, cte: &CTEBuilder) -> &mut Self {
99        let cte_clone = cte.clone();
100        let ph = self.var(Arg::Builder(Box::new(cte.clone())));
101        self.cte = Some(cte_clone);
102        self.cte_var = Some(ph);
103        self.marker = UPDATE_MARKER_AFTER_WHERE; // temporarily? Wait we need new constant? run with marker?
104        self
105    }
106
107    fn table_names(&self) -> Vec<String> {
108        let mut table_names = Vec::new();
109        if !self.tables.is_empty() {
110            table_names.extend(self.tables.clone());
111        }
112        if let Some(cte) = &self.cte {
113            table_names.extend(cte.table_names_for_from());
114        }
115        table_names
116    }
117
118    pub fn where_clause(&self) -> Option<WhereClauseRef> {
119        self.where_clause.clone()
120    }
121
122    pub fn set_where_clause(&mut self, wc: Option<WhereClauseRef>) -> &mut Self {
123        match wc {
124            None => {
125                self.where_clause = None;
126                self.where_var = None;
127            }
128            Some(wc) => {
129                if let Some(ph) = &self.where_var {
130                    self.args.borrow_mut().replace(
131                        ph,
132                        Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))),
133                    );
134                } else {
135                    let ph = self.var(Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))));
136                    self.where_var = Some(ph);
137                }
138                self.where_clause = Some(wc);
139            }
140        }
141        self
142    }
143
144    pub fn clear_where_clause(&mut self) -> &mut Self {
145        self.set_where_clause(None)
146    }
147
148    pub fn clone_builder(&self) -> Self {
149        let old_args = self.args.borrow().clone();
150        let args = Rc::new(RefCell::new(old_args));
151        let cond = Cond::with_args(args.clone());
152
153        let mut cloned = Self {
154            args,
155            cond,
156            tables: self.tables.clone(),
157            assignments: self.assignments.clone(),
158            where_clause: self.where_clause.clone(),
159            where_var: self.where_var.clone(),
160            cte_var: self.cte_var.clone(),
161            cte: self.cte.clone(),
162            order_by_cols: self.order_by_cols.clone(),
163            order: self.order,
164            limit_var: self.limit_var.clone(),
165            returning: self.returning.clone(),
166            injection: self.injection.clone(),
167            marker: self.marker,
168        };
169
170        if let (Some(wc), Some(ph)) = (&self.where_clause, &self.where_var) {
171            let new_wc = Rc::new(RefCell::new(wc.borrow().clone()));
172            cloned.where_clause = Some(new_wc.clone());
173            cloned
174                .args
175                .borrow_mut()
176                .replace(ph, Arg::Builder(Box::new(WhereClauseBuilder::new(new_wc))));
177        }
178
179        if let (Some(cte), Some(ph)) = (&self.cte, &self.cte_var) {
180            let cte_for_arg = cte.clone();
181            let cte_for_field = cte_for_arg.clone();
182            cloned.cte = Some(cte_for_field);
183            cloned
184                .args
185                .borrow_mut()
186                .replace(ph, Arg::Builder(Box::new(cte_for_arg)));
187        }
188
189        cloned
190    }
191
192    fn var(&self, v: impl Into<Arg>) -> String {
193        self.args.borrow_mut().add(v)
194    }
195
196    pub fn update(&mut self, tables: impl IntoIterator<Item = impl Into<String>>) -> &mut Self {
197        self.tables = tables.into_iter().map(Into::into).collect();
198        self.marker = UPDATE_MARKER_AFTER_UPDATE;
199        self
200    }
201
202    pub fn set(&mut self, assignments: impl IntoIterator<Item = impl Into<String>>) -> &mut Self {
203        self.assignments = assignments.into_iter().map(Into::into).collect();
204        self.marker = UPDATE_MARKER_AFTER_SET;
205        self
206    }
207
208    pub fn set_more(
209        &mut self,
210        assignments: impl IntoIterator<Item = impl Into<String>>,
211    ) -> &mut Self {
212        self.assignments
213            .extend(assignments.into_iter().map(Into::into));
214        self.marker = UPDATE_MARKER_AFTER_SET;
215        self
216    }
217
218    pub fn where_(&mut self, and_expr: impl IntoIterator<Item = impl Into<String>>) -> &mut Self {
219        let exprs: Vec<String> = and_expr.into_iter().map(Into::into).collect();
220        if exprs.is_empty() || exprs.iter().all(|s| s.is_empty()) {
221            return self;
222        }
223
224        if self.where_clause.is_none() {
225            let wc = WhereClause::new();
226            let ph = self.var(Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))));
227            self.where_clause = Some(wc);
228            self.where_var = Some(ph);
229        }
230
231        self.where_clause
232            .as_ref()
233            .unwrap()
234            .borrow_mut()
235            .add_where_expr(self.args.clone(), exprs);
236        self.marker = UPDATE_MARKER_AFTER_WITH;
237        self
238    }
239
240    pub fn add_where_expr(
241        &mut self,
242        args: ArgsRef,
243        exprs: impl IntoIterator<Item = impl Into<String>>,
244    ) -> &mut Self {
245        let exprs: Vec<String> = exprs.into_iter().map(Into::into).collect();
246        if exprs.is_empty() || exprs.iter().all(|s| s.is_empty()) {
247            return self;
248        }
249        if self.where_clause.is_none() {
250            let wc = WhereClause::new();
251            let ph = self.var(Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))));
252            self.where_clause = Some(wc);
253            self.where_var = Some(ph);
254        }
255        self.where_clause
256            .as_ref()
257            .unwrap()
258            .borrow_mut()
259            .add_where_expr(args, exprs);
260        self.marker = UPDATE_MARKER_AFTER_WHERE;
261        self
262    }
263
264    pub fn add_where_clause(&mut self, other: &WhereClause) -> &mut Self {
265        if self.where_clause.is_none() {
266            let wc = WhereClause::new();
267            let ph = self.var(Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))));
268            self.where_clause = Some(wc);
269            self.where_var = Some(ph);
270        }
271        self.where_clause
272            .as_ref()
273            .unwrap()
274            .borrow_mut()
275            .add_where_clause(other);
276        self
277    }
278
279    pub fn add_where_clause_ref(&mut self, other: &WhereClauseRef) -> &mut Self {
280        if self.where_clause.is_none() {
281            let wc = WhereClause::new();
282            let ph = self.var(Arg::Builder(Box::new(WhereClauseBuilder::new(wc.clone()))));
283            self.where_clause = Some(wc);
284            self.where_var = Some(ph);
285        }
286        self.where_clause
287            .as_ref()
288            .unwrap()
289            .borrow_mut()
290            .add_where_clause(&other.borrow());
291        self
292    }
293
294    pub fn assign(&self, field: &str, value: impl Into<Arg>) -> String {
295        format!("{} = {}", escape(field), self.var(value))
296    }
297
298    pub fn incr(&self, field: &str) -> String {
299        let f = escape(field);
300        format!("{f} = {f} + 1")
301    }
302
303    pub fn decr(&self, field: &str) -> String {
304        let f = escape(field);
305        format!("{f} = {f} - 1")
306    }
307
308    pub fn add_(&self, field: &str, value: impl Into<Arg>) -> String {
309        let f = escape(field);
310        format!("{f} = {f} + {}", self.var(value))
311    }
312
313    /// Add:对齐 go-sqlbuilder `UpdateBuilder.Add`。
314    pub fn add(&self, field: &str, value: impl Into<Arg>) -> String {
315        self.add_(field, value)
316    }
317
318    pub fn sub(&self, field: &str, value: impl Into<Arg>) -> String {
319        let f = escape(field);
320        format!("{f} = {f} - {}", self.var(value))
321    }
322
323    pub fn mul(&self, field: &str, value: impl Into<Arg>) -> String {
324        let f = escape(field);
325        format!("{f} = {f} * {}", self.var(value))
326    }
327
328    pub fn div(&self, field: &str, value: impl Into<Arg>) -> String {
329        let f = escape(field);
330        format!("{f} = {f} / {}", self.var(value))
331    }
332
333    pub fn order_by(&mut self, cols: impl IntoIterator<Item = impl Into<String>>) -> &mut Self {
334        self.order_by_cols = cols.into_iter().map(Into::into).collect();
335        self.marker = UPDATE_MARKER_AFTER_ORDER_BY;
336        self
337    }
338
339    pub fn order_by_asc(&mut self, col: impl Into<String>) -> &mut Self {
340        self.order_by_cols.push(format!("{} ASC", col.into()));
341        self.marker = UPDATE_MARKER_AFTER_ORDER_BY;
342        self
343    }
344
345    pub fn order_by_desc(&mut self, col: impl Into<String>) -> &mut Self {
346        self.order_by_cols.push(format!("{} DESC", col.into()));
347        self.marker = UPDATE_MARKER_AFTER_ORDER_BY;
348        self
349    }
350
351    pub fn asc(&mut self) -> &mut Self {
352        self.order = Some("ASC");
353        self.marker = UPDATE_MARKER_AFTER_ORDER_BY;
354        self
355    }
356
357    pub fn desc(&mut self) -> &mut Self {
358        self.order = Some("DESC");
359        self.marker = UPDATE_MARKER_AFTER_ORDER_BY;
360        self
361    }
362
363    pub fn limit(&mut self, limit: i64) -> &mut Self {
364        if limit < 0 {
365            self.limit_var = None;
366            return self;
367        }
368        self.limit_var = Some(self.var(limit));
369        self.marker = UPDATE_MARKER_AFTER_LIMIT;
370        self
371    }
372
373    pub fn returning(&mut self, cols: impl IntoIterator<Item = impl Into<String>>) -> &mut Self {
374        self.returning = cols.into_iter().map(Into::into).collect();
375        self.marker = UPDATE_MARKER_AFTER_RETURNING;
376        self
377    }
378
379    /// NumAssignment:对齐 go-sqlbuilder `UpdateBuilder.NumAssignment()`。
380    pub fn num_assignment(&self) -> usize {
381        self.assignments.iter().filter(|s| !s.is_empty()).count()
382    }
383
384    pub fn sql(&mut self, sql: impl Into<String>) -> &mut Self {
385        self.injection.sql(self.marker, sql);
386        self
387    }
388}
389
390impl Builder for UpdateBuilder {
391    fn build_with_flavor(&self, flavor: Flavor, initial_arg: &[Arg]) -> (String, Vec<Arg>) {
392        let mut buf = StringBuilder::new();
393        write_injection(&mut buf, &self.injection, UPDATE_MARKER_INIT);
394
395        if let Some(ph) = &self.cte_var {
396            buf.write_leading(ph);
397            write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_WITH);
398        }
399
400        match flavor {
401            Flavor::MySQL => {
402                let table_names = self.table_names();
403                if !table_names.is_empty() {
404                    buf.write_leading("UPDATE");
405                    buf.write_str(" ");
406                    buf.write_str(&table_names.join(", "));
407                }
408            }
409            _ => {
410                if !self.tables.is_empty() {
411                    buf.write_leading("UPDATE");
412                    buf.write_str(" ");
413                    buf.write_str(&self.tables.join(", "));
414                }
415            }
416        }
417        write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_UPDATE);
418
419        let assigns: Vec<String> = self
420            .assignments
421            .iter()
422            .filter(|s| !s.is_empty())
423            .cloned()
424            .collect();
425        if !assigns.is_empty() {
426            buf.write_leading("SET");
427            buf.write_str(" ");
428            buf.write_str(&assigns.join(", "));
429        }
430        write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_SET);
431
432        if flavor != Flavor::MySQL
433            && let Some(cte) = &self.cte
434        {
435            let cte_table_names = cte.table_names_for_from();
436            if !cte_table_names.is_empty() {
437                buf.write_leading("FROM");
438                buf.write_str(" ");
439                buf.write_str(&cte_table_names.join(", "));
440            }
441        }
442
443        if flavor == Flavor::SQLServer && !self.returning.is_empty() {
444            buf.write_leading("OUTPUT");
445            buf.write_str(" ");
446            let prefixed: Vec<String> = self
447                .returning
448                .iter()
449                .map(|c| format!("INSERTED.{c}"))
450                .collect();
451            buf.write_str(&prefixed.join(", "));
452            write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_RETURNING);
453        }
454
455        if let Some(ph) = &self.where_var {
456            buf.write_leading(ph);
457            write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_WHERE);
458        }
459
460        if !self.order_by_cols.is_empty() {
461            buf.write_leading("ORDER BY");
462            buf.write_str(" ");
463            buf.write_str(&self.order_by_cols.join(", "));
464            if let Some(order) = self.order {
465                buf.write_str(" ");
466                buf.write_str(order);
467            }
468            write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_ORDER_BY);
469        }
470
471        if let Some(lim) = &self.limit_var {
472            buf.write_leading("LIMIT");
473            buf.write_str(" ");
474            buf.write_str(lim);
475            write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_LIMIT);
476        }
477
478        if (flavor == Flavor::PostgreSQL || flavor == Flavor::SQLite) && !self.returning.is_empty()
479        {
480            buf.write_leading("RETURNING");
481            buf.write_str(" ");
482            buf.write_str(&self.returning.join(", "));
483            write_injection(&mut buf, &self.injection, UPDATE_MARKER_AFTER_RETURNING);
484        }
485
486        self.args
487            .borrow()
488            .compile_with_flavor(&buf.into_string(), flavor, initial_arg)
489    }
490
491    fn flavor(&self) -> Flavor {
492        self.flavor()
493    }
494}
495
496fn write_injection(buf: &mut StringBuilder, inj: &Injection, marker: InjectionMarker) {
497    let sqls = inj.at(marker);
498    if sqls.is_empty() {
499        return;
500    }
501    buf.write_leading("");
502    buf.write_str(&sqls.join(" "));
503}