1use sqlparser::dialect::PostgreSqlDialect;
36use sqlparser::tokenizer::{Token, Tokenizer};
37
38use crate::error::{Result, SqlError};
39use crate::params::ParamValue;
40
41#[derive(Debug, Clone)]
48pub struct BoundDslSql(String);
49
50impl BoundDslSql {
51 pub fn from_simple_query(sql: String) -> Self {
58 Self(sql)
59 }
60
61 pub fn as_str(&self) -> &str {
62 &self.0
63 }
64
65 pub fn into_string(self) -> String {
66 self.0
67 }
68}
69
70pub fn bind_dsl(sql: &str, params: &[ParamValue]) -> Result<BoundDslSql> {
76 if params.is_empty() {
77 return Ok(BoundDslSql(sql.to_owned()));
78 }
79 let dialect = PostgreSqlDialect {};
80 let tokens = Tokenizer::new(&dialect, sql)
81 .tokenize()
82 .map_err(|e| SqlError::Parse {
83 detail: format!("tokenize DSL for parameter binding: {e}"),
84 })?;
85
86 let mut out = String::with_capacity(sql.len());
87 for tok in &tokens {
88 if let Token::Placeholder(p) = tok
89 && p.starts_with('$')
90 {
91 let replacement =
96 placeholder_literal_token(p, params).ok_or_else(|| SqlError::Parse {
97 detail: format!(
98 "DSL parameter bind: placeholder {p} has no corresponding \
99 parameter ({len} provided)",
100 len = params.len()
101 ),
102 })?;
103 out.push_str(&replacement.to_string());
104 continue;
105 }
106 out.push_str(&tok.to_string());
107 }
108 Ok(BoundDslSql(out))
109}
110
111fn placeholder_literal_token(placeholder: &str, params: &[ParamValue]) -> Option<Token> {
112 let idx_str = placeholder.strip_prefix('$')?;
113 let idx: usize = idx_str.parse().ok()?;
114 let param = params.get(idx.checked_sub(1)?)?;
115 Some(match param {
116 ParamValue::Null => Token::make_keyword("NULL"),
117 ParamValue::Bool(true) => Token::make_keyword("TRUE"),
118 ParamValue::Bool(false) => Token::make_keyword("FALSE"),
119 ParamValue::Int64(n) => Token::Number(n.to_string(), false),
120 ParamValue::Float64(f) => Token::Number(f.to_string(), false),
121 ParamValue::Decimal(d) => Token::Number(d.to_string(), false),
122 ParamValue::Text(s) => Token::SingleQuotedString(s.clone()),
123 ParamValue::Timestamp(dt) | ParamValue::Timestamptz(dt) => {
124 Token::SingleQuotedString(dt.to_iso8601())
125 }
126 })
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 #[test]
134 fn upsert_int_and_text_params() {
135 let bound = bind_dsl(
136 "UPSERT INTO t (id, n) VALUES ($1, $2)",
137 &[ParamValue::Text("alice".into()), ParamValue::Int64(42)],
138 )
139 .unwrap();
140 assert!(
141 bound.as_str().contains("'alice'"),
142 "text param not substituted: {}",
143 bound.as_str()
144 );
145 assert!(
146 bound.as_str().contains("42"),
147 "int param not substituted: {}",
148 bound.as_str()
149 );
150 assert!(
151 !bound.as_str().contains('$'),
152 "placeholder survived: {}",
153 bound.as_str()
154 );
155 }
156
157 #[test]
158 fn search_top_k_param() {
159 let bound = bind_dsl(
160 "SEARCH v USING VECTOR(ARRAY[1.0, 0.0, 0.0], $1)",
161 &[ParamValue::Int64(5)],
162 )
163 .unwrap();
164 assert!(bound.as_str().contains(", 5)"), "got: {}", bound.as_str());
165 }
166
167 #[test]
171 fn placeholder_inside_string_literal_untouched() {
172 let bound = bind_dsl(
173 "UPSERT INTO t (id, note) VALUES ($1, 'your $1 change')",
174 &[ParamValue::Text("abc".into())],
175 )
176 .unwrap();
177 assert!(
178 bound.as_str().contains("'your $1 change'"),
179 "string literal was rewritten: {}",
180 bound.as_str()
181 );
182 assert!(
183 bound.as_str().contains("'abc'"),
184 "real placeholder not bound: {}",
185 bound.as_str()
186 );
187 }
188
189 #[test]
190 fn null_param() {
191 let bound = bind_dsl(
192 "UPSERT INTO t (id, n) VALUES ($1, $2)",
193 &[ParamValue::Text("x".into()), ParamValue::Null],
194 )
195 .unwrap();
196 let s = bound.as_str();
197 assert!(s.to_uppercase().contains("NULL"), "got: {s}");
198 }
199
200 #[test]
201 fn out_of_range_placeholder_errors() {
202 let err = bind_dsl(
203 "UPSERT INTO t (id, n) VALUES ($1, $2)",
204 &[ParamValue::Text("only-one".into())],
205 )
206 .unwrap_err();
207 let msg = format!("{err:?}");
208 assert!(
209 msg.contains("$2") && msg.to_lowercase().contains("placeholder"),
210 "error must name the unresolved placeholder: {msg}"
211 );
212 }
213
214 #[test]
215 fn zero_placeholder_errors() {
216 let err = bind_dsl(
219 "UPSERT INTO t (id) VALUES ($0)",
220 &[ParamValue::Text("x".into())],
221 )
222 .unwrap_err();
223 assert!(format!("{err:?}").contains("$0"));
224 }
225
226 #[test]
227 fn empty_params_is_noop() {
228 let sql = "UPSERT INTO t (id) VALUES ('a')";
229 let bound = bind_dsl(sql, &[]).unwrap();
230 assert_eq!(bound.as_str(), sql);
231 }
232}