Skip to main content

nodedb_sql/
params.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! AST-level parameter binding for prepared statements.
4//!
5//! Every `Value::Placeholder("$N")` in a parsed statement is rewritten to
6//! a concrete literal value via sqlparser's `VisitorMut`. The visitor
7//! traverses every expression position the AST defines — CTE bodies,
8//! window specs, `Update.from/returning/limit`, `Delete.returning`,
9//! `Insert.on_conflict`, `Expr::Array` elements, `Expr::AnyOp`/`AllOp`
10//! right-hand sides, `Expr::Interval`, and any variant sqlparser adds in
11//! the future — without us maintaining a hand-written walker.
12//!
13//! # Why a visitor, not a hand-written walker
14//!
15//! The previous implementation was a recursive match on ~20 `Expr` /
16//! `Statement` / `Query` variants. Every new sqlparser variant it didn't
17//! enumerate was a silent bug: placeholders survived into the planner and
18//! surfaced as "unsupported expression: $1" in the resolver. The walker
19//! was opt-in where it had to be exhaustive to be correct. `VisitorMut`
20//! moves the exhaustiveness burden to sqlparser itself.
21
22use core::ops::ControlFlow;
23
24use sqlparser::ast::{Statement, Value, VisitMut, VisitorMut};
25
26/// Parameter value for AST substitution.
27///
28/// Converted from pgwire binary parameters + type info in the Control Plane.
29#[derive(Debug, Clone)]
30pub enum ParamValue {
31    Null,
32    Bool(bool),
33    Int64(i64),
34    Float64(f64),
35    /// Exact arbitrary-precision decimal from a NUMERIC/DECIMAL pgwire parameter.
36    Decimal(rust_decimal::Decimal),
37    Text(String),
38    /// Naive (no-timezone) timestamp from a TIMESTAMP pgwire parameter.
39    Timestamp(nodedb_types::datetime::NdbDateTime),
40    /// UTC (timezone-aware) timestamp from a TIMESTAMPTZ pgwire parameter.
41    Timestamptz(nodedb_types::datetime::NdbDateTime),
42}
43
44/// Substitute all `$N` placeholders in a parsed statement with concrete values.
45pub fn bind_params(stmt: &mut Statement, params: &[ParamValue]) {
46    if params.is_empty() {
47        return;
48    }
49    let mut binder = ParamBinder { params };
50    let _ = stmt.visit(&mut binder);
51}
52
53/// Visitor that rewrites every `Value::Placeholder("$N")` it encounters.
54///
55/// sqlparser's `VisitMut` impls take us into every expression position of
56/// every `Expr`, `Statement`, `Query`, `SetExpr`, `TableFactor`, etc. —
57/// so we only care about the leaf: the `Value` itself.
58struct ParamBinder<'a> {
59    params: &'a [ParamValue],
60}
61
62impl VisitorMut for ParamBinder<'_> {
63    type Break = ();
64
65    fn pre_visit_value(&mut self, value: &mut Value) -> ControlFlow<Self::Break> {
66        if let Value::Placeholder(p) = value
67            && let Some(v) = placeholder_to_value(p, self.params)
68        {
69            *value = v;
70        }
71        ControlFlow::Continue(())
72    }
73}
74
75fn placeholder_to_value(placeholder: &str, params: &[ParamValue]) -> Option<Value> {
76    let idx_str = placeholder.strip_prefix('$')?;
77    let idx: usize = idx_str.parse().ok()?;
78    let param = params.get(idx.checked_sub(1)?)?;
79    Some(match param {
80        ParamValue::Null => Value::Null,
81        ParamValue::Bool(true) => Value::Boolean(true),
82        ParamValue::Bool(false) => Value::Boolean(false),
83        ParamValue::Int64(n) => Value::Number(n.to_string(), false),
84        ParamValue::Float64(f) => Value::Number(f.to_string(), false),
85        ParamValue::Decimal(d) => Value::Number(d.to_string(), false),
86        ParamValue::Text(s) => Value::SingleQuotedString(s.clone()),
87        // Timestamp/Timestamptz: emit as a typed SQL literal so the resolver
88        // can produce the correct SqlValue variant (Timestamp vs Timestamptz).
89        ParamValue::Timestamp(dt) => Value::SingleQuotedString(dt.to_iso8601()),
90        ParamValue::Timestamptz(dt) => Value::SingleQuotedString(dt.to_iso8601()),
91    })
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97    use crate::parser::statement::parse_sql;
98
99    fn bind_and_format(sql: &str, params: &[ParamValue]) -> String {
100        let mut stmts = parse_sql(sql).unwrap();
101        for stmt in &mut stmts {
102            bind_params(stmt, params);
103        }
104        stmts
105            .iter()
106            .map(|s| s.to_string())
107            .collect::<Vec<_>>()
108            .join("; ")
109    }
110
111    #[test]
112    fn bind_select_where() {
113        let result = bind_and_format(
114            "SELECT * FROM users WHERE id = $1",
115            &[ParamValue::Int64(42)],
116        );
117        assert!(result.contains("id = 42"), "got: {result}");
118    }
119
120    #[test]
121    fn bind_string_param() {
122        let result = bind_and_format(
123            "SELECT * FROM users WHERE name = $1",
124            &[ParamValue::Text("alice".into())],
125        );
126        assert!(result.contains("name = 'alice'"), "got: {result}");
127    }
128
129    #[test]
130    fn bind_null_param() {
131        let result = bind_and_format("SELECT * FROM users WHERE name = $1", &[ParamValue::Null]);
132        assert!(result.contains("name = NULL"), "got: {result}");
133    }
134
135    #[test]
136    fn bind_multiple_params() {
137        let result = bind_and_format(
138            "SELECT * FROM users WHERE age > $1 AND name = $2",
139            &[ParamValue::Int64(18), ParamValue::Text("bob".into())],
140        );
141        assert!(result.contains("age > 18"), "got: {result}");
142        assert!(result.contains("name = 'bob'"), "got: {result}");
143    }
144
145    #[test]
146    fn bind_insert_values() {
147        let result = bind_and_format(
148            "INSERT INTO users (id, name) VALUES ($1, $2)",
149            &[ParamValue::Int64(1), ParamValue::Text("eve".into())],
150        );
151        assert!(result.contains("1, 'eve'"), "got: {result}");
152    }
153
154    #[test]
155    fn bind_bool_param() {
156        let result = bind_and_format(
157            "SELECT * FROM users WHERE active = $1",
158            &[ParamValue::Bool(true)],
159        );
160        assert!(result.contains("active = true"), "got: {result}");
161    }
162
163    #[test]
164    fn no_params_noop() {
165        let result = bind_and_format("SELECT 1", &[]);
166        assert!(result.contains("SELECT 1"));
167    }
168
169    #[test]
170    fn bind_cte_body_placeholders() {
171        let result = bind_and_format(
172            "WITH x AS (SELECT $1 AS v) SELECT v FROM x",
173            &[ParamValue::Int64(42)],
174        );
175        assert!(
176            result.contains("SELECT 42"),
177            "CTE body placeholder not substituted: {result}"
178        );
179        assert!(!result.contains("$1"), "placeholder survived: {result}");
180    }
181
182    #[test]
183    fn bind_recursive_cte_placeholders() {
184        let result = bind_and_format(
185            "WITH RECURSIVE chain AS (SELECT $1 AS id UNION ALL SELECT chain.id + 1 FROM chain WHERE chain.id < $2) SELECT * FROM chain",
186            &[ParamValue::Int64(1), ParamValue::Int64(10)],
187        );
188        assert!(
189            !result.contains("$1") && !result.contains("$2"),
190            "got: {result}"
191        );
192    }
193
194    #[test]
195    fn bind_array_elements() {
196        let result = bind_and_format(
197            "SELECT ARRAY[$1, $2, $3]",
198            &[
199                ParamValue::Int64(1),
200                ParamValue::Int64(2),
201                ParamValue::Int64(3),
202            ],
203        );
204        assert!(
205            !result.contains("$1"),
206            "array placeholder survived: {result}"
207        );
208        assert!(
209            result.contains('1') && result.contains('2') && result.contains('3'),
210            "got: {result}"
211        );
212    }
213
214    #[test]
215    fn bind_any_op_rhs() {
216        let result = bind_and_format(
217            "SELECT * FROM t WHERE id = ANY($1)",
218            &[ParamValue::Text("{a,b}".into())],
219        );
220        assert!(!result.contains("$1"), "got: {result}");
221    }
222
223    #[test]
224    fn bind_all_op_rhs() {
225        let result = bind_and_format(
226            "SELECT * FROM t WHERE id = ALL($1)",
227            &[ParamValue::Text("{a,b}".into())],
228        );
229        assert!(!result.contains("$1"), "got: {result}");
230    }
231
232    #[test]
233    fn bind_update_returning() {
234        let result = bind_and_format(
235            "UPDATE t SET n = 1 WHERE id = $1 RETURNING $2 AS tag",
236            &[ParamValue::Int64(7), ParamValue::Text("note".into())],
237        );
238        assert!(result.contains("id = 7"), "got: {result}");
239        assert!(result.contains("'note'"), "got: {result}");
240        assert!(!result.contains("$2"), "got: {result}");
241    }
242
243    #[test]
244    fn bind_delete_returning() {
245        let result = bind_and_format(
246            "DELETE FROM t WHERE id = $1 RETURNING $2 AS tag",
247            &[ParamValue::Int64(3), ParamValue::Text("gone".into())],
248        );
249        assert!(
250            result.contains("id = 3") && result.contains("'gone'"),
251            "got: {result}"
252        );
253        assert!(!result.contains("$2"), "got: {result}");
254    }
255
256    #[test]
257    fn bind_update_limit() {
258        let result = bind_and_format(
259            "UPDATE t SET n = 1 WHERE id > 0 LIMIT $1",
260            &[ParamValue::Int64(5)],
261        );
262        assert!(!result.contains("$1"), "got: {result}");
263    }
264
265    #[test]
266    fn bind_interval_value_placeholder() {
267        let result = bind_and_format(
268            "SELECT now() - INTERVAL $1",
269            &[ParamValue::Text("1 day".into())],
270        );
271        assert!(!result.contains("$1"), "got: {result}");
272    }
273
274    #[test]
275    fn bind_update_from_subquery() {
276        let result = bind_and_format(
277            "UPDATE t SET n = s.y FROM (SELECT $1 AS y) s WHERE t.id = s.y",
278            &[ParamValue::Int64(7)],
279        );
280        assert!(!result.contains("$1"), "got: {result}");
281    }
282
283    #[test]
284    fn bind_insert_on_conflict_update_placeholder() {
285        let result = bind_and_format(
286            "INSERT INTO t (id, n) VALUES ($1, $2) ON CONFLICT (id) DO UPDATE SET n = $3",
287            &[
288                ParamValue::Int64(1),
289                ParamValue::Int64(2),
290                ParamValue::Int64(99),
291            ],
292        );
293        assert!(!result.contains("$3"), "got: {result}");
294    }
295
296    #[test]
297    fn bind_window_partition_by_placeholder() {
298        let result = bind_and_format(
299            "SELECT LAG(x) OVER (PARTITION BY $1 ORDER BY b) FROM t",
300            &[ParamValue::Text("user_id".into())],
301        );
302        assert!(!result.contains("$1"), "got: {result}");
303    }
304
305    #[test]
306    fn bind_window_order_by_placeholder() {
307        let result = bind_and_format(
308            "SELECT LAG(x) OVER (PARTITION BY a ORDER BY $1) FROM t",
309            &[ParamValue::Text("ts".into())],
310        );
311        assert!(!result.contains("$1"), "got: {result}");
312    }
313}