halo/
args.rs

1//! Args:存储参数并把含 `$` 语法的 format 编译成最终 SQL(对齐 go-sqlbuilder `args.go`)。
2
3use crate::flavor::Flavor;
4use crate::flavor::default_flavor;
5use crate::modifiers::{Arg, Raw, SqlNamedArg};
6use crate::string_builder::StringBuilder;
7use std::collections::HashMap;
8
9#[derive(Debug, thiserror::Error, PartialEq, Eq)]
10pub enum CompileError {
11    #[error("sql_builder: invalid arg reference ${0}")]
12    InvalidArgRef(isize),
13}
14
15/// Args 存储 SQL 相关参数。
16#[derive(Debug, Clone)]
17pub struct Args {
18    /// 默认 flavor,用于 `compile`。
19    pub flavor: Flavor,
20
21    pub(crate) index_base: usize,
22    pub(crate) arg_values: Vec<Arg>,
23    pub(crate) named_args: HashMap<String, usize>,
24    pub(crate) sql_named_args: HashMap<String, usize>,
25    pub(crate) only_named: bool,
26}
27
28#[allow(clippy::derivable_impls)]
29impl Default for Args {
30    fn default() -> Self {
31        Self {
32            flavor: default_flavor(),
33            index_base: 0,
34            arg_values: Vec::new(),
35            named_args: HashMap::new(),
36            sql_named_args: HashMap::new(),
37            only_named: false,
38        }
39    }
40}
41
42impl Args {
43    /// Add:追加一个参数并返回内部占位符(`$0/$1/...`)。
44    pub fn add(&mut self, arg: impl Into<Arg>) -> String {
45        let idx = self.add_internal(arg.into());
46        format!("${idx}")
47    }
48
49    /// Replace:用新参数替换某个 `$n` 占位符对应的值(对齐 go-sqlbuilder `Args.Replace`)。
50    pub fn replace(&mut self, placeholder: &str, arg: impl Into<Arg>) {
51        if !placeholder.starts_with('$') {
52            return;
53        }
54        let digits = &placeholder[1..];
55        if digits.is_empty() || !digits.as_bytes().iter().all(|b| b.is_ascii_digit()) {
56            return;
57        }
58        if let Ok(i) = digits.parse::<usize>() {
59            let idx = i.saturating_sub(self.index_base);
60            if idx < self.arg_values.len() {
61                self.arg_values[idx] = arg.into();
62            }
63        }
64    }
65
66    /// Value:按 `$<n>` 前缀解析参数值(对齐 go-sqlbuilder `Args.Value` 的“宽松匹配”)。
67    ///
68    /// - `placeholder` 可以带后缀(如 `"$0xxx"`),只要以 `$<digits>` 开头就会解析。
69    pub fn value(&self, placeholder: &str) -> Option<&Arg> {
70        let s = placeholder.strip_prefix('$')?;
71        let mut end = 0usize;
72        for b in s.as_bytes() {
73            if b.is_ascii_digit() {
74                end += 1;
75            } else {
76                break;
77            }
78        }
79        if end == 0 {
80            return None;
81        }
82        let n: usize = s[..end].parse().ok()?;
83        let idx = n.saturating_sub(self.index_base);
84        self.arg_values.get(idx)
85    }
86
87    fn add_internal(&mut self, mut arg: Arg) -> usize {
88        let idx = self.arg_values.len() + self.index_base;
89
90        match &mut arg {
91            Arg::SqlNamed(SqlNamedArg { name, value: _ }) => {
92                if let Some(&p) = self.sql_named_args.get(name) {
93                    arg = self.arg_values[p - self.index_base].clone();
94                } else {
95                    self.sql_named_args.insert(name.clone(), idx);
96                }
97                // fallthrough: push arg below
98            }
99            Arg::Named { name, arg: inner } => {
100                if let Some(&p) = self.named_args.get(name) {
101                    arg = self.arg_values[p - self.index_base].clone();
102                } else {
103                    // 先把真实参数加入,再记录 name->idx
104                    let real_idx = self.add_internal((**inner).clone());
105                    self.named_args.insert(name.clone(), real_idx);
106                    return real_idx;
107                }
108            }
109            _ => {}
110        }
111
112        self.arg_values.push(arg);
113        idx
114    }
115
116    /// Compile:按默认 flavor 编译 format。
117    pub fn compile(&self, format: &str, initial_value: &[Arg]) -> (String, Vec<Arg>) {
118        self.compile_with_flavor(format, self.flavor, initial_value)
119    }
120
121    /// CompileWithFlavor:编译 format,并用 `flavor` 输出最终占位符。
122    pub fn compile_with_flavor(
123        &self,
124        format: &str,
125        flavor: Flavor,
126        initial_value: &[Arg],
127    ) -> (String, Vec<Arg>) {
128        let mut offset = 0usize;
129        let mut ctx = CompileContext {
130            buf: StringBuilder::new(),
131            flavor,
132            values: initial_value.to_vec(),
133            named_args: Vec::new(),
134        };
135
136        let mut rest = format;
137        while let Some(pos) = rest.find('$') {
138            if pos > 0 {
139                ctx.buf.write_str(&rest[..pos]);
140            }
141            rest = &rest[pos + 1..];
142
143            if rest.is_empty() {
144                ctx.buf.write_char('$');
145                break;
146            }
147
148            let b0 = rest.as_bytes()[0];
149            match b0 {
150                b'$' => {
151                    ctx.buf.write_char('$');
152                    rest = &rest[1..];
153                }
154                b'{' => {
155                    rest = self.compile_named(&mut ctx, rest);
156                }
157                b'0'..=b'9' if !self.only_named => {
158                    let (r, off) = self.compile_digits(&mut ctx, rest, offset);
159                    rest = r;
160                    offset = off;
161                }
162                b'?' if !self.only_named => {
163                    let (r, off) = self.compile_successive(&mut ctx, &rest[1..], offset);
164                    rest = r;
165                    offset = off;
166                }
167                _ => {
168                    ctx.buf.write_char('$');
169                }
170            }
171        }
172
173        if !rest.is_empty() {
174            ctx.buf.write_str(rest);
175        }
176
177        let sql = ctx.buf.into_string();
178        let values = self.merge_sql_named_args(ctx.values, ctx.named_args);
179        (sql, values)
180    }
181
182    fn compile_named<'a>(&self, ctx: &mut CompileContext, format: &'a str) -> &'a str {
183        // format[0] == '{'
184        if let Some(end) = format.find('}') {
185            let name = &format[1..end];
186            let rest = &format[end + 1..];
187            if let Some(&p) = self.named_args.get(name) {
188                let (r, _off) = self.compile_successive(ctx, rest, p - self.index_base);
189                return r;
190            }
191            return rest;
192        }
193        // invalid
194        format
195    }
196
197    fn compile_digits<'a>(
198        &self,
199        ctx: &mut CompileContext,
200        format: &'a str,
201        offset: usize,
202    ) -> (&'a str, usize) {
203        let mut i = 0usize;
204        for b in format.as_bytes() {
205            if b.is_ascii_digit() {
206                i += 1;
207            } else {
208                break;
209            }
210        }
211        let digits = &format[..i];
212        let rest = &format[i..];
213        if let Ok(pointer) = digits.parse::<usize>() {
214            return self.compile_successive(ctx, rest, pointer.saturating_sub(self.index_base));
215        }
216        (rest, offset)
217    }
218
219    fn compile_successive<'a>(
220        &self,
221        ctx: &mut CompileContext,
222        format: &'a str,
223        offset: usize,
224    ) -> (&'a str, usize) {
225        if offset >= self.arg_values.len() {
226            ctx.buf.write_str("/* INVALID ARG $");
227            ctx.buf.write_str(&offset.to_string());
228            ctx.buf.write_str(" */");
229            return (format, offset);
230        }
231        let arg = self.arg_values[offset].clone();
232        ctx.write_value(&arg);
233        (format, offset + 1)
234    }
235
236    fn merge_sql_named_args(&self, mut values: Vec<Arg>, named: Vec<SqlNamedArg>) -> Vec<Arg> {
237        if self.sql_named_args.is_empty() && named.is_empty() {
238            return values;
239        }
240
241        // 先追加 ctx 中遇到的 named args,并去重
242        let mut seen = HashMap::<String, ()>::new();
243        for a in named {
244            if seen.insert(a.name.clone(), ()).is_none() {
245                values.push(Arg::SqlNamed(a));
246            }
247        }
248
249        // 再追加 Add() 时出现但 ctx 中未出现的 named args,按位置稳定排序
250        let mut idxs: Vec<usize> = self
251            .sql_named_args
252            .iter()
253            .filter_map(|(n, &p)| if seen.contains_key(n) { None } else { Some(p) })
254            .collect();
255        idxs.sort_unstable();
256        for p in idxs {
257            values.push(self.arg_values[p - self.index_base].clone());
258        }
259
260        values
261    }
262}
263
264#[derive(Debug)]
265struct CompileContext {
266    buf: StringBuilder,
267    flavor: Flavor,
268    values: Vec<Arg>,
269    named_args: Vec<SqlNamedArg>,
270}
271
272impl CompileContext {
273    fn write_value(&mut self, arg: &Arg) {
274        match arg {
275            Arg::Builder(b) => {
276                let (sql, args) = b.build_with_flavor(self.flavor, &self.values);
277                self.buf.write_str(&sql);
278
279                let (values, named) = split_named_args(args);
280                self.values = values;
281                self.named_args.extend(named);
282            }
283            Arg::SqlNamed(SqlNamedArg { name, value }) => {
284                self.buf.write_char('@');
285                self.buf.write_str(name);
286                self.named_args.push(SqlNamedArg {
287                    name: name.clone(),
288                    value: value.clone(),
289                });
290            }
291            Arg::Raw(Raw { expr }) => self.buf.write_str(expr),
292            Arg::List { args, is_tuple } => {
293                if *is_tuple {
294                    self.buf.write_char('(');
295                }
296                for (i, a) in args.iter().enumerate() {
297                    if i > 0 {
298                        self.buf.write_str(", ");
299                    }
300                    self.write_value(a);
301                }
302                if *is_tuple {
303                    self.buf.write_char(')');
304                }
305            }
306            Arg::Named { .. } => {
307                // Named 只在 `${name}` 被解析到时才会真正生效;
308                // 这里按普通值处理(保持行为可预测)。
309                self.write_placeholder_and_push(arg.clone());
310            }
311            Arg::Valuer(_) => self.write_placeholder_and_push(arg.clone()),
312            Arg::Value(_) => self.write_placeholder_and_push(arg.clone()),
313        }
314    }
315
316    fn write_placeholder_and_push(&mut self, arg: Arg) {
317        match self.flavor {
318            Flavor::MySQL
319            | Flavor::SQLite
320            | Flavor::CQL
321            | Flavor::ClickHouse
322            | Flavor::Presto
323            | Flavor::Informix
324            | Flavor::Doris => {
325                self.buf.write_char('?');
326            }
327            Flavor::PostgreSQL => {
328                let idx = self.values.len() + 1;
329                self.buf.write_char('$');
330                self.buf.write_str(&idx.to_string());
331            }
332            Flavor::SQLServer => {
333                let idx = self.values.len() + 1;
334                self.buf.write_str(&format!("@p{idx}"));
335            }
336            Flavor::Oracle => {
337                let idx = self.values.len() + 1;
338                self.buf.write_char(':');
339                self.buf.write_str(&idx.to_string());
340            }
341        }
342        self.values.push(arg);
343    }
344}
345
346fn split_named_args(mut values: Vec<Arg>) -> (Vec<Arg>, Vec<SqlNamedArg>) {
347    if values.is_empty() {
348        return (values, Vec::new());
349    }
350
351    let mut named = Vec::new();
352    while let Some(Arg::SqlNamed(a)) = values.last().cloned() {
353        values.pop();
354        named.push(a);
355    }
356    named.reverse();
357    (values, named)
358}