use std::ops::ControlFlow;
use sqlparser::ast::{
Expr, Ident, Statement, Value as AstValue, ValueWithSpan, visit_expressions_mut,
};
use sqlparser::tokenizer::Span;
use crate::error::{Result, SQLRiteError};
use crate::sql::db::table::Value;
pub fn rewrite_placeholders(stmt: &mut Statement) -> usize {
let mut counter: usize = 0;
let _ = visit_expressions_mut(stmt, |expr| {
if let Expr::Value(v) = expr
&& let AstValue::Placeholder(s) = &mut v.value
&& s == "?"
{
counter += 1;
*s = format!("?{counter}");
}
ControlFlow::<()>::Continue(())
});
counter
}
pub fn substitute_params(stmt: &mut Statement, params: &[Value]) -> Result<()> {
let mut bind_err: Option<SQLRiteError> = None;
let _ = visit_expressions_mut(stmt, |expr| {
let Expr::Value(v) = expr else {
return ControlFlow::Continue(());
};
let placeholder_str = match &v.value {
AstValue::Placeholder(s) => s.clone(),
_ => return ControlFlow::Continue(()),
};
let idx = match placeholder_index(&placeholder_str) {
Some(i) => i,
None => {
bind_err = Some(SQLRiteError::NotImplemented(format!(
"unsupported placeholder form `{placeholder_str}`; only `?` and `?N` are supported"
)));
return ControlFlow::Break(());
}
};
let Some(value) = params.get(idx) else {
bind_err = Some(SQLRiteError::General(format!(
"missing bind value for `?{}` (got {} parameter{})",
idx + 1,
params.len(),
if params.len() == 1 { "" } else { "s" }
)));
return ControlFlow::Break(());
};
*expr = value_to_expr(value);
ControlFlow::<()>::Continue(())
});
if let Some(e) = bind_err {
return Err(e);
}
Ok(())
}
fn placeholder_index(s: &str) -> Option<usize> {
let n = s.strip_prefix('?')?.parse::<usize>().ok()?;
if n == 0 {
return None;
}
Some(n - 1)
}
fn value_to_expr(v: &Value) -> Expr {
match v {
Value::Null => Expr::Value(ValueWithSpan {
value: AstValue::Null,
span: Span::empty(),
}),
Value::Integer(i) => Expr::Value(ValueWithSpan {
value: AstValue::Number(i.to_string(), false),
span: Span::empty(),
}),
Value::Real(f) => Expr::Value(ValueWithSpan {
value: AstValue::Number(f.to_string(), false),
span: Span::empty(),
}),
Value::Text(s) => Expr::Value(ValueWithSpan {
value: AstValue::SingleQuotedString(s.clone()),
span: Span::empty(),
}),
Value::Bool(b) => Expr::Value(ValueWithSpan {
value: AstValue::Boolean(*b),
span: Span::empty(),
}),
Value::Vector(v) => {
let inner = format_vector_inner(v);
Expr::Identifier(Ident {
value: inner,
quote_style: Some('['),
span: Span::empty(),
})
}
}
}
fn format_vector_inner(v: &[f32]) -> String {
let mut s = String::with_capacity(v.len() * 10);
for (i, x) in v.iter().enumerate() {
if i > 0 {
s.push_str(", ");
}
s.push_str(&x.to_string());
}
s
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sql::dialect::SqlriteDialect;
use sqlparser::parser::Parser;
fn parse_one(sql: &str) -> Statement {
let mut ast = Parser::parse_sql(&SqlriteDialect::new(), sql).unwrap();
ast.pop().unwrap()
}
#[test]
fn rewrite_assigns_indices_in_source_order() {
let mut stmt = parse_one("SELECT * FROM t WHERE a = ? AND b = ? AND c = ?");
let n = rewrite_placeholders(&mut stmt);
assert_eq!(n, 3);
let sql = stmt.to_string();
assert!(sql.contains("?1"));
assert!(sql.contains("?2"));
assert!(sql.contains("?3"));
}
#[test]
fn rewrite_zero_for_no_placeholders() {
let mut stmt = parse_one("SELECT * FROM t WHERE a = 1");
assert_eq!(rewrite_placeholders(&mut stmt), 0);
}
#[test]
fn rewrite_idempotent_on_numbered_placeholders() {
let mut stmt = parse_one("SELECT * FROM t WHERE a = ?1 AND b = ?2");
let n = rewrite_placeholders(&mut stmt);
assert_eq!(n, 0);
}
#[test]
fn substitute_replaces_scalar_params() {
let mut stmt = parse_one("SELECT * FROM t WHERE a = ? AND b = ? AND c = ?");
rewrite_placeholders(&mut stmt);
substitute_params(
&mut stmt,
&[
Value::Integer(1),
Value::Text("x".into()),
Value::Bool(true),
],
)
.unwrap();
let sql = stmt.to_string();
assert!(sql.contains("a = 1"), "got: {sql}");
assert!(sql.contains("b = 'x'"), "got: {sql}");
assert!(sql.contains("c = true"), "got: {sql}");
}
#[test]
fn substitute_replaces_vector_param_as_bracket_array() {
let mut stmt = parse_one("SELECT id FROM t ORDER BY vec_distance_l2(v, ?) LIMIT 5");
rewrite_placeholders(&mut stmt);
substitute_params(&mut stmt, &[Value::Vector(vec![0.1, 0.2, 0.3])]).unwrap();
let sql = stmt.to_string();
assert!(sql.contains("[0.1, 0.2, 0.3]"), "got: {sql}");
}
#[test]
fn substitute_errors_on_too_few_params() {
let mut stmt = parse_one("SELECT * FROM t WHERE a = ? AND b = ?");
rewrite_placeholders(&mut stmt);
let err = substitute_params(&mut stmt, &[Value::Integer(1)]).unwrap_err();
assert!(format!("{err}").contains("missing bind value"));
}
#[test]
fn substitute_replaces_null_param() {
let mut stmt = parse_one("SELECT * FROM t WHERE a = ?");
rewrite_placeholders(&mut stmt);
substitute_params(&mut stmt, &[Value::Null]).unwrap();
let sql = stmt.to_string();
assert!(sql.to_uppercase().contains("NULL"), "got: {sql}");
}
#[test]
fn placeholder_index_decodes_canonical_form() {
assert_eq!(placeholder_index("?1"), Some(0));
assert_eq!(placeholder_index("?42"), Some(41));
assert_eq!(placeholder_index("?"), None);
assert_eq!(placeholder_index("?0"), None);
assert_eq!(placeholder_index(":name"), None);
assert_eq!(placeholder_index("$1"), None);
}
}