1use core::ops::ControlFlow;
21
22use sqlparser::ast::{Statement, Value, VisitMut, VisitorMut};
23
24#[derive(Debug, Clone)]
28pub enum ParamValue {
29 Null,
30 Bool(bool),
31 Int64(i64),
32 Float64(f64),
33 Text(String),
34}
35
36pub fn bind_params(stmt: &mut Statement, params: &[ParamValue]) {
38 if params.is_empty() {
39 return;
40 }
41 let mut binder = ParamBinder { params };
42 let _ = stmt.visit(&mut binder);
43}
44
45struct ParamBinder<'a> {
51 params: &'a [ParamValue],
52}
53
54impl VisitorMut for ParamBinder<'_> {
55 type Break = ();
56
57 fn pre_visit_value(&mut self, value: &mut Value) -> ControlFlow<Self::Break> {
58 if let Value::Placeholder(p) = value
59 && let Some(v) = placeholder_to_value(p, self.params)
60 {
61 *value = v;
62 }
63 ControlFlow::Continue(())
64 }
65}
66
67fn placeholder_to_value(placeholder: &str, params: &[ParamValue]) -> Option<Value> {
68 let idx_str = placeholder.strip_prefix('$')?;
69 let idx: usize = idx_str.parse().ok()?;
70 let param = params.get(idx.checked_sub(1)?)?;
71 Some(match param {
72 ParamValue::Null => Value::Null,
73 ParamValue::Bool(true) => Value::Boolean(true),
74 ParamValue::Bool(false) => Value::Boolean(false),
75 ParamValue::Int64(n) => Value::Number(n.to_string(), false),
76 ParamValue::Float64(f) => Value::Number(f.to_string(), false),
77 ParamValue::Text(s) => Value::SingleQuotedString(s.clone()),
78 })
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84 use crate::parser::statement::parse_sql;
85
86 fn bind_and_format(sql: &str, params: &[ParamValue]) -> String {
87 let mut stmts = parse_sql(sql).unwrap();
88 for stmt in &mut stmts {
89 bind_params(stmt, params);
90 }
91 stmts
92 .iter()
93 .map(|s| s.to_string())
94 .collect::<Vec<_>>()
95 .join("; ")
96 }
97
98 #[test]
99 fn bind_select_where() {
100 let result = bind_and_format(
101 "SELECT * FROM users WHERE id = $1",
102 &[ParamValue::Int64(42)],
103 );
104 assert!(result.contains("id = 42"), "got: {result}");
105 }
106
107 #[test]
108 fn bind_string_param() {
109 let result = bind_and_format(
110 "SELECT * FROM users WHERE name = $1",
111 &[ParamValue::Text("alice".into())],
112 );
113 assert!(result.contains("name = 'alice'"), "got: {result}");
114 }
115
116 #[test]
117 fn bind_null_param() {
118 let result = bind_and_format("SELECT * FROM users WHERE name = $1", &[ParamValue::Null]);
119 assert!(result.contains("name = NULL"), "got: {result}");
120 }
121
122 #[test]
123 fn bind_multiple_params() {
124 let result = bind_and_format(
125 "SELECT * FROM users WHERE age > $1 AND name = $2",
126 &[ParamValue::Int64(18), ParamValue::Text("bob".into())],
127 );
128 assert!(result.contains("age > 18"), "got: {result}");
129 assert!(result.contains("name = 'bob'"), "got: {result}");
130 }
131
132 #[test]
133 fn bind_insert_values() {
134 let result = bind_and_format(
135 "INSERT INTO users (id, name) VALUES ($1, $2)",
136 &[ParamValue::Int64(1), ParamValue::Text("eve".into())],
137 );
138 assert!(result.contains("1, 'eve'"), "got: {result}");
139 }
140
141 #[test]
142 fn bind_bool_param() {
143 let result = bind_and_format(
144 "SELECT * FROM users WHERE active = $1",
145 &[ParamValue::Bool(true)],
146 );
147 assert!(result.contains("active = true"), "got: {result}");
148 }
149
150 #[test]
151 fn no_params_noop() {
152 let result = bind_and_format("SELECT 1", &[]);
153 assert!(result.contains("SELECT 1"));
154 }
155
156 #[test]
157 fn bind_cte_body_placeholders() {
158 let result = bind_and_format(
159 "WITH x AS (SELECT $1 AS v) SELECT v FROM x",
160 &[ParamValue::Int64(42)],
161 );
162 assert!(
163 result.contains("SELECT 42"),
164 "CTE body placeholder not substituted: {result}"
165 );
166 assert!(!result.contains("$1"), "placeholder survived: {result}");
167 }
168
169 #[test]
170 fn bind_recursive_cte_placeholders() {
171 let result = bind_and_format(
172 "WITH RECURSIVE chain AS (SELECT $1 AS id UNION ALL SELECT chain.id + 1 FROM chain WHERE chain.id < $2) SELECT * FROM chain",
173 &[ParamValue::Int64(1), ParamValue::Int64(10)],
174 );
175 assert!(
176 !result.contains("$1") && !result.contains("$2"),
177 "got: {result}"
178 );
179 }
180
181 #[test]
182 fn bind_array_elements() {
183 let result = bind_and_format(
184 "SELECT ARRAY[$1, $2, $3]",
185 &[
186 ParamValue::Int64(1),
187 ParamValue::Int64(2),
188 ParamValue::Int64(3),
189 ],
190 );
191 assert!(
192 !result.contains("$1"),
193 "array placeholder survived: {result}"
194 );
195 assert!(
196 result.contains('1') && result.contains('2') && result.contains('3'),
197 "got: {result}"
198 );
199 }
200
201 #[test]
202 fn bind_any_op_rhs() {
203 let result = bind_and_format(
204 "SELECT * FROM t WHERE id = ANY($1)",
205 &[ParamValue::Text("{a,b}".into())],
206 );
207 assert!(!result.contains("$1"), "got: {result}");
208 }
209
210 #[test]
211 fn bind_all_op_rhs() {
212 let result = bind_and_format(
213 "SELECT * FROM t WHERE id = ALL($1)",
214 &[ParamValue::Text("{a,b}".into())],
215 );
216 assert!(!result.contains("$1"), "got: {result}");
217 }
218
219 #[test]
220 fn bind_update_returning() {
221 let result = bind_and_format(
222 "UPDATE t SET n = 1 WHERE id = $1 RETURNING $2 AS tag",
223 &[ParamValue::Int64(7), ParamValue::Text("note".into())],
224 );
225 assert!(result.contains("id = 7"), "got: {result}");
226 assert!(result.contains("'note'"), "got: {result}");
227 assert!(!result.contains("$2"), "got: {result}");
228 }
229
230 #[test]
231 fn bind_delete_returning() {
232 let result = bind_and_format(
233 "DELETE FROM t WHERE id = $1 RETURNING $2 AS tag",
234 &[ParamValue::Int64(3), ParamValue::Text("gone".into())],
235 );
236 assert!(
237 result.contains("id = 3") && result.contains("'gone'"),
238 "got: {result}"
239 );
240 assert!(!result.contains("$2"), "got: {result}");
241 }
242
243 #[test]
244 fn bind_update_limit() {
245 let result = bind_and_format(
246 "UPDATE t SET n = 1 WHERE id > 0 LIMIT $1",
247 &[ParamValue::Int64(5)],
248 );
249 assert!(!result.contains("$1"), "got: {result}");
250 }
251
252 #[test]
253 fn bind_interval_value_placeholder() {
254 let result = bind_and_format(
255 "SELECT now() - INTERVAL $1",
256 &[ParamValue::Text("1 day".into())],
257 );
258 assert!(!result.contains("$1"), "got: {result}");
259 }
260
261 #[test]
262 fn bind_update_from_subquery() {
263 let result = bind_and_format(
264 "UPDATE t SET n = s.y FROM (SELECT $1 AS y) s WHERE t.id = s.y",
265 &[ParamValue::Int64(7)],
266 );
267 assert!(!result.contains("$1"), "got: {result}");
268 }
269
270 #[test]
271 fn bind_insert_on_conflict_update_placeholder() {
272 let result = bind_and_format(
273 "INSERT INTO t (id, n) VALUES ($1, $2) ON CONFLICT (id) DO UPDATE SET n = $3",
274 &[
275 ParamValue::Int64(1),
276 ParamValue::Int64(2),
277 ParamValue::Int64(99),
278 ],
279 );
280 assert!(!result.contains("$3"), "got: {result}");
281 }
282
283 #[test]
284 fn bind_window_partition_by_placeholder() {
285 let result = bind_and_format(
286 "SELECT LAG(x) OVER (PARTITION BY $1 ORDER BY b) FROM t",
287 &[ParamValue::Text("user_id".into())],
288 );
289 assert!(!result.contains("$1"), "got: {result}");
290 }
291
292 #[test]
293 fn bind_window_order_by_placeholder() {
294 let result = bind_and_format(
295 "SELECT LAG(x) OVER (PARTITION BY a ORDER BY $1) FROM t",
296 &[ParamValue::Text("ts".into())],
297 );
298 assert!(!result.contains("$1"), "got: {result}");
299 }
300}