1use core::ops::ControlFlow;
23
24use sqlparser::ast::{Statement, Value, VisitMut, VisitorMut};
25
26#[derive(Debug, Clone)]
30pub enum ParamValue {
31 Null,
32 Bool(bool),
33 Int64(i64),
34 Float64(f64),
35 Decimal(rust_decimal::Decimal),
37 Text(String),
38 Timestamp(nodedb_types::datetime::NdbDateTime),
40 Timestamptz(nodedb_types::datetime::NdbDateTime),
42}
43
44pub fn bind_params(stmt: &mut Statement, params: &[ParamValue]) {
46 if params.is_empty() {
47 return;
48 }
49 let mut binder = ParamBinder { params };
50 let _ = stmt.visit(&mut binder);
51}
52
53struct ParamBinder<'a> {
59 params: &'a [ParamValue],
60}
61
62impl VisitorMut for ParamBinder<'_> {
63 type Break = ();
64
65 fn pre_visit_value(&mut self, value: &mut Value) -> ControlFlow<Self::Break> {
66 if let Value::Placeholder(p) = value
67 && let Some(v) = placeholder_to_value(p, self.params)
68 {
69 *value = v;
70 }
71 ControlFlow::Continue(())
72 }
73}
74
75fn placeholder_to_value(placeholder: &str, params: &[ParamValue]) -> Option<Value> {
76 let idx_str = placeholder.strip_prefix('$')?;
77 let idx: usize = idx_str.parse().ok()?;
78 let param = params.get(idx.checked_sub(1)?)?;
79 Some(match param {
80 ParamValue::Null => Value::Null,
81 ParamValue::Bool(true) => Value::Boolean(true),
82 ParamValue::Bool(false) => Value::Boolean(false),
83 ParamValue::Int64(n) => Value::Number(n.to_string(), false),
84 ParamValue::Float64(f) => Value::Number(f.to_string(), false),
85 ParamValue::Decimal(d) => Value::Number(d.to_string(), false),
86 ParamValue::Text(s) => Value::SingleQuotedString(s.clone()),
87 ParamValue::Timestamp(dt) => Value::SingleQuotedString(dt.to_iso8601()),
90 ParamValue::Timestamptz(dt) => Value::SingleQuotedString(dt.to_iso8601()),
91 })
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97 use crate::parser::statement::parse_sql;
98
99 fn bind_and_format(sql: &str, params: &[ParamValue]) -> String {
100 let mut stmts = parse_sql(sql).unwrap();
101 for stmt in &mut stmts {
102 bind_params(stmt, params);
103 }
104 stmts
105 .iter()
106 .map(|s| s.to_string())
107 .collect::<Vec<_>>()
108 .join("; ")
109 }
110
111 #[test]
112 fn bind_select_where() {
113 let result = bind_and_format(
114 "SELECT * FROM users WHERE id = $1",
115 &[ParamValue::Int64(42)],
116 );
117 assert!(result.contains("id = 42"), "got: {result}");
118 }
119
120 #[test]
121 fn bind_string_param() {
122 let result = bind_and_format(
123 "SELECT * FROM users WHERE name = $1",
124 &[ParamValue::Text("alice".into())],
125 );
126 assert!(result.contains("name = 'alice'"), "got: {result}");
127 }
128
129 #[test]
130 fn bind_null_param() {
131 let result = bind_and_format("SELECT * FROM users WHERE name = $1", &[ParamValue::Null]);
132 assert!(result.contains("name = NULL"), "got: {result}");
133 }
134
135 #[test]
136 fn bind_multiple_params() {
137 let result = bind_and_format(
138 "SELECT * FROM users WHERE age > $1 AND name = $2",
139 &[ParamValue::Int64(18), ParamValue::Text("bob".into())],
140 );
141 assert!(result.contains("age > 18"), "got: {result}");
142 assert!(result.contains("name = 'bob'"), "got: {result}");
143 }
144
145 #[test]
146 fn bind_insert_values() {
147 let result = bind_and_format(
148 "INSERT INTO users (id, name) VALUES ($1, $2)",
149 &[ParamValue::Int64(1), ParamValue::Text("eve".into())],
150 );
151 assert!(result.contains("1, 'eve'"), "got: {result}");
152 }
153
154 #[test]
155 fn bind_bool_param() {
156 let result = bind_and_format(
157 "SELECT * FROM users WHERE active = $1",
158 &[ParamValue::Bool(true)],
159 );
160 assert!(result.contains("active = true"), "got: {result}");
161 }
162
163 #[test]
164 fn no_params_noop() {
165 let result = bind_and_format("SELECT 1", &[]);
166 assert!(result.contains("SELECT 1"));
167 }
168
169 #[test]
170 fn bind_cte_body_placeholders() {
171 let result = bind_and_format(
172 "WITH x AS (SELECT $1 AS v) SELECT v FROM x",
173 &[ParamValue::Int64(42)],
174 );
175 assert!(
176 result.contains("SELECT 42"),
177 "CTE body placeholder not substituted: {result}"
178 );
179 assert!(!result.contains("$1"), "placeholder survived: {result}");
180 }
181
182 #[test]
183 fn bind_recursive_cte_placeholders() {
184 let result = bind_and_format(
185 "WITH RECURSIVE chain AS (SELECT $1 AS id UNION ALL SELECT chain.id + 1 FROM chain WHERE chain.id < $2) SELECT * FROM chain",
186 &[ParamValue::Int64(1), ParamValue::Int64(10)],
187 );
188 assert!(
189 !result.contains("$1") && !result.contains("$2"),
190 "got: {result}"
191 );
192 }
193
194 #[test]
195 fn bind_array_elements() {
196 let result = bind_and_format(
197 "SELECT ARRAY[$1, $2, $3]",
198 &[
199 ParamValue::Int64(1),
200 ParamValue::Int64(2),
201 ParamValue::Int64(3),
202 ],
203 );
204 assert!(
205 !result.contains("$1"),
206 "array placeholder survived: {result}"
207 );
208 assert!(
209 result.contains('1') && result.contains('2') && result.contains('3'),
210 "got: {result}"
211 );
212 }
213
214 #[test]
215 fn bind_any_op_rhs() {
216 let result = bind_and_format(
217 "SELECT * FROM t WHERE id = ANY($1)",
218 &[ParamValue::Text("{a,b}".into())],
219 );
220 assert!(!result.contains("$1"), "got: {result}");
221 }
222
223 #[test]
224 fn bind_all_op_rhs() {
225 let result = bind_and_format(
226 "SELECT * FROM t WHERE id = ALL($1)",
227 &[ParamValue::Text("{a,b}".into())],
228 );
229 assert!(!result.contains("$1"), "got: {result}");
230 }
231
232 #[test]
233 fn bind_update_returning() {
234 let result = bind_and_format(
235 "UPDATE t SET n = 1 WHERE id = $1 RETURNING $2 AS tag",
236 &[ParamValue::Int64(7), ParamValue::Text("note".into())],
237 );
238 assert!(result.contains("id = 7"), "got: {result}");
239 assert!(result.contains("'note'"), "got: {result}");
240 assert!(!result.contains("$2"), "got: {result}");
241 }
242
243 #[test]
244 fn bind_delete_returning() {
245 let result = bind_and_format(
246 "DELETE FROM t WHERE id = $1 RETURNING $2 AS tag",
247 &[ParamValue::Int64(3), ParamValue::Text("gone".into())],
248 );
249 assert!(
250 result.contains("id = 3") && result.contains("'gone'"),
251 "got: {result}"
252 );
253 assert!(!result.contains("$2"), "got: {result}");
254 }
255
256 #[test]
257 fn bind_update_limit() {
258 let result = bind_and_format(
259 "UPDATE t SET n = 1 WHERE id > 0 LIMIT $1",
260 &[ParamValue::Int64(5)],
261 );
262 assert!(!result.contains("$1"), "got: {result}");
263 }
264
265 #[test]
266 fn bind_interval_value_placeholder() {
267 let result = bind_and_format(
268 "SELECT now() - INTERVAL $1",
269 &[ParamValue::Text("1 day".into())],
270 );
271 assert!(!result.contains("$1"), "got: {result}");
272 }
273
274 #[test]
275 fn bind_update_from_subquery() {
276 let result = bind_and_format(
277 "UPDATE t SET n = s.y FROM (SELECT $1 AS y) s WHERE t.id = s.y",
278 &[ParamValue::Int64(7)],
279 );
280 assert!(!result.contains("$1"), "got: {result}");
281 }
282
283 #[test]
284 fn bind_insert_on_conflict_update_placeholder() {
285 let result = bind_and_format(
286 "INSERT INTO t (id, n) VALUES ($1, $2) ON CONFLICT (id) DO UPDATE SET n = $3",
287 &[
288 ParamValue::Int64(1),
289 ParamValue::Int64(2),
290 ParamValue::Int64(99),
291 ],
292 );
293 assert!(!result.contains("$3"), "got: {result}");
294 }
295
296 #[test]
297 fn bind_window_partition_by_placeholder() {
298 let result = bind_and_format(
299 "SELECT LAG(x) OVER (PARTITION BY $1 ORDER BY b) FROM t",
300 &[ParamValue::Text("user_id".into())],
301 );
302 assert!(!result.contains("$1"), "got: {result}");
303 }
304
305 #[test]
306 fn bind_window_order_by_placeholder() {
307 let result = bind_and_format(
308 "SELECT LAG(x) OVER (PARTITION BY a ORDER BY $1) FROM t",
309 &[ParamValue::Text("ts".into())],
310 );
311 assert!(!result.contains("$1"), "got: {result}");
312 }
313}