databend_driver/
params.rs

1// Copyright 2021 Datafuse Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::fmt::Debug;
17
18use databend_common_ast::parser::Dialect;
19
20pub trait Param: Debug {
21    fn as_sql_string(&self) -> String;
22}
23
24#[derive(Debug)]
25pub enum Params {
26    // ?, ?
27    QuestionParams(Vec<String>),
28    // :name, :age
29    NamedParams(HashMap<String, String>),
30}
31
32impl Default for Params {
33    fn default() -> Self {
34        Params::QuestionParams(vec![])
35    }
36}
37
38impl Params {
39    pub fn len(&self) -> usize {
40        match self {
41            Params::QuestionParams(vec) => vec.len(),
42            Params::NamedParams(map) => map.len(),
43        }
44    }
45
46    pub fn is_empty(&self) -> bool {
47        self.len() == 0
48    }
49
50    // index based from 1
51    pub fn get_by_index(&self, index: usize) -> Option<&String> {
52        if index == 0 {
53            return None;
54        }
55        match self {
56            Params::QuestionParams(vec) => vec.get(index - 1),
57            _ => None,
58        }
59    }
60
61    pub fn get_by_name(&self, name: &str) -> Option<&String> {
62        match self {
63            Params::NamedParams(map) => map.get(name),
64            _ => None,
65        }
66    }
67
68    pub fn merge(&mut self, other: Params) {
69        match (self, other) {
70            (Params::QuestionParams(vec1), Params::QuestionParams(vec2)) => {
71                vec1.extend(vec2);
72            }
73            (Params::NamedParams(map1), Params::NamedParams(map2)) => {
74                map1.extend(map2);
75            }
76            _ => panic!("Cannot merge QuestionParams with NamedParams"),
77        }
78    }
79
80    pub fn replace(&self, sql: &str) -> String {
81        if !self.is_empty() {
82            let tokens = databend_common_ast::parser::tokenize_sql(sql).unwrap();
83            if let Ok((stmt, _)) =
84                databend_common_ast::parser::parse_sql(&tokens, Dialect::PostgreSQL)
85            {
86                let mut v = super::placeholder::PlaceholderVisitor::new();
87                return v.replace_sql(self, &stmt, sql);
88            }
89        }
90        sql.to_string()
91    }
92}
93
94// impl param for all integer types and string types
95macro_rules! impl_param_for_integer {
96    ($($t:ty)*) => ($(
97        impl Param for $t {
98            fn as_sql_string(&self) -> String {
99                self.to_string()
100            }
101        }
102    )*)
103}
104
105impl_param_for_integer! { i8 i16 i32 i64 f32 f64 i128 isize u8 u16 u32 u64 u128 usize }
106
107// Implement Param for String
108impl Param for bool {
109    fn as_sql_string(&self) -> String {
110        if *self {
111            "TRUE".to_string()
112        } else {
113            "FALSE".to_string()
114        }
115    }
116}
117
118// Implement Param for String
119impl Param for String {
120    fn as_sql_string(&self) -> String {
121        format!("'{self}'")
122    }
123}
124
125// Implement Param for &str
126impl Param for &str {
127    fn as_sql_string(&self) -> String {
128        format!("'{self}'")
129    }
130}
131
132// Impl Param for None
133impl Param for () {
134    fn as_sql_string(&self) -> String {
135        "NULL".to_string()
136    }
137}
138
139impl<T> Param for Option<T>
140where
141    T: Param,
142{
143    fn as_sql_string(&self) -> String {
144        match self {
145            Some(s) => s.as_sql_string(),
146            None => "NULL".to_string(),
147        }
148    }
149}
150
151impl Param for serde_json::Value {
152    fn as_sql_string(&self) -> String {
153        match self {
154            serde_json::Value::Number(n) => n.to_string(),
155            serde_json::Value::String(s) => format!("'{s}'"),
156            serde_json::Value::Bool(b) => b.to_string(),
157            serde_json::Value::Null => "NULL".to_string(),
158            serde_json::Value::Array(values) => {
159                let mut s = String::from("[");
160                for (i, v) in values.iter().enumerate() {
161                    if i > 0 {
162                        s.push_str(", ");
163                    }
164                    s.push_str(&v.as_sql_string());
165                }
166                s.push(']');
167                s
168            }
169            serde_json::Value::Object(map) => {
170                let mut s = String::from("'{");
171                for (i, (k, v)) in map.iter().enumerate() {
172                    if i > 0 {
173                        s.push_str(", ");
174                    }
175                    s.push_str(&format!("\"{k}\": {}", v.as_sql_string()));
176                }
177                s.push_str("}'::JSON");
178                s
179            }
180        }
181    }
182}
183
184/// let name = d;
185/// let age = 4;
186/// params!{a => 1, b => 2, c =>  name }  ---> generate Params::NamedParams{"a" : 1, "b": 2, "c": "d"}
187/// params!{ name, age } ---> generate Params::QuestionParams{ vec!["d", 4] }
188#[macro_export]
189macro_rules! params {
190    // Handle named parameters
191    () => {
192        $crate::Params::default()
193    };
194    ($($key:ident => $value:expr),* $(,)?) => {
195        $crate::Params::NamedParams({
196            let mut map = HashMap::new();
197
198            $(
199                map.insert(stringify!($key).to_string(), $crate::Param::as_sql_string(&$value));
200            )*
201            map
202        })
203    };
204    // Handle positional parameters
205    ($($value:expr),* $(,)?) => {
206        $crate::Params::QuestionParams(vec![
207            $(
208                $crate::Param::as_sql_string(&$value),
209            )*
210        ])
211    };
212}
213
214impl From<()> for Params {
215    fn from(_: ()) -> Self {
216        Params::default()
217    }
218}
219
220// impl From Tuple(A, B, C, D....) for Params where A, B, C, D: Param
221macro_rules! impl_from_tuple_for_params {
222    // empty tuple
223    () => {};
224
225    // recursive impl
226    ($head:ident, $($tail:ident),*) => {
227	#[allow(non_snake_case)]
228        impl<$head: Param, $($tail: Param),*> From<($head, $($tail),*)> for Params {
229            fn from(tuple: ($head, $($tail),*)) -> Self {
230                let (h, $($tail),*) = tuple;
231                let mut params = Params::QuestionParams(vec![h.as_sql_string()]);
232                $(params.merge(Params::QuestionParams(vec![$tail.as_sql_string()]));)*
233                params
234            }
235        }
236
237        impl_from_tuple_for_params!($($tail),*);
238    };
239
240    // single element tuple
241    ($last:ident) => {
242        impl<$last: Param> From<($last,)> for Params {
243            fn from(tuple: ($last,)) -> Self {
244                Params::QuestionParams(vec![tuple.0.as_sql_string()])
245            }
246        }
247    };
248}
249
250impl_from_tuple_for_params! { T1, T2, T3, T4, T5, T6, T7, T8, T9, T10 }
251
252impl From<Option<serde_json::Value>> for Params {
253    fn from(value: Option<serde_json::Value>) -> Self {
254        match value {
255            Some(v) => v.into(),
256            None => Params::default(),
257        }
258    }
259}
260
261impl From<serde_json::Value> for Params {
262    fn from(value: serde_json::Value) -> Self {
263        match value {
264            serde_json::Value::Array(obj) => {
265                let mut array = Vec::new();
266                for v in obj {
267                    array.push(v.as_sql_string());
268                }
269                Params::QuestionParams(array)
270            }
271            serde_json::Value::Object(obj) => {
272                let mut map = HashMap::new();
273                for (k, v) in obj {
274                    map.insert(k, v.as_sql_string());
275                }
276                Params::NamedParams(map)
277            }
278            other => Params::QuestionParams(vec![other.as_sql_string()]),
279        }
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286
287    #[test]
288    fn test_params() {
289        // Test named parameters
290        {
291            let name = "d";
292            let age = 4;
293            let params = params! {a => 1, b => age, c => name};
294            match params {
295                Params::NamedParams(map) => {
296                    assert_eq!(map.get("a").unwrap(), "1");
297                    assert_eq!(map.get("b").unwrap(), "4");
298                    assert_eq!(map.get("c").unwrap(), "'d'");
299                }
300                _ => panic!("Expected NamedParams"),
301            }
302            let params = params! {};
303            assert!(params.is_empty());
304        }
305
306        // Test positional parameters
307        {
308            let name = "d";
309            let age = 4;
310            let params = params! {name, age, 33u64};
311            match params {
312                Params::QuestionParams(vec) => {
313                    assert_eq!(vec, vec!["'d'", "4", "33"]);
314                }
315                _ => panic!("Expected QuestionParams"),
316            }
317        }
318
319        // Test into params for tuple
320        {
321            let params: Params = (1, "44", 2, 3, "55", "66").into();
322            match params {
323                Params::QuestionParams(vec) => {
324                    assert_eq!(vec, vec!["1", "'44'", "2", "3", "'55'", "'66'"]);
325                }
326                _ => panic!("Expected QuestionParams"),
327            }
328        }
329
330        // Test Option<T>
331        {
332            let params: Params = (Some(1), None::<()>, Some("44"), None::<()>).into();
333            match params {
334                Params::QuestionParams(vec) => assert_eq!(vec, vec!["1", "NULL", "'44'", "NULL"]),
335                _ => panic!("Expected QuestionParams"),
336            }
337        }
338
339        // Test into params for serde_json
340        {
341            let params: Params = serde_json::json!({
342            "a": 1,
343            "b": "44",
344            "c": 2,
345            "d": 3,
346            "e": "55",
347            "f": "66",
348            })
349            .into();
350            match params {
351                Params::NamedParams(map) => {
352                    assert_eq!(map.get("a").unwrap(), "1");
353                    assert_eq!(map.get("b").unwrap(), "'44'");
354                    assert_eq!(map.get("c").unwrap(), "2");
355                    assert_eq!(map.get("d").unwrap(), "3");
356                    assert_eq!(map.get("e").unwrap(), "'55'");
357                    assert_eq!(map.get("f").unwrap(), "'66'");
358                }
359                _ => panic!("Expected NamedParams"),
360            }
361        }
362
363        // Test into params for serde_json::Value::Array
364        {
365            let params: Params =
366                serde_json::json!([1, "44", 2, serde_json::json!({"a" : 1}), "55", "66"]).into();
367            match params {
368                Params::QuestionParams(vec) => {
369                    assert_eq!(
370                        vec,
371                        vec!["1", "'44'", "2", "'{\"a\": 1}'::JSON", "'55'", "'66'"]
372                    );
373                }
374                _ => panic!("Expected QuestionParams"),
375            }
376        }
377    }
378
379    #[test]
380    fn test_replace() {
381        let params = params! {1, "44", 2, 3, "55", "66"};
382        let sql =
383            "SELECT * FROM table WHERE a = ? AND '?' = cj AND b = ? AND c = ? AND d = ? AND e = ? AND f = ?";
384        let replaced_sql = params.replace(sql);
385        assert_eq!(replaced_sql, "SELECT * FROM table WHERE a = 1 AND '?' = cj AND b = '44' AND c = 2 AND d = 3 AND e = '55' AND f = '66'");
386
387        let params = params! {a => 1, b => "44", c => 2, d => 3, e => "55", f => "66"};
388
389        {
390            let sql = "SELECT * FROM table WHERE a = :a AND '?' = cj AND b = :b AND c = :c AND d = :d AND e = :e AND f = :f";
391            let replaced_sql = params.replace(sql);
392            assert_eq!(replaced_sql, "SELECT * FROM table WHERE a = 1 AND '?' = cj AND b = '44' AND c = 2 AND d = 3 AND e = '55' AND f = '66'");
393        }
394
395        {
396            let sql = "SELECT b = :b, a = :a FROM table WHERE a = :a AND '?' = cj AND b = :b AND c = :c AND d = :d AND e = :e AND f = :f";
397            let replaced_sql = params.replace(sql);
398            assert_eq!(replaced_sql, "SELECT b = '44', a = 1 FROM table WHERE a = 1 AND '?' = cj AND b = '44' AND c = 2 AND d = 3 AND e = '55' AND f = '66'");
399        }
400
401        {
402            let params = params! {1, "44", 2, 3, "55", "66"};
403            let sql = "SELECT $3, $2, $1 FROM table WHERE a = $1 AND '?' = cj AND b = $2 AND c = $3 AND d = $4 AND e = $5 AND f = $6";
404            let replaced_sql = params.replace(sql);
405            assert_eq!(replaced_sql, "SELECT 2, '44', 1 FROM table WHERE a = 1 AND '?' = cj AND b = '44' AND c = 2 AND d = 3 AND e = '55' AND f = '66'");
406        }
407    }
408}