Skip to main content

databend_driver/
params.rs

1// Copyright 2021 Datafuse Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt::Debug;
17
18use databend_common_ast::parser::Dialect;
19
20pub trait Param: Debug {
21    fn as_json_value(&self) -> serde_json::Value;
22
23    fn as_sql_string(&self) -> String {
24        json_value_to_sql_string(&self.as_json_value())
25    }
26}
27
28#[derive(Debug)]
29pub enum Params {
30    // ?, ?
31    QuestionParams(Vec<serde_json::Value>),
32    // :name, :age
33    NamedParams(HashMap<String, serde_json::Value>),
34}
35
36impl Default for Params {
37    fn default() -> Self {
38        Params::QuestionParams(vec![])
39    }
40}
41
42/// Convert a `serde_json::Value` to a SQL string representation for client-side binding.
43pub fn json_value_to_sql_string(v: &serde_json::Value) -> String {
44    match v {
45        serde_json::Value::Null => "NULL".to_string(),
46        serde_json::Value::Bool(b) => {
47            if *b {
48                "TRUE".to_string()
49            } else {
50                "FALSE".to_string()
51            }
52        }
53        serde_json::Value::Number(n) => n.to_string(),
54        serde_json::Value::String(s) => format!("'{}'", s.replace('\'', "''")),
55        serde_json::Value::Array(arr) => {
56            let items: Vec<String> = arr.iter().map(json_value_to_sql_string).collect();
57            format!("[{}]", items.join(", "))
58        }
59        serde_json::Value::Object(map) => {
60            let mut s = String::from("'{");
61            for (i, (k, v)) in map.iter().enumerate() {
62                if i > 0 {
63                    s.push_str(", ");
64                }
65                s.push_str(&format!("\"{k}\": {}", json_value_to_sql_string(v)));
66            }
67            s.push_str("}'::JSON");
68            s
69        }
70    }
71}
72
73impl Params {
74    pub fn len(&self) -> usize {
75        match self {
76            Params::QuestionParams(vec) => vec.len(),
77            Params::NamedParams(map) => map.len(),
78        }
79    }
80
81    pub fn is_empty(&self) -> bool {
82        self.len() == 0
83    }
84
85    // index based from 1
86    pub fn get_by_index(&self, index: usize) -> Option<&serde_json::Value> {
87        if index == 0 {
88            return None;
89        }
90        match self {
91            Params::QuestionParams(vec) => vec.get(index - 1),
92            _ => None,
93        }
94    }
95
96    pub fn get_by_name(&self, name: &str) -> Option<&serde_json::Value> {
97        match self {
98            Params::NamedParams(map) => map.get(name),
99            _ => None,
100        }
101    }
102
103    pub fn merge(&mut self, other: Params) {
104        match (self, other) {
105            (Params::QuestionParams(vec1), Params::QuestionParams(vec2)) => {
106                vec1.extend(vec2);
107            }
108            (Params::NamedParams(map1), Params::NamedParams(map2)) => {
109                map1.extend(map2);
110            }
111            _ => panic!("Cannot merge QuestionParams with NamedParams"),
112        }
113    }
114
115    /// Convert params to a JSON value suitable for server-side parameter binding.
116    /// `QuestionParams` → `Value::Array`, `NamedParams` → `Value::Object`.
117    pub fn to_json_value(&self) -> serde_json::Value {
118        match self {
119            Params::QuestionParams(vec) => serde_json::Value::Array(vec.clone()),
120            Params::NamedParams(map) => {
121                let obj: serde_json::Map<String, serde_json::Value> =
122                    map.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
123                serde_json::Value::Object(obj)
124            }
125        }
126    }
127
128    pub fn replace(&self, sql: &str) -> String {
129        if !self.is_empty() {
130            let tokens = databend_common_ast::parser::tokenize_sql(sql).unwrap();
131            if let Ok((stmt, _)) =
132                databend_common_ast::parser::parse_sql(&tokens, Dialect::PostgreSQL)
133            {
134                let mut v = super::placeholder::PlaceholderVisitor::new();
135                return v.replace_sql(self, &stmt, sql);
136            }
137        }
138        sql.to_string()
139    }
140}
141
142// Implement Param for numeric types that fit in serde_json::Number
143macro_rules! impl_param_for_json_number {
144    ($($t:ty)*) => ($(
145        impl Param for $t {
146            fn as_json_value(&self) -> serde_json::Value {
147                serde_json::json!(self)
148            }
149        }
150    )*)
151}
152
153impl_param_for_json_number! { i8 i16 i32 i64 isize u8 u16 u32 u64 usize f32 f64 }
154
155// i128/u128 cannot be represented in JSON numbers, store as string
156impl Param for i128 {
157    fn as_json_value(&self) -> serde_json::Value {
158        // If it fits in i64, use a number; otherwise use a string to avoid precision loss
159        if *self >= i128::from(i64::MIN) && *self <= i128::from(i64::MAX) {
160            serde_json::json!(*self as i64)
161        } else {
162            serde_json::Value::String(self.to_string())
163        }
164    }
165}
166
167impl Param for u128 {
168    fn as_json_value(&self) -> serde_json::Value {
169        // If it fits in u64, use a number; otherwise use a string to avoid precision loss
170        if *self <= u128::from(u64::MAX) {
171            serde_json::json!(*self as u64)
172        } else {
173            serde_json::Value::String(self.to_string())
174        }
175    }
176}
177
178impl Param for bool {
179    fn as_json_value(&self) -> serde_json::Value {
180        serde_json::Value::Bool(*self)
181    }
182}
183
184impl Param for String {
185    fn as_json_value(&self) -> serde_json::Value {
186        serde_json::Value::String(self.clone())
187    }
188}
189
190impl Param for &str {
191    fn as_json_value(&self) -> serde_json::Value {
192        serde_json::Value::String(self.to_string())
193    }
194}
195
196impl Param for () {
197    fn as_json_value(&self) -> serde_json::Value {
198        serde_json::Value::Null
199    }
200}
201
202impl<T> Param for Option<T>
203where
204    T: Param,
205{
206    fn as_json_value(&self) -> serde_json::Value {
207        match self {
208            Some(s) => s.as_json_value(),
209            None => serde_json::Value::Null,
210        }
211    }
212}
213
214impl Param for serde_json::Value {
215    fn as_json_value(&self) -> serde_json::Value {
216        self.clone()
217    }
218}
219
220/// let name = d;
221/// let age = 4;
222/// params!{a => 1, b => 2, c =>  name }  ---> generate Params::NamedParams{"a" : 1, "b": 2, "c": "d"}
223/// params!{ name, age } ---> generate Params::QuestionParams{ vec!["d", 4] }
224#[macro_export]
225macro_rules! params {
226    // Handle named parameters
227    () => {
228        $crate::Params::default()
229    };
230    ($($key:ident => $value:expr),* $(,)?) => {
231        $crate::Params::NamedParams({
232            let mut map = HashMap::new();
233
234            $(
235                map.insert(stringify!($key).to_string(), $crate::Param::as_json_value(&$value));
236            )*
237            map
238        })
239    };
240    // Handle positional parameters
241    ($($value:expr),* $(,)?) => {
242        $crate::Params::QuestionParams(vec![
243            $(
244                $crate::Param::as_json_value(&$value),
245            )*
246        ])
247    };
248}
249
250impl From<()> for Params {
251    fn from(_: ()) -> Self {
252        Params::default()
253    }
254}
255
256// impl From Tuple(A, B, C, D....) for Params where A, B, C, D: Param
257macro_rules! impl_from_tuple_for_params {
258    // empty tuple
259    () => {};
260
261    // recursive impl
262    ($head:ident, $($tail:ident),*) => {
263	#[allow(non_snake_case)]
264        impl<$head: Param, $($tail: Param),*> From<($head, $($tail),*)> for Params {
265            fn from(tuple: ($head, $($tail),*)) -> Self {
266                let (h, $($tail),*) = tuple;
267                let mut params = Params::QuestionParams(vec![h.as_json_value()]);
268                $(params.merge(Params::QuestionParams(vec![$tail.as_json_value()]));)*
269                params
270            }
271        }
272
273        impl_from_tuple_for_params!($($tail),*);
274    };
275
276    // single element tuple
277    ($last:ident) => {
278        impl<$last: Param> From<($last,)> for Params {
279            fn from(tuple: ($last,)) -> Self {
280                Params::QuestionParams(vec![tuple.0.as_json_value()])
281            }
282        }
283    };
284}
285
286impl_from_tuple_for_params! { T1, T2, T3, T4, T5, T6, T7, T8, T9, T10 }
287
288impl From<Option<serde_json::Value>> for Params {
289    fn from(value: Option<serde_json::Value>) -> Self {
290        match value {
291            Some(v) => v.into(),
292            None => Params::default(),
293        }
294    }
295}
296
297impl From<serde_json::Value> for Params {
298    fn from(value: serde_json::Value) -> Self {
299        match value {
300            serde_json::Value::Array(arr) => Params::QuestionParams(arr),
301            serde_json::Value::Object(obj) => Params::NamedParams(obj.into_iter().collect()),
302            other => Params::QuestionParams(vec![other]),
303        }
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[test]
312    fn test_params() {
313        // Test named parameters
314        {
315            let name = "d";
316            let age = 4;
317            let params = params! {a => 1, b => age, c => name};
318            match params {
319                Params::NamedParams(map) => {
320                    assert_eq!(map.get("a").unwrap(), &serde_json::json!(1));
321                    assert_eq!(map.get("b").unwrap(), &serde_json::json!(4));
322                    assert_eq!(map.get("c").unwrap(), &serde_json::json!("d"));
323                }
324                _ => panic!("Expected NamedParams"),
325            }
326            let params = params! {};
327            assert!(params.is_empty());
328        }
329
330        // Test positional parameters
331        {
332            let name = "d";
333            let age = 4;
334            let params = params! {name, age, 33u64};
335            match params {
336                Params::QuestionParams(vec) => {
337                    assert_eq!(
338                        vec,
339                        vec![
340                            serde_json::json!("d"),
341                            serde_json::json!(4),
342                            serde_json::json!(33u64)
343                        ]
344                    );
345                }
346                _ => panic!("Expected QuestionParams"),
347            }
348        }
349
350        // Test into params for tuple
351        {
352            let params: Params = (1, "44", 2, 3, "55", "66").into();
353            match params {
354                Params::QuestionParams(vec) => {
355                    assert_eq!(
356                        vec,
357                        vec![
358                            serde_json::json!(1),
359                            serde_json::json!("44"),
360                            serde_json::json!(2),
361                            serde_json::json!(3),
362                            serde_json::json!("55"),
363                            serde_json::json!("66"),
364                        ]
365                    );
366                }
367                _ => panic!("Expected QuestionParams"),
368            }
369        }
370
371        // Test Option<T>
372        {
373            let params: Params = (Some(1), None::<()>, Some("44"), None::<()>).into();
374            match params {
375                Params::QuestionParams(vec) => assert_eq!(
376                    vec,
377                    vec![
378                        serde_json::json!(1),
379                        serde_json::Value::Null,
380                        serde_json::json!("44"),
381                        serde_json::Value::Null,
382                    ]
383                ),
384                _ => panic!("Expected QuestionParams"),
385            }
386        }
387
388        // Test into params for serde_json
389        {
390            let params: Params = serde_json::json!({
391            "a": 1,
392            "b": "44",
393            "c": 2,
394            "d": 3,
395            "e": "55",
396            "f": "66",
397            })
398            .into();
399            match params {
400                Params::NamedParams(map) => {
401                    assert_eq!(map.get("a").unwrap(), &serde_json::json!(1));
402                    assert_eq!(map.get("b").unwrap(), &serde_json::json!("44"));
403                    assert_eq!(map.get("c").unwrap(), &serde_json::json!(2));
404                    assert_eq!(map.get("d").unwrap(), &serde_json::json!(3));
405                    assert_eq!(map.get("e").unwrap(), &serde_json::json!("55"));
406                    assert_eq!(map.get("f").unwrap(), &serde_json::json!("66"));
407                }
408                _ => panic!("Expected NamedParams"),
409            }
410        }
411
412        // Test into params for serde_json::Value::Array
413        {
414            let params: Params =
415                serde_json::json!([1, "44", 2, serde_json::json!({"a" : 1}), "55", "66"]).into();
416            match params {
417                Params::QuestionParams(vec) => {
418                    assert_eq!(
419                        vec,
420                        vec![
421                            serde_json::json!(1),
422                            serde_json::json!("44"),
423                            serde_json::json!(2),
424                            serde_json::json!({"a": 1}),
425                            serde_json::json!("55"),
426                            serde_json::json!("66"),
427                        ]
428                    );
429                }
430                _ => panic!("Expected QuestionParams"),
431            }
432        }
433    }
434
435    #[test]
436    fn test_to_json_value() {
437        // Test positional params
438        let params = params! {1, "hello", 9.99};
439        let json = params.to_json_value();
440        assert_eq!(json, serde_json::json!([1, "hello", 9.99]));
441
442        // Test named params
443        let params = params! {a => 1, b => "hello", c => true};
444        let json = params.to_json_value();
445        let obj = json.as_object().unwrap();
446        assert_eq!(obj.get("a").unwrap(), &serde_json::json!(1));
447        assert_eq!(obj.get("b").unwrap(), &serde_json::json!("hello"));
448        assert_eq!(obj.get("c").unwrap(), &serde_json::json!(true));
449
450        // Test NULL
451        let params = params! {()};
452        let json = params.to_json_value();
453        assert_eq!(json, serde_json::json!([null]));
454
455        // Test Option<T>
456        let params: Params = (Some(42), None::<()>, Some("world")).into();
457        let json = params.to_json_value();
458        assert_eq!(json, serde_json::json!([42, null, "world"]));
459
460        // Test lowercase bool (from serde_json::Value::Bool)
461        let params: Params = serde_json::json!([true, false]).into();
462        let json = params.to_json_value();
463        assert_eq!(json, serde_json::json!([true, false]));
464
465        // Test large u64 above i64::MAX
466        let big: u64 = u64::MAX;
467        let params: Params = (big,).into();
468        let json = params.to_json_value();
469        assert_eq!(json, serde_json::json!([big]));
470    }
471
472    #[test]
473    fn test_replace() {
474        let params = params! {1, "44", 2, 3, "55", "66"};
475        let sql =
476            "SELECT * FROM table WHERE a = ? AND '?' = cj AND b = ? AND c = ? AND d = ? AND e = ? AND f = ?";
477        let replaced_sql = params.replace(sql);
478        assert_eq!(replaced_sql, "SELECT * FROM table WHERE a = 1 AND '?' = cj AND b = '44' AND c = 2 AND d = 3 AND e = '55' AND f = '66'");
479
480        let params = params! {a => 1, b => "44", c => 2, d => 3, e => "55", f => "66"};
481
482        {
483            let sql = "SELECT * FROM table WHERE a = :a AND '?' = cj AND b = :b AND c = :c AND d = :d AND e = :e AND f = :f";
484            let replaced_sql = params.replace(sql);
485            assert_eq!(replaced_sql, "SELECT * FROM table WHERE a = 1 AND '?' = cj AND b = '44' AND c = 2 AND d = 3 AND e = '55' AND f = '66'");
486        }
487
488        {
489            let sql = "SELECT b = :b, a = :a FROM table WHERE a = :a AND '?' = cj AND b = :b AND c = :c AND d = :d AND e = :e AND f = :f";
490            let replaced_sql = params.replace(sql);
491            assert_eq!(replaced_sql, "SELECT b = '44', a = 1 FROM table WHERE a = 1 AND '?' = cj AND b = '44' AND c = 2 AND d = 3 AND e = '55' AND f = '66'");
492        }
493
494        {
495            let params = params! {1, "44", 2, 3, "55", "66"};
496            let sql = "SELECT $3, $2, $1 FROM table WHERE a = $1 AND '?' = cj AND b = $2 AND c = $3 AND d = $4 AND e = $5 AND f = $6";
497            let replaced_sql = params.replace(sql);
498            assert_eq!(replaced_sql, "SELECT 2, '44', 1 FROM table WHERE a = 1 AND '?' = cj AND b = '44' AND c = 2 AND d = 3 AND e = '55' AND f = '66'");
499        }
500    }
501}