1use sqlparser::dialect::PostgreSqlDialect;
34use sqlparser::tokenizer::{Token, Tokenizer};
35
36use crate::error::{Result, SqlError};
37use crate::params::ParamValue;
38
39#[derive(Debug, Clone)]
46pub struct BoundDslSql(String);
47
48impl BoundDslSql {
49 pub fn from_simple_query(sql: String) -> Self {
56 Self(sql)
57 }
58
59 pub fn as_str(&self) -> &str {
60 &self.0
61 }
62
63 pub fn into_string(self) -> String {
64 self.0
65 }
66}
67
68pub fn bind_dsl(sql: &str, params: &[ParamValue]) -> Result<BoundDslSql> {
74 if params.is_empty() {
75 return Ok(BoundDslSql(sql.to_owned()));
76 }
77 let dialect = PostgreSqlDialect {};
78 let tokens = Tokenizer::new(&dialect, sql)
79 .tokenize()
80 .map_err(|e| SqlError::Parse {
81 detail: format!("tokenize DSL for parameter binding: {e}"),
82 })?;
83
84 let mut out = String::with_capacity(sql.len());
85 for tok in &tokens {
86 if let Token::Placeholder(p) = tok
87 && p.starts_with('$')
88 {
89 let replacement =
94 placeholder_literal_token(p, params).ok_or_else(|| SqlError::Parse {
95 detail: format!(
96 "DSL parameter bind: placeholder {p} has no corresponding \
97 parameter ({len} provided)",
98 len = params.len()
99 ),
100 })?;
101 out.push_str(&replacement.to_string());
102 continue;
103 }
104 out.push_str(&tok.to_string());
105 }
106 Ok(BoundDslSql(out))
107}
108
109fn placeholder_literal_token(placeholder: &str, params: &[ParamValue]) -> Option<Token> {
110 let idx_str = placeholder.strip_prefix('$')?;
111 let idx: usize = idx_str.parse().ok()?;
112 let param = params.get(idx.checked_sub(1)?)?;
113 Some(match param {
114 ParamValue::Null => Token::make_keyword("NULL"),
115 ParamValue::Bool(true) => Token::make_keyword("TRUE"),
116 ParamValue::Bool(false) => Token::make_keyword("FALSE"),
117 ParamValue::Int64(n) => Token::Number(n.to_string(), false),
118 ParamValue::Float64(f) => Token::Number(f.to_string(), false),
119 ParamValue::Text(s) => Token::SingleQuotedString(s.clone()),
120 })
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126
127 #[test]
128 fn upsert_int_and_text_params() {
129 let bound = bind_dsl(
130 "UPSERT INTO t (id, n) VALUES ($1, $2)",
131 &[ParamValue::Text("alice".into()), ParamValue::Int64(42)],
132 )
133 .unwrap();
134 assert!(
135 bound.as_str().contains("'alice'"),
136 "text param not substituted: {}",
137 bound.as_str()
138 );
139 assert!(
140 bound.as_str().contains("42"),
141 "int param not substituted: {}",
142 bound.as_str()
143 );
144 assert!(
145 !bound.as_str().contains('$'),
146 "placeholder survived: {}",
147 bound.as_str()
148 );
149 }
150
151 #[test]
152 fn search_top_k_param() {
153 let bound = bind_dsl(
154 "SEARCH v USING VECTOR(ARRAY[1.0, 0.0, 0.0], $1)",
155 &[ParamValue::Int64(5)],
156 )
157 .unwrap();
158 assert!(bound.as_str().contains(", 5)"), "got: {}", bound.as_str());
159 }
160
161 #[test]
165 fn placeholder_inside_string_literal_untouched() {
166 let bound = bind_dsl(
167 "UPSERT INTO t (id, note) VALUES ($1, 'your $1 change')",
168 &[ParamValue::Text("abc".into())],
169 )
170 .unwrap();
171 assert!(
172 bound.as_str().contains("'your $1 change'"),
173 "string literal was rewritten: {}",
174 bound.as_str()
175 );
176 assert!(
177 bound.as_str().contains("'abc'"),
178 "real placeholder not bound: {}",
179 bound.as_str()
180 );
181 }
182
183 #[test]
184 fn null_param() {
185 let bound = bind_dsl(
186 "UPSERT INTO t (id, n) VALUES ($1, $2)",
187 &[ParamValue::Text("x".into()), ParamValue::Null],
188 )
189 .unwrap();
190 let s = bound.as_str();
191 assert!(s.to_uppercase().contains("NULL"), "got: {s}");
192 }
193
194 #[test]
195 fn out_of_range_placeholder_errors() {
196 let err = bind_dsl(
197 "UPSERT INTO t (id, n) VALUES ($1, $2)",
198 &[ParamValue::Text("only-one".into())],
199 )
200 .unwrap_err();
201 let msg = format!("{err:?}");
202 assert!(
203 msg.contains("$2") && msg.to_lowercase().contains("placeholder"),
204 "error must name the unresolved placeholder: {msg}"
205 );
206 }
207
208 #[test]
209 fn zero_placeholder_errors() {
210 let err = bind_dsl(
213 "UPSERT INTO t (id) VALUES ($0)",
214 &[ParamValue::Text("x".into())],
215 )
216 .unwrap_err();
217 assert!(format!("{err:?}").contains("$0"));
218 }
219
220 #[test]
221 fn empty_params_is_noop() {
222 let sql = "UPSERT INTO t (id) VALUES ('a')";
223 let bound = bind_dsl(sql, &[]).unwrap();
224 assert_eq!(bound.as_str(), sql);
225 }
226}