halo/
union.rs

1//! UnionBuilder:构建 UNION / UNION ALL(对齐 go-sqlbuilder `union.go` 的核心行为)。
2
3use crate::args::Args;
4use crate::flavor::Flavor;
5use crate::injection::{Injection, InjectionMarker};
6use crate::modifiers::{Arg, Builder};
7use crate::string_builder::StringBuilder;
8use std::cell::RefCell;
9use std::rc::Rc;
10
11const UNION_DISTINCT: &str = " UNION ";
12const UNION_ALL: &str = " UNION ALL ";
13
14const UNION_MARKER_INIT: InjectionMarker = 0;
15const UNION_MARKER_AFTER_UNION: InjectionMarker = 1;
16const UNION_MARKER_AFTER_ORDER_BY: InjectionMarker = 2;
17const UNION_MARKER_AFTER_LIMIT: InjectionMarker = 3;
18
19#[derive(Debug)]
20pub struct UnionBuilder {
21    opt: &'static str,
22    order_by_cols: Vec<String>,
23    order: Option<&'static str>,
24    limit_var: Option<String>,
25    offset_var: Option<String>,
26
27    builder_vars: Vec<String>,
28    args: Rc<RefCell<Args>>,
29
30    injection: Injection,
31    marker: InjectionMarker,
32}
33
34impl Default for UnionBuilder {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40impl Clone for UnionBuilder {
41    fn clone(&self) -> Self {
42        self.clone_builder()
43    }
44}
45
46impl UnionBuilder {
47    pub fn new() -> Self {
48        Self {
49            opt: UNION_DISTINCT,
50            order_by_cols: Vec::new(),
51            order: None,
52            limit_var: None,
53            offset_var: None,
54            builder_vars: Vec::new(),
55            args: Rc::new(RefCell::new(Args::default())),
56            injection: Injection::new(),
57            marker: UNION_MARKER_INIT,
58        }
59    }
60
61    pub fn set_flavor(&mut self, flavor: Flavor) -> Flavor {
62        let mut a = self.args.borrow_mut();
63        let old = a.flavor;
64        a.flavor = flavor;
65        old
66    }
67
68    pub fn flavor(&self) -> Flavor {
69        self.args.borrow().flavor
70    }
71
72    pub fn clone_builder(&self) -> Self {
73        Self {
74            opt: self.opt,
75            order_by_cols: self.order_by_cols.clone(),
76            order: self.order,
77            limit_var: self.limit_var.clone(),
78            offset_var: self.offset_var.clone(),
79            builder_vars: self.builder_vars.clone(),
80            args: Rc::new(RefCell::new(self.args.borrow().clone())),
81            injection: self.injection.clone(),
82            marker: self.marker,
83        }
84    }
85
86    fn var(&self, v: impl Into<Arg>) -> String {
87        self.args.borrow_mut().add(v)
88    }
89
90    pub fn union(
91        &mut self,
92        builders: impl IntoIterator<Item = impl Builder + 'static>,
93    ) -> &mut Self {
94        self.union_impl(UNION_DISTINCT, builders)
95    }
96
97    pub fn union_all(
98        &mut self,
99        builders: impl IntoIterator<Item = impl Builder + 'static>,
100    ) -> &mut Self {
101        self.union_impl(UNION_ALL, builders)
102    }
103
104    fn union_impl(
105        &mut self,
106        opt: &'static str,
107        builders: impl IntoIterator<Item = impl Builder + 'static>,
108    ) -> &mut Self {
109        self.opt = opt;
110        self.builder_vars = builders
111            .into_iter()
112            .map(|b| self.var(Arg::Builder(Box::new(b))))
113            .collect();
114        self.marker = UNION_MARKER_AFTER_UNION;
115        self
116    }
117
118    pub fn order_by(&mut self, cols: impl IntoIterator<Item = impl Into<String>>) -> &mut Self {
119        self.order_by_cols = cols.into_iter().map(Into::into).collect();
120        self.marker = UNION_MARKER_AFTER_ORDER_BY;
121        self
122    }
123
124    pub fn order_by_asc(&mut self, col: impl Into<String>) -> &mut Self {
125        self.order_by_cols.push(format!("{} ASC", col.into()));
126        self.marker = UNION_MARKER_AFTER_ORDER_BY;
127        self
128    }
129
130    pub fn order_by_desc(&mut self, col: impl Into<String>) -> &mut Self {
131        self.order_by_cols.push(format!("{} DESC", col.into()));
132        self.marker = UNION_MARKER_AFTER_ORDER_BY;
133        self
134    }
135
136    pub fn asc(&mut self) -> &mut Self {
137        self.order = Some("ASC");
138        self.marker = UNION_MARKER_AFTER_ORDER_BY;
139        self
140    }
141
142    pub fn desc(&mut self) -> &mut Self {
143        self.order = Some("DESC");
144        self.marker = UNION_MARKER_AFTER_ORDER_BY;
145        self
146    }
147
148    pub fn limit(&mut self, limit: i64) -> &mut Self {
149        if limit < 0 {
150            self.limit_var = None;
151            return self;
152        }
153        self.limit_var = Some(self.var(limit));
154        self.marker = UNION_MARKER_AFTER_LIMIT;
155        self
156    }
157
158    pub fn offset(&mut self, offset: i64) -> &mut Self {
159        if offset < 0 {
160            self.offset_var = None;
161            return self;
162        }
163        self.offset_var = Some(self.var(offset));
164        self.marker = UNION_MARKER_AFTER_LIMIT;
165        self
166    }
167
168    pub fn sql(&mut self, sql: impl Into<String>) -> &mut Self {
169        self.injection.sql(self.marker, sql);
170        self
171    }
172}
173
174impl Builder for UnionBuilder {
175    fn build_with_flavor(&self, flavor: Flavor, initial_arg: &[Arg]) -> (String, Vec<Arg>) {
176        let mut buf = StringBuilder::new();
177        write_injection(&mut buf, &self.injection, UNION_MARKER_INIT);
178
179        let nested_select = (flavor == Flavor::Oracle
180            && (self.limit_var.is_some() || self.offset_var.is_some()))
181            || (flavor == Flavor::Informix && self.limit_var.is_some());
182
183        if !self.builder_vars.is_empty() {
184            let need_paren = flavor != Flavor::SQLite;
185
186            if nested_select {
187                buf.write_leading("SELECT * FROM (");
188            }
189
190            // first
191            if need_paren {
192                buf.write_leading("(");
193                buf.write_str(&self.builder_vars[0]);
194                buf.write_str(")");
195            } else {
196                buf.write_leading(&self.builder_vars[0]);
197            }
198
199            for b in self.builder_vars.iter().skip(1) {
200                buf.write_str(self.opt);
201                if need_paren {
202                    buf.write_str("(");
203                }
204                buf.write_str(b);
205                if need_paren {
206                    buf.write_str(")");
207                }
208            }
209
210            if nested_select {
211                buf.write_leading(")");
212            }
213        }
214
215        write_injection(&mut buf, &self.injection, UNION_MARKER_AFTER_UNION);
216
217        if !self.order_by_cols.is_empty() {
218            buf.write_leading("ORDER BY");
219            buf.write_str(" ");
220            buf.write_str(&self.order_by_cols.join(", "));
221            if let Some(order) = self.order {
222                buf.write_str(" ");
223                buf.write_str(order);
224            }
225            write_injection(&mut buf, &self.injection, UNION_MARKER_AFTER_ORDER_BY);
226        }
227
228        match flavor {
229            Flavor::MySQL | Flavor::SQLite | Flavor::ClickHouse => {
230                if let Some(lim) = &self.limit_var {
231                    buf.write_leading("LIMIT");
232                    buf.write_str(" ");
233                    buf.write_str(lim);
234                    if let Some(off) = &self.offset_var {
235                        buf.write_leading("OFFSET");
236                        buf.write_str(" ");
237                        buf.write_str(off);
238                    }
239                }
240            }
241            Flavor::CQL => {
242                if let Some(lim) = &self.limit_var {
243                    buf.write_leading("LIMIT");
244                    buf.write_str(" ");
245                    buf.write_str(lim);
246                }
247            }
248            Flavor::PostgreSQL => {
249                if let Some(lim) = &self.limit_var {
250                    buf.write_leading("LIMIT");
251                    buf.write_str(" ");
252                    buf.write_str(lim);
253                }
254                if let Some(off) = &self.offset_var {
255                    buf.write_leading("OFFSET");
256                    buf.write_str(" ");
257                    buf.write_str(off);
258                }
259            }
260            Flavor::Presto => {
261                if let Some(off) = &self.offset_var {
262                    buf.write_leading("OFFSET");
263                    buf.write_str(" ");
264                    buf.write_str(off);
265                }
266                if let Some(lim) = &self.limit_var {
267                    buf.write_leading("LIMIT");
268                    buf.write_str(" ");
269                    buf.write_str(lim);
270                }
271            }
272            Flavor::SQLServer => {
273                if self.order_by_cols.is_empty()
274                    && (self.limit_var.is_some() || self.offset_var.is_some())
275                {
276                    buf.write_leading("ORDER BY 1");
277                }
278                if let Some(off) = &self.offset_var {
279                    buf.write_leading("OFFSET");
280                    buf.write_str(" ");
281                    buf.write_str(off);
282                    buf.write_str(" ROWS");
283                }
284                if let Some(lim) = &self.limit_var {
285                    if self.offset_var.is_none() {
286                        buf.write_leading("OFFSET 0 ROWS");
287                    }
288                    buf.write_leading("FETCH NEXT");
289                    buf.write_str(" ");
290                    buf.write_str(lim);
291                    buf.write_str(" ROWS ONLY");
292                }
293            }
294            Flavor::Oracle => {
295                if let Some(off) = &self.offset_var {
296                    buf.write_leading("OFFSET");
297                    buf.write_str(" ");
298                    buf.write_str(off);
299                    buf.write_str(" ROWS");
300                }
301                if let Some(lim) = &self.limit_var {
302                    if self.offset_var.is_none() {
303                        buf.write_leading("OFFSET 0 ROWS");
304                    }
305                    buf.write_leading("FETCH NEXT");
306                    buf.write_str(" ");
307                    buf.write_str(lim);
308                    buf.write_str(" ROWS ONLY");
309                }
310            }
311            Flavor::Informix => {
312                // Informix:
313                // - offset 无 limit 时忽略
314                // - limit/offset 使用 `SKIP ? FIRST ?`
315                if let Some(lim) = &self.limit_var {
316                    if let Some(off) = &self.offset_var {
317                        buf.write_leading("SKIP");
318                        buf.write_str(" ");
319                        buf.write_str(off);
320                        buf.write_leading("FIRST");
321                        buf.write_str(" ");
322                        buf.write_str(lim);
323                    } else {
324                        buf.write_leading("FIRST");
325                        buf.write_str(" ");
326                        buf.write_str(lim);
327                    }
328                }
329            }
330            Flavor::Doris => {
331                // Doris:
332                // - offset 无 limit 时忽略
333                // - limit/offset 使用字面量(不参数化)
334                if let Some(lim_ph) = &self.limit_var {
335                    if let Some(n) = extract_i64(&self.args.borrow(), lim_ph) {
336                        buf.write_leading("LIMIT");
337                        buf.write_str(" ");
338                        buf.write_str(&n.to_string());
339                        if let Some(off_ph) = &self.offset_var
340                            && let Some(off) = extract_i64(&self.args.borrow(), off_ph)
341                        {
342                            buf.write_leading("OFFSET");
343                            buf.write_str(" ");
344                            buf.write_str(&off.to_string());
345                        }
346                    } else {
347                        // fallback:仍使用占位符
348                        buf.write_leading("LIMIT");
349                        buf.write_str(" ");
350                        buf.write_str(lim_ph);
351                    }
352                }
353            }
354        }
355
356        if self.limit_var.is_some() {
357            write_injection(&mut buf, &self.injection, UNION_MARKER_AFTER_LIMIT);
358        }
359
360        self.args
361            .borrow()
362            .compile_with_flavor(&buf.into_string(), flavor, initial_arg)
363    }
364
365    fn flavor(&self) -> Flavor {
366        self.flavor()
367    }
368}
369
370fn extract_i64(args: &Args, placeholder: &str) -> Option<i64> {
371    let a = args.value(placeholder)?;
372    match a {
373        Arg::Value(crate::value::SqlValue::I64(v)) => Some(*v),
374        Arg::Value(crate::value::SqlValue::U64(v)) => i64::try_from(*v).ok(),
375        _ => None,
376    }
377}
378
379fn write_injection(buf: &mut StringBuilder, inj: &Injection, marker: InjectionMarker) {
380    let sqls = inj.at(marker);
381    if sqls.is_empty() {
382        return;
383    }
384    buf.write_leading("");
385    buf.write_str(&sqls.join(" "));
386}