halo/
modifiers.rs

1//! 参数修饰器与辅助函数(对齐 go-sqlbuilder `modifiers.go`)。
2
3use crate::flavor::Flavor;
4use crate::value::SqlValue;
5use crate::valuer::SqlValuer;
6use dyn_clone::DynClone;
7use std::cell::RefCell;
8use std::rc::Rc;
9
10/// Escape:把 `$` 替换为 `$$`,避免被 `Args::compile` 当成表达式。
11pub fn escape(ident: &str) -> String {
12    ident.replace('$', "$$")
13}
14
15/// EscapeAll:批量 Escape。
16pub fn escape_all(idents: impl IntoIterator<Item = impl AsRef<str>>) -> Vec<String> {
17    idents.into_iter().map(|s| escape(s.as_ref())).collect()
18}
19
20/// Raw:标记为原样拼入 SQL(不会成为参数占位符)。
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct Raw {
23    pub(crate) expr: String,
24}
25
26pub fn raw(expr: impl Into<String>) -> Arg {
27    Arg::Raw(Raw { expr: expr.into() })
28}
29
30/// List:标记为参数列表,会展开成 `?, ?, ?`(或对应 flavor 占位符序列)。
31pub fn list<T: FlattenIntoArgs>(arg: T) -> Arg {
32    let mut out = Vec::new();
33    arg.flatten_into(&mut out);
34    Arg::List {
35        args: out,
36        is_tuple: false,
37    }
38}
39
40/// Tuple:标记为元组,会展开成 `(?, ?)`(或对应 flavor 占位符序列)。
41pub fn tuple<T: FlattenIntoArgs>(values: T) -> Arg {
42    let mut out = Vec::new();
43    values.flatten_into(&mut out);
44    Arg::List {
45        args: out,
46        is_tuple: true,
47    }
48}
49
50/// TupleNames:生成 `(a, b, c)` 的列名元组字符串(不做 escape)。
51pub fn tuple_names(names: impl IntoIterator<Item = impl AsRef<str>>) -> String {
52    let mut s = String::from("(");
53    let mut first = true;
54    for n in names {
55        if !first {
56            s.push_str(", ");
57        }
58        first = false;
59        s.push_str(n.as_ref());
60    }
61    s.push(')');
62    s
63}
64
65/// Flatten:对齐 go-sqlbuilder `Flatten` 的“递归展开”体验(Rust 版用 trait 代替反射)。
66pub fn flatten<T: FlattenIntoArgs>(v: T) -> Vec<Arg> {
67    let mut out = Vec::new();
68    v.flatten_into(&mut out);
69    out
70}
71
72/// Named:命名参数(仅用于 `Build/BuildNamed` 的 `${name}` 引用)。
73pub fn named(name: impl Into<String>, arg: impl Into<Arg>) -> Arg {
74    Arg::Named {
75        name: name.into(),
76        arg: Box::new(arg.into()),
77    }
78}
79
80/// 对齐 go 的 `sql.NamedArg`:用于在 SQL 中以 `@name` 占位复用。
81#[derive(Debug, Clone, PartialEq)]
82pub struct SqlNamedArg {
83    pub name: String,
84    pub value: Box<Arg>,
85}
86
87impl SqlNamedArg {
88    pub fn new(name: impl Into<String>, value: impl Into<Arg>) -> Self {
89        Self {
90            name: name.into(),
91            value: Box::new(value.into()),
92        }
93    }
94}
95
96/// Builder/Args 体系使用的动态参数类型。
97#[derive(Clone)]
98pub enum Arg {
99    Value(SqlValue),
100    Valuer(Box<dyn SqlValuer>),
101    SqlNamed(SqlNamedArg),
102    Raw(Raw),
103    /// List/Tuple 的统一表示。
104    List {
105        args: Vec<Arg>,
106        is_tuple: bool,
107    },
108    /// Named(name,arg) —— 只在 Build/BuildNamed 的 `${name}` 路径上生效。
109    Named {
110        name: String,
111        arg: Box<Arg>,
112    },
113    Builder(Box<dyn Builder>),
114}
115
116impl std::fmt::Debug for Arg {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        match self {
119            Self::Value(v) => f.debug_tuple("Value").field(v).finish(),
120            Self::Valuer(_) => f.write_str("Valuer(..)"),
121            Self::SqlNamed(v) => f.debug_tuple("SqlNamed").field(v).finish(),
122            Self::Raw(v) => f.debug_tuple("Raw").field(v).finish(),
123            Self::List { args, is_tuple } => f
124                .debug_struct("List")
125                .field("args", args)
126                .field("is_tuple", is_tuple)
127                .finish(),
128            Self::Named { name, arg } => f
129                .debug_struct("Named")
130                .field("name", name)
131                .field("arg", arg)
132                .finish(),
133            Self::Builder(_) => f.write_str("Builder(..)"),
134        }
135    }
136}
137
138impl PartialEq for Arg {
139    fn eq(&self, other: &Self) -> bool {
140        match (self, other) {
141            (Self::Value(a), Self::Value(b)) => a == b,
142            (Self::Valuer(_), _) | (_, Self::Valuer(_)) => false,
143            (Self::SqlNamed(a), Self::SqlNamed(b)) => a == b,
144            (Self::Raw(a), Self::Raw(b)) => a == b,
145            (
146                Self::List {
147                    args: a,
148                    is_tuple: at,
149                },
150                Self::List {
151                    args: b,
152                    is_tuple: bt,
153                },
154            ) => at == bt && a == b,
155            (Self::Named { name: an, arg: aa }, Self::Named { name: bn, arg: ba }) => {
156                an == bn && aa == ba
157            }
158            (Self::Builder(_), _) | (_, Self::Builder(_)) => false,
159            _ => false,
160        }
161    }
162}
163
164impl From<Box<dyn Builder>> for Arg {
165    fn from(v: Box<dyn Builder>) -> Self {
166        Self::Builder(v)
167    }
168}
169
170impl From<Box<dyn SqlValuer>> for Arg {
171    fn from(v: Box<dyn SqlValuer>) -> Self {
172        Self::Valuer(v)
173    }
174}
175
176impl Builder for Box<dyn Builder> {
177    fn build_with_flavor(&self, flavor: Flavor, initial_arg: &[Arg]) -> (String, Vec<Arg>) {
178        (**self).build_with_flavor(flavor, initial_arg)
179    }
180
181    fn flavor(&self) -> Flavor {
182        (**self).flavor()
183    }
184}
185
186/// 对齐 go-sqlbuilder `Builder`:可嵌套构建 SQL。
187pub trait Builder: DynClone {
188    fn build(&self) -> (String, Vec<Arg>) {
189        self.build_with_flavor(self.flavor(), &[])
190    }
191
192    fn build_with_flavor(&self, flavor: Flavor, initial_arg: &[Arg]) -> (String, Vec<Arg>);
193
194    fn flavor(&self) -> Flavor;
195}
196
197dyn_clone::clone_trait_object!(Builder);
198
199/// RcBuilder:把 `Rc<RefCell<T>>` 包装成 `Builder`,用于对齐 go-sqlbuilder 的“共享 builder 指针”语义。
200///
201/// 典型用法:把 `SelectBuilder` 作为子查询参数传递,同时允许后续继续修改原 builder,
202/// 使得最终 build 时使用的是最新状态(late-binding)。
203#[derive(Debug)]
204pub struct RcBuilder<T: Builder> {
205    inner: Rc<RefCell<T>>,
206}
207
208impl<T: Builder> Clone for RcBuilder<T> {
209    fn clone(&self) -> Self {
210        Self {
211            inner: self.inner.clone(),
212        }
213    }
214}
215
216impl<T: Builder> RcBuilder<T> {
217    pub fn new(inner: Rc<RefCell<T>>) -> Self {
218        Self { inner }
219    }
220
221    pub fn inner(&self) -> Rc<RefCell<T>> {
222        self.inner.clone()
223    }
224}
225
226impl<T: Builder> Builder for RcBuilder<T> {
227    fn build_with_flavor(&self, flavor: Flavor, initial_arg: &[Arg]) -> (String, Vec<Arg>) {
228        self.inner.borrow().build_with_flavor(flavor, initial_arg)
229    }
230
231    fn flavor(&self) -> Flavor {
232        self.inner.borrow().flavor()
233    }
234}
235
236pub fn rc_builder<T: Builder>(inner: Rc<RefCell<T>>) -> RcBuilder<T> {
237    RcBuilder::new(inner)
238}
239
240impl From<SqlValue> for Arg {
241    fn from(v: SqlValue) -> Self {
242        Self::Value(v)
243    }
244}
245
246impl From<i64> for Arg {
247    fn from(v: i64) -> Self {
248        SqlValue::I64(v).into()
249    }
250}
251impl From<i32> for Arg {
252    fn from(v: i32) -> Self {
253        SqlValue::I64(v as i64).into()
254    }
255}
256impl From<u64> for Arg {
257    fn from(v: u64) -> Self {
258        SqlValue::U64(v).into()
259    }
260}
261impl From<u16> for Arg {
262    fn from(v: u16) -> Self {
263        SqlValue::U64(v as u64).into()
264    }
265}
266impl From<bool> for Arg {
267    fn from(v: bool) -> Self {
268        SqlValue::Bool(v).into()
269    }
270}
271impl From<f64> for Arg {
272    fn from(v: f64) -> Self {
273        SqlValue::F64(v).into()
274    }
275}
276impl From<&'static str> for Arg {
277    fn from(v: &'static str) -> Self {
278        SqlValue::from(v).into()
279    }
280}
281impl From<String> for Arg {
282    fn from(v: String) -> Self {
283        SqlValue::from(v).into()
284    }
285}
286impl From<Vec<u8>> for Arg {
287    fn from(v: Vec<u8>) -> Self {
288        SqlValue::Bytes(v).into()
289    }
290}
291
292impl<T> From<Option<T>> for Arg
293where
294    T: Into<SqlValue>,
295{
296    fn from(v: Option<T>) -> Self {
297        match v {
298            Some(x) => x.into().into(),
299            None => SqlValue::Null.into(),
300        }
301    }
302}
303
304impl From<time::OffsetDateTime> for Arg {
305    fn from(v: time::OffsetDateTime) -> Self {
306        SqlValue::from(v).into()
307    }
308}
309impl From<SqlNamedArg> for Arg {
310    fn from(v: SqlNamedArg) -> Self {
311        Self::SqlNamed(v)
312    }
313}
314
315/// 用 trait 实现 go-sqlbuilder `Flatten` 的“递归展开”体验。
316pub trait FlattenIntoArgs {
317    fn flatten_into(self, out: &mut Vec<Arg>);
318}
319
320impl<T: Into<Arg>> FlattenIntoArgs for T {
321    fn flatten_into(self, out: &mut Vec<Arg>) {
322        out.push(self.into());
323    }
324}
325
326impl<T: FlattenIntoArgs> FlattenIntoArgs for Vec<T> {
327    fn flatten_into(self, out: &mut Vec<Arg>) {
328        for v in self {
329            v.flatten_into(out);
330        }
331    }
332}
333
334impl<T: FlattenIntoArgs, const N: usize> FlattenIntoArgs for [T; N] {
335    fn flatten_into(self, out: &mut Vec<Arg>) {
336        for v in self {
337            v.flatten_into(out);
338        }
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use pretty_assertions::assert_eq;
346
347    #[test]
348    fn test_escape() {
349        assert_eq!(escape("foo"), "foo");
350        assert_eq!(escape("$foo"), "$$foo");
351        assert_eq!(escape("$$$"), "$$$$$$");
352    }
353
354    #[test]
355    fn test_escape_all() {
356        assert_eq!(
357            escape_all(["foo", "$foo"]),
358            vec!["foo".to_string(), "$$foo".to_string()]
359        );
360    }
361
362    #[test]
363    fn tuple_names_basic() {
364        assert_eq!(tuple_names(["a", "b"]), "(a, b)");
365    }
366
367    #[test]
368    fn flatten_vec_and_array() {
369        let a = list(vec![1_i64, 2, 3]);
370        match a {
371            Arg::List { args, is_tuple } => {
372                assert!(!is_tuple);
373                assert_eq!(args.len(), 3);
374            }
375            _ => panic!("expected list"),
376        }
377
378        let b = list([1_i64, 2, 3]);
379        match b {
380            Arg::List { args, is_tuple } => {
381                assert!(!is_tuple);
382                assert_eq!(args.len(), 3);
383            }
384            _ => panic!("expected list"),
385        }
386    }
387}