halo_space/
args.rs

1//! Args: hold parameters and compile `$`-style formats into final SQL.
2
3use crate::flavor::Flavor;
4use crate::flavor::default_flavor;
5use crate::modifiers::{Arg, Raw, SqlNamedArg};
6use crate::string_builder::StringBuilder;
7use std::collections::HashMap;
8
9#[derive(Debug, thiserror::Error, PartialEq, Eq)]
10pub enum CompileError {
11    #[error("builder invalid arg reference ${0}")]
12    InvalidArgRef(isize),
13}
14
15/// Args store SQL-related arguments and index mappings.
16#[derive(Debug, Clone)]
17pub struct Args {
18    /// Default flavor used by `compile`.
19    pub flavor: Flavor,
20
21    pub(crate) index_base: usize,
22    pub(crate) arg_values: Vec<Arg>,
23    pub(crate) named_args: HashMap<String, usize>,
24    pub(crate) sql_named_args: HashMap<String, usize>,
25    pub(crate) only_named: bool,
26}
27
28#[allow(clippy::derivable_impls)]
29impl Default for Args {
30    fn default() -> Self {
31        Self {
32            flavor: default_flavor(),
33            index_base: 0,
34            arg_values: Vec::new(),
35            named_args: HashMap::new(),
36            sql_named_args: HashMap::new(),
37            only_named: false,
38        }
39    }
40}
41
42impl Args {
43    /// Add: push an argument and return the internal placeholder (`$0/$1/...`).
44    pub fn add(&mut self, arg: impl Into<Arg>) -> String {
45        let idx = self.add_internal(arg.into());
46        format!("${idx}")
47    }
48
49    /// Replace: swap the value bound to a `$n` placeholder.
50    pub fn replace(&mut self, placeholder: &str, arg: impl Into<Arg>) {
51        if !placeholder.starts_with('$') {
52            return;
53        }
54        let digits = &placeholder[1..];
55        if digits.is_empty() || !digits.as_bytes().iter().all(|b| b.is_ascii_digit()) {
56            return;
57        }
58        if let Ok(i) = digits.parse::<usize>() {
59            let idx = i.saturating_sub(self.index_base);
60            if idx < self.arg_values.len() {
61                self.arg_values[idx] = arg.into();
62            }
63        }
64    }
65
66    /// Value: parse a placeholder with `$<n>` prefix (lenient match).
67    ///
68    /// - `placeholder` may have suffix (e.g. `"$0xxx"`); as long as it starts with `$<digits>` it will be parsed.
69    pub fn value(&self, placeholder: &str) -> Option<&Arg> {
70        let s = placeholder.strip_prefix('$')?;
71        let mut end = 0usize;
72        for b in s.as_bytes() {
73            if b.is_ascii_digit() {
74                end += 1;
75            } else {
76                break;
77            }
78        }
79        if end == 0 {
80            return None;
81        }
82        let n: usize = s[..end].parse().ok()?;
83        let idx = n.saturating_sub(self.index_base);
84        self.arg_values.get(idx)
85    }
86
87    fn add_internal(&mut self, mut arg: Arg) -> usize {
88        let idx = self.arg_values.len() + self.index_base;
89
90        match &mut arg {
91            Arg::SqlNamed(SqlNamedArg { name, value: _ }) => {
92                if let Some(&p) = self.sql_named_args.get(name) {
93                    arg = self.arg_values[p - self.index_base].clone();
94                } else {
95                    self.sql_named_args.insert(name.clone(), idx);
96                }
97                // fallthrough: push arg below
98            }
99            Arg::Named { name, arg: inner } => {
100                if let Some(&p) = self.named_args.get(name) {
101                    arg = self.arg_values[p - self.index_base].clone();
102                } else {
103                    // Add real argument first, then record name->idx
104                    let real_idx = self.add_internal((**inner).clone());
105                    self.named_args.insert(name.clone(), real_idx);
106                    return real_idx;
107                }
108            }
109            _ => {}
110        }
111
112        self.arg_values.push(arg);
113        idx
114    }
115
116    /// Compile: build SQL using the default flavor.
117    pub fn compile(&self, format: &str, initial_value: &[Arg]) -> (String, Vec<Arg>) {
118        self.compile_with_flavor(format, self.flavor, initial_value)
119    }
120
121    /// CompileWithFlavor: build SQL using a specific flavor.
122    pub fn compile_with_flavor(
123        &self,
124        format: &str,
125        flavor: Flavor,
126        initial_value: &[Arg],
127    ) -> (String, Vec<Arg>) {
128        let mut offset = 0usize;
129        let mut ctx = CompileContext {
130            buf: StringBuilder::new(),
131            flavor,
132            values: initial_value.to_vec(),
133            named_args: Vec::new(),
134        };
135
136        let mut rest = format;
137        while let Some(pos) = rest.find('$') {
138            if pos > 0 {
139                ctx.buf.write_str(&rest[..pos]);
140            }
141            rest = &rest[pos + 1..];
142
143            if rest.is_empty() {
144                ctx.buf.write_char('$');
145                break;
146            }
147
148            let b0 = rest.as_bytes()[0];
149            match b0 {
150                b'$' => {
151                    ctx.buf.write_char('$');
152                    rest = &rest[1..];
153                }
154                b'{' => {
155                    rest = self.compile_named(&mut ctx, rest);
156                }
157                b'0'..=b'9' if !self.only_named => {
158                    let (r, off) = self.compile_digits(&mut ctx, rest, offset);
159                    rest = r;
160                    offset = off;
161                }
162                b'?' if !self.only_named => {
163                    let (r, off) = self.compile_successive(&mut ctx, &rest[1..], offset);
164                    rest = r;
165                    offset = off;
166                }
167                _ => {
168                    ctx.buf.write_char('$');
169                }
170            }
171        }
172
173        if !rest.is_empty() {
174            ctx.buf.write_str(rest);
175        }
176
177        let sql = ctx.buf.into_string();
178        let values = self.merge_sql_named_args(ctx.values, ctx.named_args);
179        (sql, values)
180    }
181
182    fn compile_named<'a>(&self, ctx: &mut CompileContext, format: &'a str) -> &'a str {
183        // format[0] == '{'
184        if let Some(end) = format.find('}') {
185            let name = &format[1..end];
186            let rest = &format[end + 1..];
187            if let Some(&p) = self.named_args.get(name) {
188                let (r, _off) = self.compile_successive(ctx, rest, p - self.index_base);
189                return r;
190            }
191            return rest;
192        }
193        // invalid
194        format
195    }
196
197    fn compile_digits<'a>(
198        &self,
199        ctx: &mut CompileContext,
200        format: &'a str,
201        offset: usize,
202    ) -> (&'a str, usize) {
203        let mut i = 0usize;
204        for b in format.as_bytes() {
205            if b.is_ascii_digit() {
206                i += 1;
207            } else {
208                break;
209            }
210        }
211        let digits = &format[..i];
212        let rest = &format[i..];
213        if let Ok(pointer) = digits.parse::<usize>() {
214            return self.compile_successive(ctx, rest, pointer.saturating_sub(self.index_base));
215        }
216        (rest, offset)
217    }
218
219    fn compile_successive<'a>(
220        &self,
221        ctx: &mut CompileContext,
222        format: &'a str,
223        offset: usize,
224    ) -> (&'a str, usize) {
225        if offset >= self.arg_values.len() {
226            ctx.buf.write_str("/* INVALID ARG $");
227            ctx.buf.write_str(&offset.to_string());
228            ctx.buf.write_str(" */");
229            return (format, offset);
230        }
231        let arg = self.arg_values[offset].clone();
232        ctx.write_value(&arg);
233        (format, offset + 1)
234    }
235
236    fn merge_sql_named_args(&self, mut values: Vec<Arg>, named: Vec<SqlNamedArg>) -> Vec<Arg> {
237        if self.sql_named_args.is_empty() && named.is_empty() {
238            return values;
239        }
240
241        // Add named args encountered during parsing first, de-duplicated.
242        let mut seen = HashMap::<String, ()>::new();
243        for a in named {
244            if seen.insert(a.name.clone(), ()).is_none() {
245                values.push(Arg::SqlNamed(a));
246            }
247        }
248
249        // Then append named args added via Add() but not seen in parsing order.
250        let mut idxs: Vec<usize> = self
251            .sql_named_args
252            .iter()
253            .filter_map(|(n, &p)| if seen.contains_key(n) { None } else { Some(p) })
254            .collect();
255        idxs.sort_unstable();
256        for p in idxs {
257            values.push(self.arg_values[p - self.index_base].clone());
258        }
259
260        values
261    }
262}
263
264#[derive(Debug)]
265struct CompileContext {
266    buf: StringBuilder,
267    flavor: Flavor,
268    values: Vec<Arg>,
269    named_args: Vec<SqlNamedArg>,
270}
271
272impl CompileContext {
273    fn write_value(&mut self, arg: &Arg) {
274        match arg {
275            Arg::Builder(b) => {
276                let (sql, args) = b.build_with_flavor(self.flavor, &self.values);
277                self.buf.write_str(&sql);
278
279                let (values, named) = split_named_args(args);
280                self.values = values;
281                self.named_args.extend(named);
282            }
283            Arg::SqlNamed(SqlNamedArg { name, value }) => {
284                self.buf.write_char('@');
285                self.buf.write_str(name);
286                self.named_args.push(SqlNamedArg {
287                    name: name.clone(),
288                    value: value.clone(),
289                });
290            }
291            Arg::Raw(Raw { expr }) => self.buf.write_str(expr),
292            Arg::List { args, is_tuple } => {
293                if *is_tuple {
294                    self.buf.write_char('(');
295                }
296                for (i, a) in args.iter().enumerate() {
297                    if i > 0 {
298                        self.buf.write_str(", ");
299                    }
300                    self.write_value(a);
301                }
302                if *is_tuple {
303                    self.buf.write_char(')');
304                }
305            }
306            Arg::Named { .. } => {
307                // Named only takes effect when `${name}` is parsed; treat as a value here for predictability.
308                self.write_placeholder_and_push(arg.clone());
309            }
310            Arg::Valuer(_) => self.write_placeholder_and_push(arg.clone()),
311            Arg::Value(_) => self.write_placeholder_and_push(arg.clone()),
312        }
313    }
314
315    fn write_placeholder_and_push(&mut self, arg: Arg) {
316        match self.flavor {
317            Flavor::MySQL
318            | Flavor::SQLite
319            | Flavor::CQL
320            | Flavor::ClickHouse
321            | Flavor::Presto
322            | Flavor::Informix
323            | Flavor::Doris => {
324                self.buf.write_char('?');
325            }
326            Flavor::PostgreSQL => {
327                let idx = self.values.len() + 1;
328                self.buf.write_char('$');
329                self.buf.write_str(&idx.to_string());
330            }
331            Flavor::SQLServer => {
332                let idx = self.values.len() + 1;
333                self.buf.write_str(&format!("@p{idx}"));
334            }
335            Flavor::Oracle => {
336                let idx = self.values.len() + 1;
337                self.buf.write_char(':');
338                self.buf.write_str(&idx.to_string());
339            }
340        }
341        self.values.push(arg);
342    }
343}
344
345fn split_named_args(mut values: Vec<Arg>) -> (Vec<Arg>, Vec<SqlNamedArg>) {
346    if values.is_empty() {
347        return (values, Vec::new());
348    }
349
350    let mut named = Vec::new();
351    while let Some(Arg::SqlNamed(a)) = values.last().cloned() {
352        values.pop();
353        named.push(a);
354    }
355    named.reverse();
356    (values, named)
357}