1use crate::args::Args;
4use crate::flavor::Flavor;
5use crate::injection::{Injection, InjectionMarker};
6use crate::modifiers::{Arg, Builder};
7use crate::string_builder::StringBuilder;
8use std::cell::RefCell;
9use std::rc::Rc;
10
11const UNION_DISTINCT: &str = " UNION ";
12const UNION_ALL: &str = " UNION ALL ";
13
14const UNION_MARKER_INIT: InjectionMarker = 0;
15const UNION_MARKER_AFTER_UNION: InjectionMarker = 1;
16const UNION_MARKER_AFTER_ORDER_BY: InjectionMarker = 2;
17const UNION_MARKER_AFTER_LIMIT: InjectionMarker = 3;
18
19#[derive(Debug)]
20pub struct UnionBuilder {
21 opt: &'static str,
22 order_by_cols: Vec<String>,
23 order: Option<&'static str>,
24 limit_var: Option<String>,
25 offset_var: Option<String>,
26
27 builder_vars: Vec<String>,
28 args: Rc<RefCell<Args>>,
29
30 injection: Injection,
31 marker: InjectionMarker,
32}
33
34impl Default for UnionBuilder {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40impl Clone for UnionBuilder {
41 fn clone(&self) -> Self {
42 self.clone_builder()
43 }
44}
45
46impl UnionBuilder {
47 pub fn new() -> Self {
48 Self {
49 opt: UNION_DISTINCT,
50 order_by_cols: Vec::new(),
51 order: None,
52 limit_var: None,
53 offset_var: None,
54 builder_vars: Vec::new(),
55 args: Rc::new(RefCell::new(Args::default())),
56 injection: Injection::new(),
57 marker: UNION_MARKER_INIT,
58 }
59 }
60
61 pub fn set_flavor(&mut self, flavor: Flavor) -> Flavor {
62 let mut a = self.args.borrow_mut();
63 let old = a.flavor;
64 a.flavor = flavor;
65 old
66 }
67
68 pub fn flavor(&self) -> Flavor {
69 self.args.borrow().flavor
70 }
71
72 pub fn clone_builder(&self) -> Self {
73 Self {
74 opt: self.opt,
75 order_by_cols: self.order_by_cols.clone(),
76 order: self.order,
77 limit_var: self.limit_var.clone(),
78 offset_var: self.offset_var.clone(),
79 builder_vars: self.builder_vars.clone(),
80 args: Rc::new(RefCell::new(self.args.borrow().clone())),
81 injection: self.injection.clone(),
82 marker: self.marker,
83 }
84 }
85
86 fn var(&self, v: impl Into<Arg>) -> String {
87 self.args.borrow_mut().add(v)
88 }
89
90 pub fn union(
91 &mut self,
92 builders: impl IntoIterator<Item = impl Builder + 'static>,
93 ) -> &mut Self {
94 self.union_impl(UNION_DISTINCT, builders)
95 }
96
97 pub fn union_all(
98 &mut self,
99 builders: impl IntoIterator<Item = impl Builder + 'static>,
100 ) -> &mut Self {
101 self.union_impl(UNION_ALL, builders)
102 }
103
104 fn union_impl(
105 &mut self,
106 opt: &'static str,
107 builders: impl IntoIterator<Item = impl Builder + 'static>,
108 ) -> &mut Self {
109 self.opt = opt;
110 self.builder_vars = builders
111 .into_iter()
112 .map(|b| self.var(Arg::Builder(Box::new(b))))
113 .collect();
114 self.marker = UNION_MARKER_AFTER_UNION;
115 self
116 }
117
118 pub fn order_by(&mut self, cols: impl IntoIterator<Item = impl Into<String>>) -> &mut Self {
119 self.order_by_cols = cols.into_iter().map(Into::into).collect();
120 self.marker = UNION_MARKER_AFTER_ORDER_BY;
121 self
122 }
123
124 pub fn order_by_asc(&mut self, col: impl Into<String>) -> &mut Self {
125 self.order_by_cols.push(format!("{} ASC", col.into()));
126 self.marker = UNION_MARKER_AFTER_ORDER_BY;
127 self
128 }
129
130 pub fn order_by_desc(&mut self, col: impl Into<String>) -> &mut Self {
131 self.order_by_cols.push(format!("{} DESC", col.into()));
132 self.marker = UNION_MARKER_AFTER_ORDER_BY;
133 self
134 }
135
136 pub fn asc(&mut self) -> &mut Self {
137 self.order = Some("ASC");
138 self.marker = UNION_MARKER_AFTER_ORDER_BY;
139 self
140 }
141
142 pub fn desc(&mut self) -> &mut Self {
143 self.order = Some("DESC");
144 self.marker = UNION_MARKER_AFTER_ORDER_BY;
145 self
146 }
147
148 pub fn limit(&mut self, limit: i64) -> &mut Self {
149 if limit < 0 {
150 self.limit_var = None;
151 return self;
152 }
153 self.limit_var = Some(self.var(limit));
154 self.marker = UNION_MARKER_AFTER_LIMIT;
155 self
156 }
157
158 pub fn offset(&mut self, offset: i64) -> &mut Self {
159 if offset < 0 {
160 self.offset_var = None;
161 return self;
162 }
163 self.offset_var = Some(self.var(offset));
164 self.marker = UNION_MARKER_AFTER_LIMIT;
165 self
166 }
167
168 pub fn sql(&mut self, sql: impl Into<String>) -> &mut Self {
169 self.injection.sql(self.marker, sql);
170 self
171 }
172}
173
174impl Builder for UnionBuilder {
175 fn build_with_flavor(&self, flavor: Flavor, initial_arg: &[Arg]) -> (String, Vec<Arg>) {
176 let mut buf = StringBuilder::new();
177 write_injection(&mut buf, &self.injection, UNION_MARKER_INIT);
178
179 let nested_select = (flavor == Flavor::Oracle
180 && (self.limit_var.is_some() || self.offset_var.is_some()))
181 || (flavor == Flavor::Informix && self.limit_var.is_some());
182
183 if !self.builder_vars.is_empty() {
184 let need_paren = flavor != Flavor::SQLite;
185
186 if nested_select {
187 buf.write_leading("SELECT * FROM (");
188 }
189
190 if need_paren {
192 buf.write_leading("(");
193 buf.write_str(&self.builder_vars[0]);
194 buf.write_str(")");
195 } else {
196 buf.write_leading(&self.builder_vars[0]);
197 }
198
199 for b in self.builder_vars.iter().skip(1) {
200 buf.write_str(self.opt);
201 if need_paren {
202 buf.write_str("(");
203 }
204 buf.write_str(b);
205 if need_paren {
206 buf.write_str(")");
207 }
208 }
209
210 if nested_select {
211 buf.write_leading(")");
212 }
213 }
214
215 write_injection(&mut buf, &self.injection, UNION_MARKER_AFTER_UNION);
216
217 if !self.order_by_cols.is_empty() {
218 buf.write_leading("ORDER BY");
219 buf.write_str(" ");
220 buf.write_str(&self.order_by_cols.join(", "));
221 if let Some(order) = self.order {
222 buf.write_str(" ");
223 buf.write_str(order);
224 }
225 write_injection(&mut buf, &self.injection, UNION_MARKER_AFTER_ORDER_BY);
226 }
227
228 match flavor {
229 Flavor::MySQL | Flavor::SQLite | Flavor::ClickHouse => {
230 if let Some(lim) = &self.limit_var {
231 buf.write_leading("LIMIT");
232 buf.write_str(" ");
233 buf.write_str(lim);
234 if let Some(off) = &self.offset_var {
235 buf.write_leading("OFFSET");
236 buf.write_str(" ");
237 buf.write_str(off);
238 }
239 }
240 }
241 Flavor::CQL => {
242 if let Some(lim) = &self.limit_var {
243 buf.write_leading("LIMIT");
244 buf.write_str(" ");
245 buf.write_str(lim);
246 }
247 }
248 Flavor::PostgreSQL => {
249 if let Some(lim) = &self.limit_var {
250 buf.write_leading("LIMIT");
251 buf.write_str(" ");
252 buf.write_str(lim);
253 }
254 if let Some(off) = &self.offset_var {
255 buf.write_leading("OFFSET");
256 buf.write_str(" ");
257 buf.write_str(off);
258 }
259 }
260 Flavor::Presto => {
261 if let Some(off) = &self.offset_var {
262 buf.write_leading("OFFSET");
263 buf.write_str(" ");
264 buf.write_str(off);
265 }
266 if let Some(lim) = &self.limit_var {
267 buf.write_leading("LIMIT");
268 buf.write_str(" ");
269 buf.write_str(lim);
270 }
271 }
272 Flavor::SQLServer => {
273 if self.order_by_cols.is_empty()
274 && (self.limit_var.is_some() || self.offset_var.is_some())
275 {
276 buf.write_leading("ORDER BY 1");
277 }
278 if let Some(off) = &self.offset_var {
279 buf.write_leading("OFFSET");
280 buf.write_str(" ");
281 buf.write_str(off);
282 buf.write_str(" ROWS");
283 }
284 if let Some(lim) = &self.limit_var {
285 if self.offset_var.is_none() {
286 buf.write_leading("OFFSET 0 ROWS");
287 }
288 buf.write_leading("FETCH NEXT");
289 buf.write_str(" ");
290 buf.write_str(lim);
291 buf.write_str(" ROWS ONLY");
292 }
293 }
294 Flavor::Oracle => {
295 if let Some(off) = &self.offset_var {
296 buf.write_leading("OFFSET");
297 buf.write_str(" ");
298 buf.write_str(off);
299 buf.write_str(" ROWS");
300 }
301 if let Some(lim) = &self.limit_var {
302 if self.offset_var.is_none() {
303 buf.write_leading("OFFSET 0 ROWS");
304 }
305 buf.write_leading("FETCH NEXT");
306 buf.write_str(" ");
307 buf.write_str(lim);
308 buf.write_str(" ROWS ONLY");
309 }
310 }
311 Flavor::Informix => {
312 if let Some(lim) = &self.limit_var {
316 if let Some(off) = &self.offset_var {
317 buf.write_leading("SKIP");
318 buf.write_str(" ");
319 buf.write_str(off);
320 buf.write_leading("FIRST");
321 buf.write_str(" ");
322 buf.write_str(lim);
323 } else {
324 buf.write_leading("FIRST");
325 buf.write_str(" ");
326 buf.write_str(lim);
327 }
328 }
329 }
330 Flavor::Doris => {
331 if let Some(lim_ph) = &self.limit_var {
335 if let Some(n) = extract_i64(&self.args.borrow(), lim_ph) {
336 buf.write_leading("LIMIT");
337 buf.write_str(" ");
338 buf.write_str(&n.to_string());
339 if let Some(off_ph) = &self.offset_var
340 && let Some(off) = extract_i64(&self.args.borrow(), off_ph)
341 {
342 buf.write_leading("OFFSET");
343 buf.write_str(" ");
344 buf.write_str(&off.to_string());
345 }
346 } else {
347 buf.write_leading("LIMIT");
349 buf.write_str(" ");
350 buf.write_str(lim_ph);
351 }
352 }
353 }
354 }
355
356 if self.limit_var.is_some() {
357 write_injection(&mut buf, &self.injection, UNION_MARKER_AFTER_LIMIT);
358 }
359
360 self.args
361 .borrow()
362 .compile_with_flavor(&buf.into_string(), flavor, initial_arg)
363 }
364
365 fn flavor(&self) -> Flavor {
366 self.flavor()
367 }
368}
369
370fn extract_i64(args: &Args, placeholder: &str) -> Option<i64> {
371 let a = args.value(placeholder)?;
372 match a {
373 Arg::Value(crate::value::SqlValue::I64(v)) => Some(*v),
374 Arg::Value(crate::value::SqlValue::U64(v)) => i64::try_from(*v).ok(),
375 _ => None,
376 }
377}
378
379fn write_injection(buf: &mut StringBuilder, inj: &Injection, marker: InjectionMarker) {
380 let sqls = inj.at(marker);
381 if sqls.is_empty() {
382 return;
383 }
384 buf.write_leading("");
385 buf.write_str(&sqls.join(" "));
386}