Skip to main content

nodedb_sql/
params.rs

1//! AST-level parameter binding for prepared statements.
2//!
3//! Every `Value::Placeholder("$N")` in a parsed statement is rewritten to
4//! a concrete literal value via sqlparser's `VisitorMut`. The visitor
5//! traverses every expression position the AST defines — CTE bodies,
6//! window specs, `Update.from/returning/limit`, `Delete.returning`,
7//! `Insert.on_conflict`, `Expr::Array` elements, `Expr::AnyOp`/`AllOp`
8//! right-hand sides, `Expr::Interval`, and any variant sqlparser adds in
9//! the future — without us maintaining a hand-written walker.
10//!
11//! # Why a visitor, not a hand-written walker
12//!
13//! The previous implementation was a recursive match on ~20 `Expr` /
14//! `Statement` / `Query` variants. Every new sqlparser variant it didn't
15//! enumerate was a silent bug: placeholders survived into the planner and
16//! surfaced as "unsupported expression: $1" in the resolver. The walker
17//! was opt-in where it had to be exhaustive to be correct. `VisitorMut`
18//! moves the exhaustiveness burden to sqlparser itself.
19
20use core::ops::ControlFlow;
21
22use sqlparser::ast::{Statement, Value, VisitMut, VisitorMut};
23
24/// Parameter value for AST substitution.
25///
26/// Converted from pgwire binary parameters + type info in the Control Plane.
27#[derive(Debug, Clone)]
28pub enum ParamValue {
29    Null,
30    Bool(bool),
31    Int64(i64),
32    Float64(f64),
33    Text(String),
34}
35
36/// Substitute all `$N` placeholders in a parsed statement with concrete values.
37pub fn bind_params(stmt: &mut Statement, params: &[ParamValue]) {
38    if params.is_empty() {
39        return;
40    }
41    let mut binder = ParamBinder { params };
42    let _ = stmt.visit(&mut binder);
43}
44
45/// Visitor that rewrites every `Value::Placeholder("$N")` it encounters.
46///
47/// sqlparser's `VisitMut` impls take us into every expression position of
48/// every `Expr`, `Statement`, `Query`, `SetExpr`, `TableFactor`, etc. —
49/// so we only care about the leaf: the `Value` itself.
50struct ParamBinder<'a> {
51    params: &'a [ParamValue],
52}
53
54impl VisitorMut for ParamBinder<'_> {
55    type Break = ();
56
57    fn pre_visit_value(&mut self, value: &mut Value) -> ControlFlow<Self::Break> {
58        if let Value::Placeholder(p) = value
59            && let Some(v) = placeholder_to_value(p, self.params)
60        {
61            *value = v;
62        }
63        ControlFlow::Continue(())
64    }
65}
66
67fn placeholder_to_value(placeholder: &str, params: &[ParamValue]) -> Option<Value> {
68    let idx_str = placeholder.strip_prefix('$')?;
69    let idx: usize = idx_str.parse().ok()?;
70    let param = params.get(idx.checked_sub(1)?)?;
71    Some(match param {
72        ParamValue::Null => Value::Null,
73        ParamValue::Bool(true) => Value::Boolean(true),
74        ParamValue::Bool(false) => Value::Boolean(false),
75        ParamValue::Int64(n) => Value::Number(n.to_string(), false),
76        ParamValue::Float64(f) => Value::Number(f.to_string(), false),
77        ParamValue::Text(s) => Value::SingleQuotedString(s.clone()),
78    })
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84    use crate::parser::statement::parse_sql;
85
86    fn bind_and_format(sql: &str, params: &[ParamValue]) -> String {
87        let mut stmts = parse_sql(sql).unwrap();
88        for stmt in &mut stmts {
89            bind_params(stmt, params);
90        }
91        stmts
92            .iter()
93            .map(|s| s.to_string())
94            .collect::<Vec<_>>()
95            .join("; ")
96    }
97
98    #[test]
99    fn bind_select_where() {
100        let result = bind_and_format(
101            "SELECT * FROM users WHERE id = $1",
102            &[ParamValue::Int64(42)],
103        );
104        assert!(result.contains("id = 42"), "got: {result}");
105    }
106
107    #[test]
108    fn bind_string_param() {
109        let result = bind_and_format(
110            "SELECT * FROM users WHERE name = $1",
111            &[ParamValue::Text("alice".into())],
112        );
113        assert!(result.contains("name = 'alice'"), "got: {result}");
114    }
115
116    #[test]
117    fn bind_null_param() {
118        let result = bind_and_format("SELECT * FROM users WHERE name = $1", &[ParamValue::Null]);
119        assert!(result.contains("name = NULL"), "got: {result}");
120    }
121
122    #[test]
123    fn bind_multiple_params() {
124        let result = bind_and_format(
125            "SELECT * FROM users WHERE age > $1 AND name = $2",
126            &[ParamValue::Int64(18), ParamValue::Text("bob".into())],
127        );
128        assert!(result.contains("age > 18"), "got: {result}");
129        assert!(result.contains("name = 'bob'"), "got: {result}");
130    }
131
132    #[test]
133    fn bind_insert_values() {
134        let result = bind_and_format(
135            "INSERT INTO users (id, name) VALUES ($1, $2)",
136            &[ParamValue::Int64(1), ParamValue::Text("eve".into())],
137        );
138        assert!(result.contains("1, 'eve'"), "got: {result}");
139    }
140
141    #[test]
142    fn bind_bool_param() {
143        let result = bind_and_format(
144            "SELECT * FROM users WHERE active = $1",
145            &[ParamValue::Bool(true)],
146        );
147        assert!(result.contains("active = true"), "got: {result}");
148    }
149
150    #[test]
151    fn no_params_noop() {
152        let result = bind_and_format("SELECT 1", &[]);
153        assert!(result.contains("SELECT 1"));
154    }
155
156    #[test]
157    fn bind_cte_body_placeholders() {
158        let result = bind_and_format(
159            "WITH x AS (SELECT $1 AS v) SELECT v FROM x",
160            &[ParamValue::Int64(42)],
161        );
162        assert!(
163            result.contains("SELECT 42"),
164            "CTE body placeholder not substituted: {result}"
165        );
166        assert!(!result.contains("$1"), "placeholder survived: {result}");
167    }
168
169    #[test]
170    fn bind_recursive_cte_placeholders() {
171        let result = bind_and_format(
172            "WITH RECURSIVE chain AS (SELECT $1 AS id UNION ALL SELECT chain.id + 1 FROM chain WHERE chain.id < $2) SELECT * FROM chain",
173            &[ParamValue::Int64(1), ParamValue::Int64(10)],
174        );
175        assert!(
176            !result.contains("$1") && !result.contains("$2"),
177            "got: {result}"
178        );
179    }
180
181    #[test]
182    fn bind_array_elements() {
183        let result = bind_and_format(
184            "SELECT ARRAY[$1, $2, $3]",
185            &[
186                ParamValue::Int64(1),
187                ParamValue::Int64(2),
188                ParamValue::Int64(3),
189            ],
190        );
191        assert!(
192            !result.contains("$1"),
193            "array placeholder survived: {result}"
194        );
195        assert!(
196            result.contains('1') && result.contains('2') && result.contains('3'),
197            "got: {result}"
198        );
199    }
200
201    #[test]
202    fn bind_any_op_rhs() {
203        let result = bind_and_format(
204            "SELECT * FROM t WHERE id = ANY($1)",
205            &[ParamValue::Text("{a,b}".into())],
206        );
207        assert!(!result.contains("$1"), "got: {result}");
208    }
209
210    #[test]
211    fn bind_all_op_rhs() {
212        let result = bind_and_format(
213            "SELECT * FROM t WHERE id = ALL($1)",
214            &[ParamValue::Text("{a,b}".into())],
215        );
216        assert!(!result.contains("$1"), "got: {result}");
217    }
218
219    #[test]
220    fn bind_update_returning() {
221        let result = bind_and_format(
222            "UPDATE t SET n = 1 WHERE id = $1 RETURNING $2 AS tag",
223            &[ParamValue::Int64(7), ParamValue::Text("note".into())],
224        );
225        assert!(result.contains("id = 7"), "got: {result}");
226        assert!(result.contains("'note'"), "got: {result}");
227        assert!(!result.contains("$2"), "got: {result}");
228    }
229
230    #[test]
231    fn bind_delete_returning() {
232        let result = bind_and_format(
233            "DELETE FROM t WHERE id = $1 RETURNING $2 AS tag",
234            &[ParamValue::Int64(3), ParamValue::Text("gone".into())],
235        );
236        assert!(
237            result.contains("id = 3") && result.contains("'gone'"),
238            "got: {result}"
239        );
240        assert!(!result.contains("$2"), "got: {result}");
241    }
242
243    #[test]
244    fn bind_update_limit() {
245        let result = bind_and_format(
246            "UPDATE t SET n = 1 WHERE id > 0 LIMIT $1",
247            &[ParamValue::Int64(5)],
248        );
249        assert!(!result.contains("$1"), "got: {result}");
250    }
251
252    #[test]
253    fn bind_interval_value_placeholder() {
254        let result = bind_and_format(
255            "SELECT now() - INTERVAL $1",
256            &[ParamValue::Text("1 day".into())],
257        );
258        assert!(!result.contains("$1"), "got: {result}");
259    }
260
261    #[test]
262    fn bind_update_from_subquery() {
263        let result = bind_and_format(
264            "UPDATE t SET n = s.y FROM (SELECT $1 AS y) s WHERE t.id = s.y",
265            &[ParamValue::Int64(7)],
266        );
267        assert!(!result.contains("$1"), "got: {result}");
268    }
269
270    #[test]
271    fn bind_insert_on_conflict_update_placeholder() {
272        let result = bind_and_format(
273            "INSERT INTO t (id, n) VALUES ($1, $2) ON CONFLICT (id) DO UPDATE SET n = $3",
274            &[
275                ParamValue::Int64(1),
276                ParamValue::Int64(2),
277                ParamValue::Int64(99),
278            ],
279        );
280        assert!(!result.contains("$3"), "got: {result}");
281    }
282
283    #[test]
284    fn bind_window_partition_by_placeholder() {
285        let result = bind_and_format(
286            "SELECT LAG(x) OVER (PARTITION BY $1 ORDER BY b) FROM t",
287            &[ParamValue::Text("user_id".into())],
288        );
289        assert!(!result.contains("$1"), "got: {result}");
290    }
291
292    #[test]
293    fn bind_window_order_by_placeholder() {
294        let result = bind_and_format(
295            "SELECT LAG(x) OVER (PARTITION BY a ORDER BY $1) FROM t",
296            &[ParamValue::Text("ts".into())],
297        );
298        assert!(!result.contains("$1"), "got: {result}");
299    }
300}