Skip to main content

sqlrite/sql/
params.rs

1//! Prepared-statement parameter binding (SQLR-23).
2//!
3//! Two responsibilities:
4//!
5//! 1. **Placeholder rewriting at prepare time.** The user writes `?` in
6//!    the SQL; sqlparser parses each as `Expr::Value(Placeholder("?"))`.
7//!    We walk the parsed AST left-to-right and rewrite each bare `?` to
8//!    `?N` (1-indexed source order) so the later substitution pass knows
9//!    which slot to bind. The rewritten AST is what `Statement` caches.
10//!
11//! 2. **Substitution at execute time.** Given the cached AST and a
12//!    `&[Value]` slice, walk a clone of the AST and replace every
13//!    `Expr::Value(Placeholder("?N"))` with the matching `params[N-1]`.
14//!
15//! Substitution lowers the bound value into a node shape the rest of the
16//! pipeline already understands:
17//!
18//! - Scalars (`Integer`, `Real`, `Text`, `Bool`, `Null`) become
19//!   `Expr::Value(...)` literals — same shape an inline literal would
20//!   parse to. Existing executor / parser arms handle them unchanged.
21//! - Vectors become `Expr::Identifier { quote_style: Some('['), value: "<csv>" }`,
22//!   which is the in-band form sqlparser produces for inline bracket-array
23//!   literals like `[0.1, 0.2, ...]`. The INSERT parser, the executor's
24//!   `eval_expr_scope`, and the HNSW probe optimizer all already recognize
25//!   that shape, so a bound `Value::Vector(...)` flows through every path
26//!   that an inline `[...]` literal does — including the HNSW shortcut.
27//!
28//! Doing it as an AST-rewrite (rather than threading `&[Value]` through
29//! the executor) keeps the diff focused: every existing executor arm
30//! sees concrete literals, exactly as it does today on inline-params SQL.
31
32use std::ops::ControlFlow;
33
34use sqlparser::ast::{
35    Expr, Ident, Statement, Value as AstValue, ValueWithSpan, visit_expressions_mut,
36};
37use sqlparser::tokenizer::Span;
38
39use crate::error::{Result, SQLRiteError};
40use crate::sql::db::table::Value;
41
42/// Walks every expression in `stmt` and rewrites bare `?` placeholders to
43/// `?N` (1-indexed source order). Returns the total parameter count.
44///
45/// Idempotent for already-numbered placeholders: `?1`, `?2`, … pass
46/// through unchanged. We deliberately don't try to *renumber* already-
47/// numbered placeholders — that's a foot-gun (the user might use the
48/// same index twice on purpose to bind once and reference twice), and
49/// `Statement::new` runs this exactly once on a freshly-parsed AST.
50pub fn rewrite_placeholders(stmt: &mut Statement) -> usize {
51    let mut counter: usize = 0;
52    let _ = visit_expressions_mut(stmt, |expr| {
53        if let Expr::Value(v) = expr
54            && let AstValue::Placeholder(s) = &mut v.value
55            && s == "?"
56        {
57            counter += 1;
58            *s = format!("?{counter}");
59        }
60        ControlFlow::<()>::Continue(())
61    });
62    counter
63}
64
65/// Substitutes every `?N` placeholder in `stmt` with the matching value
66/// from `params`. Mutates the AST in place — callers should clone first
67/// if they want the original back.
68///
69/// Errors if the AST references a placeholder index outside `params`,
70/// or if a non-canonical placeholder form (`:name`, `$1`) is encountered.
71pub fn substitute_params(stmt: &mut Statement, params: &[Value]) -> Result<()> {
72    let mut bind_err: Option<SQLRiteError> = None;
73    let _ = visit_expressions_mut(stmt, |expr| {
74        let Expr::Value(v) = expr else {
75            return ControlFlow::Continue(());
76        };
77        let placeholder_str = match &v.value {
78            AstValue::Placeholder(s) => s.clone(),
79            _ => return ControlFlow::Continue(()),
80        };
81        let idx = match placeholder_index(&placeholder_str) {
82            Some(i) => i,
83            None => {
84                bind_err = Some(SQLRiteError::NotImplemented(format!(
85                    "unsupported placeholder form `{placeholder_str}`; only `?` and `?N` are supported"
86                )));
87                return ControlFlow::Break(());
88            }
89        };
90        let Some(value) = params.get(idx) else {
91            bind_err = Some(SQLRiteError::General(format!(
92                "missing bind value for `?{}` (got {} parameter{})",
93                idx + 1,
94                params.len(),
95                if params.len() == 1 { "" } else { "s" }
96            )));
97            return ControlFlow::Break(());
98        };
99        *expr = value_to_expr(value);
100        ControlFlow::<()>::Continue(())
101    });
102    if let Some(e) = bind_err {
103        return Err(e);
104    }
105    Ok(())
106}
107
108/// Decode a `Placeholder("?N")` string into its 0-indexed slot. Returns
109/// `None` for any non-canonical form (`:name`, `$1`, bare `?` after
110/// rewriting — that last case shouldn't happen but is rejected
111/// defensively).
112fn placeholder_index(s: &str) -> Option<usize> {
113    let n = s.strip_prefix('?')?.parse::<usize>().ok()?;
114    if n == 0 {
115        return None;
116    }
117    Some(n - 1)
118}
119
120/// Build the AST `Expr` equivalent of a runtime `Value`. The shapes
121/// match what `sqlparser` produces for inline literals so downstream
122/// executor code paths don't need to change.
123fn value_to_expr(v: &Value) -> Expr {
124    match v {
125        Value::Null => Expr::Value(ValueWithSpan {
126            value: AstValue::Null,
127            span: Span::empty(),
128        }),
129        Value::Integer(i) => Expr::Value(ValueWithSpan {
130            value: AstValue::Number(i.to_string(), false),
131            span: Span::empty(),
132        }),
133        Value::Real(f) => Expr::Value(ValueWithSpan {
134            // f64::Display picks the shortest round-tripping form;
135            // re-parsing it back via str::parse::<f64> is exact.
136            value: AstValue::Number(f.to_string(), false),
137            span: Span::empty(),
138        }),
139        Value::Text(s) => Expr::Value(ValueWithSpan {
140            value: AstValue::SingleQuotedString(s.clone()),
141            span: Span::empty(),
142        }),
143        Value::Bool(b) => Expr::Value(ValueWithSpan {
144            value: AstValue::Boolean(*b),
145            span: Span::empty(),
146        }),
147        Value::Vector(v) => {
148            // Inline bracket-array form. `i.value` carries the inner
149            // CSV without brackets — `format!("[{}]", i.value)` at the
150            // consumer side reconstructs the literal that
151            // `parse_vector_literal` accepts.
152            let inner = format_vector_inner(v);
153            Expr::Identifier(Ident {
154                value: inner,
155                quote_style: Some('['),
156                span: Span::empty(),
157            })
158        }
159    }
160}
161
162fn format_vector_inner(v: &[f32]) -> String {
163    // Preallocate generously: each f32 averages ~8 chars + ", ".
164    let mut s = String::with_capacity(v.len() * 10);
165    for (i, x) in v.iter().enumerate() {
166        if i > 0 {
167            s.push_str(", ");
168        }
169        s.push_str(&x.to_string());
170    }
171    s
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use sqlparser::dialect::SQLiteDialect;
178    use sqlparser::parser::Parser;
179
180    fn parse_one(sql: &str) -> Statement {
181        let mut ast = Parser::parse_sql(&SQLiteDialect {}, sql).unwrap();
182        ast.pop().unwrap()
183    }
184
185    #[test]
186    fn rewrite_assigns_indices_in_source_order() {
187        let mut stmt = parse_one("SELECT * FROM t WHERE a = ? AND b = ? AND c = ?");
188        let n = rewrite_placeholders(&mut stmt);
189        assert_eq!(n, 3);
190        let sql = stmt.to_string();
191        assert!(sql.contains("?1"));
192        assert!(sql.contains("?2"));
193        assert!(sql.contains("?3"));
194    }
195
196    #[test]
197    fn rewrite_zero_for_no_placeholders() {
198        let mut stmt = parse_one("SELECT * FROM t WHERE a = 1");
199        assert_eq!(rewrite_placeholders(&mut stmt), 0);
200    }
201
202    #[test]
203    fn rewrite_idempotent_on_numbered_placeholders() {
204        // `?1` parses with placeholder string `?1`. Walking again must
205        // not double-number.
206        let mut stmt = parse_one("SELECT * FROM t WHERE a = ?1 AND b = ?2");
207        let n = rewrite_placeholders(&mut stmt);
208        // Bare `?` count is zero — the existing `?1`/`?2` are left
209        // alone. The total parameter count is therefore reported as 0
210        // here; callers using `?N` form should already know their
211        // arity from the source SQL.
212        assert_eq!(n, 0);
213    }
214
215    #[test]
216    fn substitute_replaces_scalar_params() {
217        let mut stmt = parse_one("SELECT * FROM t WHERE a = ? AND b = ? AND c = ?");
218        rewrite_placeholders(&mut stmt);
219        substitute_params(
220            &mut stmt,
221            &[
222                Value::Integer(1),
223                Value::Text("x".into()),
224                Value::Bool(true),
225            ],
226        )
227        .unwrap();
228        let sql = stmt.to_string();
229        assert!(sql.contains("a = 1"), "got: {sql}");
230        assert!(sql.contains("b = 'x'"), "got: {sql}");
231        // sqlparser renders Boolean::true as `true`.
232        assert!(sql.contains("c = true"), "got: {sql}");
233    }
234
235    #[test]
236    fn substitute_replaces_vector_param_as_bracket_array() {
237        let mut stmt = parse_one("SELECT id FROM t ORDER BY vec_distance_l2(v, ?) LIMIT 5");
238        rewrite_placeholders(&mut stmt);
239        substitute_params(&mut stmt, &[Value::Vector(vec![0.1, 0.2, 0.3])]).unwrap();
240        let sql = stmt.to_string();
241        // sqlparser renders bracket-quoted Identifier as `[<inner>]`.
242        assert!(sql.contains("[0.1, 0.2, 0.3]"), "got: {sql}");
243    }
244
245    #[test]
246    fn substitute_errors_on_too_few_params() {
247        let mut stmt = parse_one("SELECT * FROM t WHERE a = ? AND b = ?");
248        rewrite_placeholders(&mut stmt);
249        let err = substitute_params(&mut stmt, &[Value::Integer(1)]).unwrap_err();
250        assert!(format!("{err}").contains("missing bind value"));
251    }
252
253    #[test]
254    fn substitute_replaces_null_param() {
255        let mut stmt = parse_one("SELECT * FROM t WHERE a = ?");
256        rewrite_placeholders(&mut stmt);
257        substitute_params(&mut stmt, &[Value::Null]).unwrap();
258        let sql = stmt.to_string();
259        assert!(sql.to_uppercase().contains("NULL"), "got: {sql}");
260    }
261
262    #[test]
263    fn placeholder_index_decodes_canonical_form() {
264        assert_eq!(placeholder_index("?1"), Some(0));
265        assert_eq!(placeholder_index("?42"), Some(41));
266        assert_eq!(placeholder_index("?"), None);
267        assert_eq!(placeholder_index("?0"), None);
268        assert_eq!(placeholder_index(":name"), None);
269        assert_eq!(placeholder_index("$1"), None);
270    }
271}