1use crate::flavor::Flavor;
4use crate::flavor::default_flavor;
5use crate::modifiers::{Arg, Raw, SqlNamedArg};
6use crate::string_builder::StringBuilder;
7use std::collections::HashMap;
8
9#[derive(Debug, thiserror::Error, PartialEq, Eq)]
10pub enum CompileError {
11 #[error("builder invalid arg reference ${0}")]
12 InvalidArgRef(isize),
13}
14
15#[derive(Debug, Clone)]
17pub struct Args {
18 pub flavor: Flavor,
20
21 pub(crate) index_base: usize,
22 pub(crate) arg_values: Vec<Arg>,
23 pub(crate) named_args: HashMap<String, usize>,
24 pub(crate) sql_named_args: HashMap<String, usize>,
25 pub(crate) only_named: bool,
26}
27
28#[allow(clippy::derivable_impls)]
29impl Default for Args {
30 fn default() -> Self {
31 Self {
32 flavor: default_flavor(),
33 index_base: 0,
34 arg_values: Vec::new(),
35 named_args: HashMap::new(),
36 sql_named_args: HashMap::new(),
37 only_named: false,
38 }
39 }
40}
41
42impl Args {
43 pub fn add(&mut self, arg: impl Into<Arg>) -> String {
45 let idx = self.add_internal(arg.into());
46 format!("${idx}")
47 }
48
49 pub fn replace(&mut self, placeholder: &str, arg: impl Into<Arg>) {
51 if !placeholder.starts_with('$') {
52 return;
53 }
54 let digits = &placeholder[1..];
55 if digits.is_empty() || !digits.as_bytes().iter().all(|b| b.is_ascii_digit()) {
56 return;
57 }
58 if let Ok(i) = digits.parse::<usize>() {
59 let idx = i.saturating_sub(self.index_base);
60 if idx < self.arg_values.len() {
61 self.arg_values[idx] = arg.into();
62 }
63 }
64 }
65
66 pub fn value(&self, placeholder: &str) -> Option<&Arg> {
70 let s = placeholder.strip_prefix('$')?;
71 let mut end = 0usize;
72 for b in s.as_bytes() {
73 if b.is_ascii_digit() {
74 end += 1;
75 } else {
76 break;
77 }
78 }
79 if end == 0 {
80 return None;
81 }
82 let n: usize = s[..end].parse().ok()?;
83 let idx = n.saturating_sub(self.index_base);
84 self.arg_values.get(idx)
85 }
86
87 fn add_internal(&mut self, mut arg: Arg) -> usize {
88 let idx = self.arg_values.len() + self.index_base;
89
90 match &mut arg {
91 Arg::SqlNamed(SqlNamedArg { name, value: _ }) => {
92 if let Some(&p) = self.sql_named_args.get(name) {
93 arg = self.arg_values[p - self.index_base].clone();
94 } else {
95 self.sql_named_args.insert(name.clone(), idx);
96 }
97 }
99 Arg::Named { name, arg: inner } => {
100 if let Some(&p) = self.named_args.get(name) {
101 arg = self.arg_values[p - self.index_base].clone();
102 } else {
103 let real_idx = self.add_internal((**inner).clone());
105 self.named_args.insert(name.clone(), real_idx);
106 return real_idx;
107 }
108 }
109 _ => {}
110 }
111
112 self.arg_values.push(arg);
113 idx
114 }
115
116 pub fn compile(&self, format: &str, initial_value: &[Arg]) -> (String, Vec<Arg>) {
118 self.compile_with_flavor(format, self.flavor, initial_value)
119 }
120
121 pub fn compile_with_flavor(
123 &self,
124 format: &str,
125 flavor: Flavor,
126 initial_value: &[Arg],
127 ) -> (String, Vec<Arg>) {
128 let mut offset = 0usize;
129 let mut ctx = CompileContext {
130 buf: StringBuilder::new(),
131 flavor,
132 values: initial_value.to_vec(),
133 named_args: Vec::new(),
134 };
135
136 let mut rest = format;
137 while let Some(pos) = rest.find('$') {
138 if pos > 0 {
139 ctx.buf.write_str(&rest[..pos]);
140 }
141 rest = &rest[pos + 1..];
142
143 if rest.is_empty() {
144 ctx.buf.write_char('$');
145 break;
146 }
147
148 let b0 = rest.as_bytes()[0];
149 match b0 {
150 b'$' => {
151 ctx.buf.write_char('$');
152 rest = &rest[1..];
153 }
154 b'{' => {
155 rest = self.compile_named(&mut ctx, rest);
156 }
157 b'0'..=b'9' if !self.only_named => {
158 let (r, off) = self.compile_digits(&mut ctx, rest, offset);
159 rest = r;
160 offset = off;
161 }
162 b'?' if !self.only_named => {
163 let (r, off) = self.compile_successive(&mut ctx, &rest[1..], offset);
164 rest = r;
165 offset = off;
166 }
167 _ => {
168 ctx.buf.write_char('$');
169 }
170 }
171 }
172
173 if !rest.is_empty() {
174 ctx.buf.write_str(rest);
175 }
176
177 let sql = ctx.buf.into_string();
178 let values = self.merge_sql_named_args(ctx.values, ctx.named_args);
179 (sql, values)
180 }
181
182 fn compile_named<'a>(&self, ctx: &mut CompileContext, format: &'a str) -> &'a str {
183 if let Some(end) = format.find('}') {
185 let name = &format[1..end];
186 let rest = &format[end + 1..];
187 if let Some(&p) = self.named_args.get(name) {
188 let (r, _off) = self.compile_successive(ctx, rest, p - self.index_base);
189 return r;
190 }
191 return rest;
192 }
193 format
195 }
196
197 fn compile_digits<'a>(
198 &self,
199 ctx: &mut CompileContext,
200 format: &'a str,
201 offset: usize,
202 ) -> (&'a str, usize) {
203 let mut i = 0usize;
204 for b in format.as_bytes() {
205 if b.is_ascii_digit() {
206 i += 1;
207 } else {
208 break;
209 }
210 }
211 let digits = &format[..i];
212 let rest = &format[i..];
213 if let Ok(pointer) = digits.parse::<usize>() {
214 return self.compile_successive(ctx, rest, pointer.saturating_sub(self.index_base));
215 }
216 (rest, offset)
217 }
218
219 fn compile_successive<'a>(
220 &self,
221 ctx: &mut CompileContext,
222 format: &'a str,
223 offset: usize,
224 ) -> (&'a str, usize) {
225 if offset >= self.arg_values.len() {
226 ctx.buf.write_str("/* INVALID ARG $");
227 ctx.buf.write_str(&offset.to_string());
228 ctx.buf.write_str(" */");
229 return (format, offset);
230 }
231 let arg = self.arg_values[offset].clone();
232 ctx.write_value(&arg);
233 (format, offset + 1)
234 }
235
236 fn merge_sql_named_args(&self, mut values: Vec<Arg>, named: Vec<SqlNamedArg>) -> Vec<Arg> {
237 if self.sql_named_args.is_empty() && named.is_empty() {
238 return values;
239 }
240
241 let mut seen = HashMap::<String, ()>::new();
243 for a in named {
244 if seen.insert(a.name.clone(), ()).is_none() {
245 values.push(Arg::SqlNamed(a));
246 }
247 }
248
249 let mut idxs: Vec<usize> = self
251 .sql_named_args
252 .iter()
253 .filter_map(|(n, &p)| if seen.contains_key(n) { None } else { Some(p) })
254 .collect();
255 idxs.sort_unstable();
256 for p in idxs {
257 values.push(self.arg_values[p - self.index_base].clone());
258 }
259
260 values
261 }
262}
263
264#[derive(Debug)]
265struct CompileContext {
266 buf: StringBuilder,
267 flavor: Flavor,
268 values: Vec<Arg>,
269 named_args: Vec<SqlNamedArg>,
270}
271
272impl CompileContext {
273 fn write_value(&mut self, arg: &Arg) {
274 match arg {
275 Arg::Builder(b) => {
276 let (sql, args) = b.build_with_flavor(self.flavor, &self.values);
277 self.buf.write_str(&sql);
278
279 let (values, named) = split_named_args(args);
280 self.values = values;
281 self.named_args.extend(named);
282 }
283 Arg::SqlNamed(SqlNamedArg { name, value }) => {
284 self.buf.write_char('@');
285 self.buf.write_str(name);
286 self.named_args.push(SqlNamedArg {
287 name: name.clone(),
288 value: value.clone(),
289 });
290 }
291 Arg::Raw(Raw { expr }) => self.buf.write_str(expr),
292 Arg::List { args, is_tuple } => {
293 if *is_tuple {
294 self.buf.write_char('(');
295 }
296 for (i, a) in args.iter().enumerate() {
297 if i > 0 {
298 self.buf.write_str(", ");
299 }
300 self.write_value(a);
301 }
302 if *is_tuple {
303 self.buf.write_char(')');
304 }
305 }
306 Arg::Named { .. } => {
307 self.write_placeholder_and_push(arg.clone());
310 }
311 Arg::Valuer(_) => self.write_placeholder_and_push(arg.clone()),
312 Arg::Value(_) => self.write_placeholder_and_push(arg.clone()),
313 }
314 }
315
316 fn write_placeholder_and_push(&mut self, arg: Arg) {
317 match self.flavor {
318 Flavor::MySQL
319 | Flavor::SQLite
320 | Flavor::CQL
321 | Flavor::ClickHouse
322 | Flavor::Presto
323 | Flavor::Informix
324 | Flavor::Doris => {
325 self.buf.write_char('?');
326 }
327 Flavor::PostgreSQL => {
328 let idx = self.values.len() + 1;
329 self.buf.write_char('$');
330 self.buf.write_str(&idx.to_string());
331 }
332 Flavor::SQLServer => {
333 let idx = self.values.len() + 1;
334 self.buf.write_str(&format!("@p{idx}"));
335 }
336 Flavor::Oracle => {
337 let idx = self.values.len() + 1;
338 self.buf.write_char(':');
339 self.buf.write_str(&idx.to_string());
340 }
341 }
342 self.values.push(arg);
343 }
344}
345
346fn split_named_args(mut values: Vec<Arg>) -> (Vec<Arg>, Vec<SqlNamedArg>) {
347 if values.is_empty() {
348 return (values, Vec::new());
349 }
350
351 let mut named = Vec::new();
352 while let Some(Arg::SqlNamed(a)) = values.last().cloned() {
353 values.pop();
354 named.push(a);
355 }
356 named.reverse();
357 (values, named)
358}