Skip to main content

nodedb_sql/
params.rs

1//! AST-level parameter binding for prepared statements.
2//!
3//! Replaces `Value::Placeholder("$1")` nodes in the sqlparser AST with
4//! concrete literal values, eliminating the need for SQL text substitution.
5
6use sqlparser::ast::{
7    self, Expr, GroupByExpr, Query, Select, SelectItem, SetExpr, Statement, Value,
8};
9
10/// Parameter value for AST substitution.
11///
12/// Converted from pgwire binary parameters + type info in the Control Plane.
13#[derive(Debug, Clone)]
14pub enum ParamValue {
15    Null,
16    Bool(bool),
17    Int64(i64),
18    Float64(f64),
19    Text(String),
20}
21
22/// Substitute all `$N` placeholders in a parsed statement with concrete values.
23///
24/// Walks the AST and replaces every `Value::Placeholder("$N")` with the
25/// corresponding literal from `params` (0-indexed: `$1` → `params[0]`).
26pub fn bind_params(stmt: &mut Statement, params: &[ParamValue]) {
27    if params.is_empty() {
28        return;
29    }
30    bind_statement(stmt, params);
31}
32
33fn placeholder_to_value(placeholder: &str, params: &[ParamValue]) -> Option<Value> {
34    let idx_str = placeholder.strip_prefix('$')?;
35    let idx: usize = idx_str.parse().ok()?;
36    let param = params.get(idx.checked_sub(1)?)?;
37    Some(match param {
38        ParamValue::Null => Value::Null,
39        ParamValue::Bool(true) => Value::Boolean(true),
40        ParamValue::Bool(false) => Value::Boolean(false),
41        ParamValue::Int64(n) => Value::Number(n.to_string(), false),
42        ParamValue::Float64(f) => Value::Number(f.to_string(), false),
43        ParamValue::Text(s) => Value::SingleQuotedString(s.clone()),
44    })
45}
46
47// ── AST walkers ─────────────────────────────────────────────────────
48
49fn bind_statement(stmt: &mut Statement, params: &[ParamValue]) {
50    match stmt {
51        Statement::Query(q) => bind_query(q, params),
52        Statement::Insert(ins) => {
53            if let Some(ref mut src) = ins.source {
54                bind_query(src, params);
55            }
56            if let Some(ref mut sel) = ins.returning {
57                for item in sel {
58                    bind_select_item(item, params);
59                }
60            }
61        }
62        Statement::Update(upd) => {
63            for a in &mut upd.assignments {
64                bind_expr(&mut a.value, params);
65            }
66            if let Some(ref mut w) = upd.selection {
67                bind_expr(w, params);
68            }
69        }
70        Statement::Delete(del) => {
71            if let Some(ref mut w) = del.selection {
72                bind_expr(w, params);
73            }
74        }
75        _ => {}
76    }
77}
78
79fn bind_query(query: &mut Query, params: &[ParamValue]) {
80    bind_set_expr(&mut query.body, params);
81    if let Some(ref mut order_by) = query.order_by
82        && let ast::OrderByKind::Expressions(ref mut exprs) = order_by.kind
83    {
84        for item in exprs {
85            bind_expr(&mut item.expr, params);
86        }
87    }
88    if let Some(limit_clause) = &mut query.limit_clause
89        && let ast::LimitClause::LimitOffset { limit, offset, .. } = limit_clause
90    {
91        if let Some(limit_expr) = limit {
92            bind_expr(limit_expr, params);
93        }
94        if let Some(offset_val) = offset {
95            bind_expr(&mut offset_val.value, params);
96        }
97    }
98}
99
100fn bind_set_expr(body: &mut SetExpr, params: &[ParamValue]) {
101    match body {
102        SetExpr::Select(sel) => bind_select(sel, params),
103        SetExpr::Query(q) => bind_query(q, params),
104        SetExpr::SetOperation { left, right, .. } => {
105            bind_set_expr(left, params);
106            bind_set_expr(right, params);
107        }
108        SetExpr::Values(vals) => {
109            for row in &mut vals.rows {
110                for expr in row {
111                    bind_expr(expr, params);
112                }
113            }
114        }
115        _ => {}
116    }
117}
118
119fn bind_select(sel: &mut Select, params: &[ParamValue]) {
120    for item in &mut sel.projection {
121        bind_select_item(item, params);
122    }
123    if let Some(ref mut w) = sel.selection {
124        bind_expr(w, params);
125    }
126    match &mut sel.group_by {
127        GroupByExpr::Expressions(exprs, _) => {
128            for e in exprs {
129                bind_expr(e, params);
130            }
131        }
132        GroupByExpr::All(_) => {}
133    }
134    if let Some(ref mut having) = sel.having {
135        bind_expr(having, params);
136    }
137}
138
139fn bind_select_item(item: &mut SelectItem, params: &[ParamValue]) {
140    match item {
141        SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => {
142            bind_expr(e, params);
143        }
144        _ => {}
145    }
146}
147
148fn bind_expr(expr: &mut Expr, params: &[ParamValue]) {
149    match expr {
150        Expr::Value(ast::ValueWithSpan { value, .. }) => {
151            if let Value::Placeholder(p) = value
152                && let Some(v) = placeholder_to_value(p, params)
153            {
154                *value = v;
155            }
156        }
157        Expr::BinaryOp { left, right, .. } => {
158            bind_expr(left, params);
159            bind_expr(right, params);
160        }
161        Expr::UnaryOp { expr: e, .. } => bind_expr(e, params),
162        Expr::Nested(e) => bind_expr(e, params),
163        Expr::Between {
164            expr: e, low, high, ..
165        } => {
166            bind_expr(e, params);
167            bind_expr(low, params);
168            bind_expr(high, params);
169        }
170        Expr::InList { expr: e, list, .. } => {
171            bind_expr(e, params);
172            for item in list {
173                bind_expr(item, params);
174            }
175        }
176        Expr::InSubquery {
177            expr: e, subquery, ..
178        } => {
179            bind_expr(e, params);
180            bind_query(subquery, params);
181        }
182        Expr::IsNull(e) | Expr::IsNotNull(e) => bind_expr(e, params),
183        Expr::IsFalse(e) | Expr::IsTrue(e) => bind_expr(e, params),
184        Expr::IsNotFalse(e) | Expr::IsNotTrue(e) => bind_expr(e, params),
185        Expr::Like {
186            expr: e, pattern, ..
187        }
188        | Expr::ILike {
189            expr: e, pattern, ..
190        } => {
191            bind_expr(e, params);
192            bind_expr(pattern, params);
193        }
194        Expr::Cast { expr: e, .. } => {
195            bind_expr(e, params);
196        }
197        Expr::Function(f) => {
198            if let ast::FunctionArguments::List(ref mut args) = f.args {
199                for arg in &mut args.args {
200                    if let ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) = arg {
201                        bind_expr(e, params);
202                    }
203                }
204            }
205        }
206        Expr::Case {
207            operand,
208            conditions,
209            else_result,
210            ..
211        } => {
212            if let Some(e) = operand {
213                bind_expr(e, params);
214            }
215            for cw in conditions {
216                bind_expr(&mut cw.condition, params);
217                bind_expr(&mut cw.result, params);
218            }
219            if let Some(e) = else_result {
220                bind_expr(e, params);
221            }
222        }
223        Expr::Exists { subquery, .. } => bind_query(subquery, params),
224        Expr::Subquery(q) => bind_query(q, params),
225        _ => {}
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use crate::parser::statement::parse_sql;
233
234    fn bind_and_format(sql: &str, params: &[ParamValue]) -> String {
235        let mut stmts = parse_sql(sql).unwrap();
236        for stmt in &mut stmts {
237            bind_params(stmt, params);
238        }
239        stmts
240            .iter()
241            .map(|s| s.to_string())
242            .collect::<Vec<_>>()
243            .join("; ")
244    }
245
246    #[test]
247    fn bind_select_where() {
248        let result = bind_and_format(
249            "SELECT * FROM users WHERE id = $1",
250            &[ParamValue::Int64(42)],
251        );
252        assert!(result.contains("id = 42"), "got: {result}");
253    }
254
255    #[test]
256    fn bind_string_param() {
257        let result = bind_and_format(
258            "SELECT * FROM users WHERE name = $1",
259            &[ParamValue::Text("alice".into())],
260        );
261        assert!(result.contains("name = 'alice'"), "got: {result}");
262    }
263
264    #[test]
265    fn bind_null_param() {
266        let result = bind_and_format("SELECT * FROM users WHERE name = $1", &[ParamValue::Null]);
267        assert!(result.contains("name = NULL"), "got: {result}");
268    }
269
270    #[test]
271    fn bind_multiple_params() {
272        let result = bind_and_format(
273            "SELECT * FROM users WHERE age > $1 AND name = $2",
274            &[ParamValue::Int64(18), ParamValue::Text("bob".into())],
275        );
276        assert!(result.contains("age > 18"), "got: {result}");
277        assert!(result.contains("name = 'bob'"), "got: {result}");
278    }
279
280    #[test]
281    fn bind_insert_values() {
282        let result = bind_and_format(
283            "INSERT INTO users (id, name) VALUES ($1, $2)",
284            &[ParamValue::Int64(1), ParamValue::Text("eve".into())],
285        );
286        assert!(result.contains("1, 'eve'"), "got: {result}");
287    }
288
289    #[test]
290    fn bind_bool_param() {
291        let result = bind_and_format(
292            "SELECT * FROM users WHERE active = $1",
293            &[ParamValue::Bool(true)],
294        );
295        assert!(result.contains("active = true"), "got: {result}");
296    }
297
298    #[test]
299    fn no_params_noop() {
300        let result = bind_and_format("SELECT 1", &[]);
301        assert!(result.contains("SELECT 1"));
302    }
303}