1use std::collections::HashMap;
16use std::fmt::Debug;
17
18use databend_common_ast::parser::Dialect;
19
20pub trait Param: Debug {
21 fn as_sql_string(&self) -> String;
22}
23
24#[derive(Debug)]
25pub enum Params {
26 QuestionParams(Vec<String>),
28 NamedParams(HashMap<String, String>),
30}
31
32impl Default for Params {
33 fn default() -> Self {
34 Params::QuestionParams(vec![])
35 }
36}
37
38impl Params {
39 pub fn len(&self) -> usize {
40 match self {
41 Params::QuestionParams(vec) => vec.len(),
42 Params::NamedParams(map) => map.len(),
43 }
44 }
45
46 pub fn is_empty(&self) -> bool {
47 self.len() == 0
48 }
49
50 pub fn get_by_index(&self, index: usize) -> Option<&String> {
52 if index == 0 {
53 return None;
54 }
55 match self {
56 Params::QuestionParams(vec) => vec.get(index - 1),
57 _ => None,
58 }
59 }
60
61 pub fn get_by_name(&self, name: &str) -> Option<&String> {
62 match self {
63 Params::NamedParams(map) => map.get(name),
64 _ => None,
65 }
66 }
67
68 pub fn merge(&mut self, other: Params) {
69 match (self, other) {
70 (Params::QuestionParams(vec1), Params::QuestionParams(vec2)) => {
71 vec1.extend(vec2);
72 }
73 (Params::NamedParams(map1), Params::NamedParams(map2)) => {
74 map1.extend(map2);
75 }
76 _ => panic!("Cannot merge QuestionParams with NamedParams"),
77 }
78 }
79
80 pub fn replace(&self, sql: &str) -> String {
81 if !self.is_empty() {
82 let tokens = databend_common_ast::parser::tokenize_sql(sql).unwrap();
83 if let Ok((stmt, _)) =
84 databend_common_ast::parser::parse_sql(&tokens, Dialect::PostgreSQL)
85 {
86 let mut v = super::placeholder::PlaceholderVisitor::new();
87 return v.replace_sql(self, &stmt, sql);
88 }
89 }
90 sql.to_string()
91 }
92}
93
94macro_rules! impl_param_for_integer {
96 ($($t:ty)*) => ($(
97 impl Param for $t {
98 fn as_sql_string(&self) -> String {
99 self.to_string()
100 }
101 }
102 )*)
103}
104
105impl_param_for_integer! { i8 i16 i32 i64 f32 f64 i128 isize u8 u16 u32 u64 u128 usize }
106
107impl Param for bool {
109 fn as_sql_string(&self) -> String {
110 if *self {
111 "TRUE".to_string()
112 } else {
113 "FALSE".to_string()
114 }
115 }
116}
117
118impl Param for String {
120 fn as_sql_string(&self) -> String {
121 format!("'{self}'")
122 }
123}
124
125impl Param for &str {
127 fn as_sql_string(&self) -> String {
128 format!("'{self}'")
129 }
130}
131
132impl Param for () {
134 fn as_sql_string(&self) -> String {
135 "NULL".to_string()
136 }
137}
138
139impl<T> Param for Option<T>
140where
141 T: Param,
142{
143 fn as_sql_string(&self) -> String {
144 match self {
145 Some(s) => s.as_sql_string(),
146 None => "NULL".to_string(),
147 }
148 }
149}
150
151impl Param for serde_json::Value {
152 fn as_sql_string(&self) -> String {
153 match self {
154 serde_json::Value::Number(n) => n.to_string(),
155 serde_json::Value::String(s) => format!("'{s}'"),
156 serde_json::Value::Bool(b) => b.to_string(),
157 serde_json::Value::Null => "NULL".to_string(),
158 serde_json::Value::Array(values) => {
159 let mut s = String::from("[");
160 for (i, v) in values.iter().enumerate() {
161 if i > 0 {
162 s.push_str(", ");
163 }
164 s.push_str(&v.as_sql_string());
165 }
166 s.push(']');
167 s
168 }
169 serde_json::Value::Object(map) => {
170 let mut s = String::from("'{");
171 for (i, (k, v)) in map.iter().enumerate() {
172 if i > 0 {
173 s.push_str(", ");
174 }
175 s.push_str(&format!("\"{k}\": {}", v.as_sql_string()));
176 }
177 s.push_str("}'::JSON");
178 s
179 }
180 }
181 }
182}
183
184#[macro_export]
189macro_rules! params {
190 () => {
192 $crate::Params::default()
193 };
194 ($($key:ident => $value:expr),* $(,)?) => {
195 $crate::Params::NamedParams({
196 let mut map = HashMap::new();
197
198 $(
199 map.insert(stringify!($key).to_string(), $crate::Param::as_sql_string(&$value));
200 )*
201 map
202 })
203 };
204 ($($value:expr),* $(,)?) => {
206 $crate::Params::QuestionParams(vec![
207 $(
208 $crate::Param::as_sql_string(&$value),
209 )*
210 ])
211 };
212}
213
214impl From<()> for Params {
215 fn from(_: ()) -> Self {
216 Params::default()
217 }
218}
219
220macro_rules! impl_from_tuple_for_params {
222 () => {};
224
225 ($head:ident, $($tail:ident),*) => {
227 #[allow(non_snake_case)]
228 impl<$head: Param, $($tail: Param),*> From<($head, $($tail),*)> for Params {
229 fn from(tuple: ($head, $($tail),*)) -> Self {
230 let (h, $($tail),*) = tuple;
231 let mut params = Params::QuestionParams(vec![h.as_sql_string()]);
232 $(params.merge(Params::QuestionParams(vec![$tail.as_sql_string()]));)*
233 params
234 }
235 }
236
237 impl_from_tuple_for_params!($($tail),*);
238 };
239
240 ($last:ident) => {
242 impl<$last: Param> From<($last,)> for Params {
243 fn from(tuple: ($last,)) -> Self {
244 Params::QuestionParams(vec![tuple.0.as_sql_string()])
245 }
246 }
247 };
248}
249
250impl_from_tuple_for_params! { T1, T2, T3, T4, T5, T6, T7, T8, T9, T10 }
251
252impl From<Option<serde_json::Value>> for Params {
253 fn from(value: Option<serde_json::Value>) -> Self {
254 match value {
255 Some(v) => v.into(),
256 None => Params::default(),
257 }
258 }
259}
260
261impl From<serde_json::Value> for Params {
262 fn from(value: serde_json::Value) -> Self {
263 match value {
264 serde_json::Value::Array(obj) => {
265 let mut array = Vec::new();
266 for v in obj {
267 array.push(v.as_sql_string());
268 }
269 Params::QuestionParams(array)
270 }
271 serde_json::Value::Object(obj) => {
272 let mut map = HashMap::new();
273 for (k, v) in obj {
274 map.insert(k, v.as_sql_string());
275 }
276 Params::NamedParams(map)
277 }
278 other => Params::QuestionParams(vec![other.as_sql_string()]),
279 }
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286
287 #[test]
288 fn test_params() {
289 {
291 let name = "d";
292 let age = 4;
293 let params = params! {a => 1, b => age, c => name};
294 match params {
295 Params::NamedParams(map) => {
296 assert_eq!(map.get("a").unwrap(), "1");
297 assert_eq!(map.get("b").unwrap(), "4");
298 assert_eq!(map.get("c").unwrap(), "'d'");
299 }
300 _ => panic!("Expected NamedParams"),
301 }
302 let params = params! {};
303 assert!(params.is_empty());
304 }
305
306 {
308 let name = "d";
309 let age = 4;
310 let params = params! {name, age, 33u64};
311 match params {
312 Params::QuestionParams(vec) => {
313 assert_eq!(vec, vec!["'d'", "4", "33"]);
314 }
315 _ => panic!("Expected QuestionParams"),
316 }
317 }
318
319 {
321 let params: Params = (1, "44", 2, 3, "55", "66").into();
322 match params {
323 Params::QuestionParams(vec) => {
324 assert_eq!(vec, vec!["1", "'44'", "2", "3", "'55'", "'66'"]);
325 }
326 _ => panic!("Expected QuestionParams"),
327 }
328 }
329
330 {
332 let params: Params = (Some(1), None::<()>, Some("44"), None::<()>).into();
333 match params {
334 Params::QuestionParams(vec) => assert_eq!(vec, vec!["1", "NULL", "'44'", "NULL"]),
335 _ => panic!("Expected QuestionParams"),
336 }
337 }
338
339 {
341 let params: Params = serde_json::json!({
342 "a": 1,
343 "b": "44",
344 "c": 2,
345 "d": 3,
346 "e": "55",
347 "f": "66",
348 })
349 .into();
350 match params {
351 Params::NamedParams(map) => {
352 assert_eq!(map.get("a").unwrap(), "1");
353 assert_eq!(map.get("b").unwrap(), "'44'");
354 assert_eq!(map.get("c").unwrap(), "2");
355 assert_eq!(map.get("d").unwrap(), "3");
356 assert_eq!(map.get("e").unwrap(), "'55'");
357 assert_eq!(map.get("f").unwrap(), "'66'");
358 }
359 _ => panic!("Expected NamedParams"),
360 }
361 }
362
363 {
365 let params: Params =
366 serde_json::json!([1, "44", 2, serde_json::json!({"a" : 1}), "55", "66"]).into();
367 match params {
368 Params::QuestionParams(vec) => {
369 assert_eq!(
370 vec,
371 vec!["1", "'44'", "2", "'{\"a\": 1}'::JSON", "'55'", "'66'"]
372 );
373 }
374 _ => panic!("Expected QuestionParams"),
375 }
376 }
377 }
378
379 #[test]
380 fn test_replace() {
381 let params = params! {1, "44", 2, 3, "55", "66"};
382 let sql =
383 "SELECT * FROM table WHERE a = ? AND '?' = cj AND b = ? AND c = ? AND d = ? AND e = ? AND f = ?";
384 let replaced_sql = params.replace(sql);
385 assert_eq!(replaced_sql, "SELECT * FROM table WHERE a = 1 AND '?' = cj AND b = '44' AND c = 2 AND d = 3 AND e = '55' AND f = '66'");
386
387 let params = params! {a => 1, b => "44", c => 2, d => 3, e => "55", f => "66"};
388
389 {
390 let sql = "SELECT * FROM table WHERE a = :a AND '?' = cj AND b = :b AND c = :c AND d = :d AND e = :e AND f = :f";
391 let replaced_sql = params.replace(sql);
392 assert_eq!(replaced_sql, "SELECT * FROM table WHERE a = 1 AND '?' = cj AND b = '44' AND c = 2 AND d = 3 AND e = '55' AND f = '66'");
393 }
394
395 {
396 let sql = "SELECT b = :b, a = :a FROM table WHERE a = :a AND '?' = cj AND b = :b AND c = :c AND d = :d AND e = :e AND f = :f";
397 let replaced_sql = params.replace(sql);
398 assert_eq!(replaced_sql, "SELECT b = '44', a = 1 FROM table WHERE a = 1 AND '?' = cj AND b = '44' AND c = 2 AND d = 3 AND e = '55' AND f = '66'");
399 }
400
401 {
402 let params = params! {1, "44", 2, 3, "55", "66"};
403 let sql = "SELECT $3, $2, $1 FROM table WHERE a = $1 AND '?' = cj AND b = $2 AND c = $3 AND d = $4 AND e = $5 AND f = $6";
404 let replaced_sql = params.replace(sql);
405 assert_eq!(replaced_sql, "SELECT 2, '44', 1 FROM table WHERE a = 1 AND '?' = cj AND b = '44' AND c = 2 AND d = 3 AND e = '55' AND f = '66'");
406 }
407 }
408}