1use sqlparser::ast::{
7 self, Expr, GroupByExpr, Query, Select, SelectItem, SetExpr, Statement, Value,
8};
9
10#[derive(Debug, Clone)]
14pub enum ParamValue {
15 Null,
16 Bool(bool),
17 Int64(i64),
18 Float64(f64),
19 Text(String),
20}
21
22pub fn bind_params(stmt: &mut Statement, params: &[ParamValue]) {
27 if params.is_empty() {
28 return;
29 }
30 bind_statement(stmt, params);
31}
32
33fn placeholder_to_value(placeholder: &str, params: &[ParamValue]) -> Option<Value> {
34 let idx_str = placeholder.strip_prefix('$')?;
35 let idx: usize = idx_str.parse().ok()?;
36 let param = params.get(idx.checked_sub(1)?)?;
37 Some(match param {
38 ParamValue::Null => Value::Null,
39 ParamValue::Bool(true) => Value::Boolean(true),
40 ParamValue::Bool(false) => Value::Boolean(false),
41 ParamValue::Int64(n) => Value::Number(n.to_string(), false),
42 ParamValue::Float64(f) => Value::Number(f.to_string(), false),
43 ParamValue::Text(s) => Value::SingleQuotedString(s.clone()),
44 })
45}
46
47fn bind_statement(stmt: &mut Statement, params: &[ParamValue]) {
50 match stmt {
51 Statement::Query(q) => bind_query(q, params),
52 Statement::Insert(ins) => {
53 if let Some(ref mut src) = ins.source {
54 bind_query(src, params);
55 }
56 if let Some(ref mut sel) = ins.returning {
57 for item in sel {
58 bind_select_item(item, params);
59 }
60 }
61 }
62 Statement::Update(upd) => {
63 for a in &mut upd.assignments {
64 bind_expr(&mut a.value, params);
65 }
66 if let Some(ref mut w) = upd.selection {
67 bind_expr(w, params);
68 }
69 }
70 Statement::Delete(del) => {
71 if let Some(ref mut w) = del.selection {
72 bind_expr(w, params);
73 }
74 }
75 _ => {}
76 }
77}
78
79fn bind_query(query: &mut Query, params: &[ParamValue]) {
80 bind_set_expr(&mut query.body, params);
81 if let Some(ref mut order_by) = query.order_by
82 && let ast::OrderByKind::Expressions(ref mut exprs) = order_by.kind
83 {
84 for item in exprs {
85 bind_expr(&mut item.expr, params);
86 }
87 }
88 if let Some(limit_clause) = &mut query.limit_clause
89 && let ast::LimitClause::LimitOffset { limit, offset, .. } = limit_clause
90 {
91 if let Some(limit_expr) = limit {
92 bind_expr(limit_expr, params);
93 }
94 if let Some(offset_val) = offset {
95 bind_expr(&mut offset_val.value, params);
96 }
97 }
98}
99
100fn bind_set_expr(body: &mut SetExpr, params: &[ParamValue]) {
101 match body {
102 SetExpr::Select(sel) => bind_select(sel, params),
103 SetExpr::Query(q) => bind_query(q, params),
104 SetExpr::SetOperation { left, right, .. } => {
105 bind_set_expr(left, params);
106 bind_set_expr(right, params);
107 }
108 SetExpr::Values(vals) => {
109 for row in &mut vals.rows {
110 for expr in row {
111 bind_expr(expr, params);
112 }
113 }
114 }
115 _ => {}
116 }
117}
118
119fn bind_select(sel: &mut Select, params: &[ParamValue]) {
120 for item in &mut sel.projection {
121 bind_select_item(item, params);
122 }
123 if let Some(ref mut w) = sel.selection {
124 bind_expr(w, params);
125 }
126 match &mut sel.group_by {
127 GroupByExpr::Expressions(exprs, _) => {
128 for e in exprs {
129 bind_expr(e, params);
130 }
131 }
132 GroupByExpr::All(_) => {}
133 }
134 if let Some(ref mut having) = sel.having {
135 bind_expr(having, params);
136 }
137}
138
139fn bind_select_item(item: &mut SelectItem, params: &[ParamValue]) {
140 match item {
141 SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => {
142 bind_expr(e, params);
143 }
144 _ => {}
145 }
146}
147
148fn bind_expr(expr: &mut Expr, params: &[ParamValue]) {
149 match expr {
150 Expr::Value(ast::ValueWithSpan { value, .. }) => {
151 if let Value::Placeholder(p) = value
152 && let Some(v) = placeholder_to_value(p, params)
153 {
154 *value = v;
155 }
156 }
157 Expr::BinaryOp { left, right, .. } => {
158 bind_expr(left, params);
159 bind_expr(right, params);
160 }
161 Expr::UnaryOp { expr: e, .. } => bind_expr(e, params),
162 Expr::Nested(e) => bind_expr(e, params),
163 Expr::Between {
164 expr: e, low, high, ..
165 } => {
166 bind_expr(e, params);
167 bind_expr(low, params);
168 bind_expr(high, params);
169 }
170 Expr::InList { expr: e, list, .. } => {
171 bind_expr(e, params);
172 for item in list {
173 bind_expr(item, params);
174 }
175 }
176 Expr::InSubquery {
177 expr: e, subquery, ..
178 } => {
179 bind_expr(e, params);
180 bind_query(subquery, params);
181 }
182 Expr::IsNull(e) | Expr::IsNotNull(e) => bind_expr(e, params),
183 Expr::IsFalse(e) | Expr::IsTrue(e) => bind_expr(e, params),
184 Expr::IsNotFalse(e) | Expr::IsNotTrue(e) => bind_expr(e, params),
185 Expr::Like {
186 expr: e, pattern, ..
187 }
188 | Expr::ILike {
189 expr: e, pattern, ..
190 } => {
191 bind_expr(e, params);
192 bind_expr(pattern, params);
193 }
194 Expr::Cast { expr: e, .. } => {
195 bind_expr(e, params);
196 }
197 Expr::Function(f) => {
198 if let ast::FunctionArguments::List(ref mut args) = f.args {
199 for arg in &mut args.args {
200 if let ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) = arg {
201 bind_expr(e, params);
202 }
203 }
204 }
205 }
206 Expr::Case {
207 operand,
208 conditions,
209 else_result,
210 ..
211 } => {
212 if let Some(e) = operand {
213 bind_expr(e, params);
214 }
215 for cw in conditions {
216 bind_expr(&mut cw.condition, params);
217 bind_expr(&mut cw.result, params);
218 }
219 if let Some(e) = else_result {
220 bind_expr(e, params);
221 }
222 }
223 Expr::Exists { subquery, .. } => bind_query(subquery, params),
224 Expr::Subquery(q) => bind_query(q, params),
225 _ => {}
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use crate::parser::statement::parse_sql;
233
234 fn bind_and_format(sql: &str, params: &[ParamValue]) -> String {
235 let mut stmts = parse_sql(sql).unwrap();
236 for stmt in &mut stmts {
237 bind_params(stmt, params);
238 }
239 stmts
240 .iter()
241 .map(|s| s.to_string())
242 .collect::<Vec<_>>()
243 .join("; ")
244 }
245
246 #[test]
247 fn bind_select_where() {
248 let result = bind_and_format(
249 "SELECT * FROM users WHERE id = $1",
250 &[ParamValue::Int64(42)],
251 );
252 assert!(result.contains("id = 42"), "got: {result}");
253 }
254
255 #[test]
256 fn bind_string_param() {
257 let result = bind_and_format(
258 "SELECT * FROM users WHERE name = $1",
259 &[ParamValue::Text("alice".into())],
260 );
261 assert!(result.contains("name = 'alice'"), "got: {result}");
262 }
263
264 #[test]
265 fn bind_null_param() {
266 let result = bind_and_format("SELECT * FROM users WHERE name = $1", &[ParamValue::Null]);
267 assert!(result.contains("name = NULL"), "got: {result}");
268 }
269
270 #[test]
271 fn bind_multiple_params() {
272 let result = bind_and_format(
273 "SELECT * FROM users WHERE age > $1 AND name = $2",
274 &[ParamValue::Int64(18), ParamValue::Text("bob".into())],
275 );
276 assert!(result.contains("age > 18"), "got: {result}");
277 assert!(result.contains("name = 'bob'"), "got: {result}");
278 }
279
280 #[test]
281 fn bind_insert_values() {
282 let result = bind_and_format(
283 "INSERT INTO users (id, name) VALUES ($1, $2)",
284 &[ParamValue::Int64(1), ParamValue::Text("eve".into())],
285 );
286 assert!(result.contains("1, 'eve'"), "got: {result}");
287 }
288
289 #[test]
290 fn bind_bool_param() {
291 let result = bind_and_format(
292 "SELECT * FROM users WHERE active = $1",
293 &[ParamValue::Bool(true)],
294 );
295 assert!(result.contains("active = true"), "got: {result}");
296 }
297
298 #[test]
299 fn no_params_noop() {
300 let result = bind_and_format("SELECT 1", &[]);
301 assert!(result.contains("SELECT 1"));
302 }
303}