1use std::collections::HashMap;
3use std::sync::{Arc, RwLock};
4
5use cdk_common::database::Error;
6use once_cell::sync::Lazy;
7
8use crate::database::DatabaseExecutor;
9use crate::value::Value;
10
11pub type Column = Value;
13
14#[derive(Debug, Clone, Copy, Default)]
16pub enum ExpectedSqlResponse {
17 SingleRow,
19 #[default]
21 ManyRows,
22 AffectedRows,
24 Pluck,
26 Batch,
28}
29
30#[derive(Debug, Clone)]
32pub enum PlaceholderValue {
33 Value(Value),
35 Set(Vec<Value>),
37}
38
39impl From<Value> for PlaceholderValue {
40 fn from(value: Value) -> Self {
41 PlaceholderValue::Value(value)
42 }
43}
44
45impl From<Vec<Value>> for PlaceholderValue {
46 fn from(value: Vec<Value>) -> Self {
47 PlaceholderValue::Set(value)
48 }
49}
50
51#[derive(Debug, Clone)]
53pub enum SqlPart {
54 Raw(Arc<str>),
56 Placeholder(Arc<str>, Option<PlaceholderValue>),
58}
59
60#[derive(Debug, PartialEq, thiserror::Error)]
62pub enum SqlParseError {
63 #[error("Unterminated String literal")]
65 UnterminatedStringLiteral,
66 #[error("Invalid placeholder name")]
68 InvalidPlaceholder,
69}
70
71pub fn split_sql_parts(input: &str) -> Result<Vec<SqlPart>, SqlParseError> {
76 let mut parts = Vec::new();
77 let mut current = String::new();
78 let mut chars = input.chars().peekable();
79
80 while let Some(&c) = chars.peek() {
81 match c {
82 '\'' | '"' => {
83 let quote = c;
85 current.push(
86 chars
87 .next()
88 .ok_or(SqlParseError::UnterminatedStringLiteral)?,
89 );
90
91 let mut closed = false;
92 while let Some(&next) = chars.peek() {
93 current.push(
94 chars
95 .next()
96 .ok_or(SqlParseError::UnterminatedStringLiteral)?,
97 );
98
99 if next == quote {
100 if chars.peek() == Some("e) {
101 current.push(
103 chars
104 .next()
105 .ok_or(SqlParseError::UnterminatedStringLiteral)?,
106 );
107 } else {
108 closed = true;
109 break;
110 }
111 }
112 }
113
114 if !closed {
115 return Err(SqlParseError::UnterminatedStringLiteral);
116 }
117 }
118
119 '-' => {
120 current.push(
121 chars
122 .next()
123 .ok_or(SqlParseError::UnterminatedStringLiteral)?,
124 );
125
126 if chars.peek() == Some(&'-') {
127 while let Some(&next) = chars.peek() {
128 current.push(
129 chars
130 .next()
131 .ok_or(SqlParseError::UnterminatedStringLiteral)?,
132 );
133 if next == '\n' {
134 break;
135 }
136 }
137 }
138 }
139
140 ':' => {
141 chars.next(); if chars.peek() == Some(&':') {
144 current.push(':');
145 current.push(
146 chars
147 .next()
148 .ok_or(SqlParseError::UnterminatedStringLiteral)?,
149 );
150 continue;
151 }
152
153 if !current.is_empty() {
155 parts.push(SqlPart::Raw(current.clone().into()));
156 current.clear();
157 }
158
159 let mut name = String::new();
160
161 while let Some(&next) = chars.peek() {
162 if next.is_alphanumeric() || next == '_' {
163 name.push(
164 chars
165 .next()
166 .ok_or(SqlParseError::UnterminatedStringLiteral)?,
167 );
168 } else {
169 break;
170 }
171 }
172
173 if name.is_empty() {
174 return Err(SqlParseError::InvalidPlaceholder);
175 }
176
177 parts.push(SqlPart::Placeholder(name.into(), None));
178 }
179
180 _ => {
181 current.push(
182 chars
183 .next()
184 .ok_or(SqlParseError::UnterminatedStringLiteral)?,
185 );
186 }
187 }
188 }
189
190 if !current.is_empty() {
191 parts.push(SqlPart::Raw(current.into()));
192 }
193
194 Ok(parts)
195}
196
197type Cache = HashMap<String, (Vec<SqlPart>, Option<Arc<str>>)>;
198
199#[derive(Debug, Default)]
201pub struct Statement {
202 cache: Arc<RwLock<Cache>>,
203 cached_sql: Option<Arc<str>>,
204 sql: Option<String>,
205 pub parts: Vec<SqlPart>,
207 pub expected_response: ExpectedSqlResponse,
209}
210
211impl Statement {
212 fn new(sql: &str, cache: Arc<RwLock<Cache>>) -> Result<Self, SqlParseError> {
214 let parsed = cache
215 .read()
216 .map(|cache| cache.get(sql).cloned())
217 .ok()
218 .flatten();
219
220 if let Some((parts, cached_sql)) = parsed {
221 Ok(Self {
222 parts,
223 cached_sql,
224 sql: None,
225 cache,
226 ..Default::default()
227 })
228 } else {
229 let parts = split_sql_parts(sql)?;
230
231 if let Ok(mut cache) = cache.write() {
232 cache.insert(sql.to_owned(), (parts.clone(), None));
233 } else {
234 tracing::warn!("Failed to acquire write lock for SQL statement cache");
235 }
236
237 Ok(Self {
238 parts,
239 sql: Some(sql.to_owned()),
240 cache,
241 ..Default::default()
242 })
243 }
244 }
245
246 pub fn to_sql(self) -> Result<(String, Vec<Value>), Error> {
252 let has_set_placeholder = self.parts.iter().any(|part| {
253 matches!(
254 part,
255 SqlPart::Placeholder(_, Some(PlaceholderValue::Set(_)))
256 )
257 });
258
259 if let (false, Some(cached_sql)) = (has_set_placeholder, self.cached_sql) {
260 let sql = cached_sql.to_string();
261 let values = self
262 .parts
263 .into_iter()
264 .map(|x| match x {
265 SqlPart::Placeholder(name, value) => {
266 match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
267 PlaceholderValue::Value(value) => Ok(vec![value]),
268 PlaceholderValue::Set(values) => Ok(values),
269 }
270 }
271 SqlPart::Raw(_) => Ok(vec![]),
272 })
273 .collect::<Result<Vec<_>, Error>>()?
274 .into_iter()
275 .flatten()
276 .collect::<Vec<_>>();
277 return Ok((sql, values));
278 }
279
280 let mut placeholder_values = Vec::new();
281 let mut can_be_cached = true;
282 let sql = self
283 .parts
284 .into_iter()
285 .map(|x| match x {
286 SqlPart::Placeholder(name, value) => {
287 match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
288 PlaceholderValue::Value(value) => {
289 placeholder_values.push(value);
290 Ok::<_, Error>(format!("${}", placeholder_values.len()))
291 }
292 PlaceholderValue::Set(mut values) => {
293 can_be_cached = false;
294 let start_size = placeholder_values.len();
295 placeholder_values.append(&mut values);
296 let placeholders = (start_size + 1..=placeholder_values.len())
297 .map(|i| format!("${i}"))
298 .collect::<Vec<_>>()
299 .join(", ");
300 Ok(placeholders)
301 }
302 }
303 }
304 SqlPart::Raw(raw) => Ok(raw.trim().to_string()),
305 })
306 .collect::<Result<Vec<String>, _>>()?
307 .join(" ");
308
309 if can_be_cached {
310 if let Some(original_sql) = self.sql {
311 let _ = self.cache.write().map(|mut cache| {
312 if let Some((_, cached_sql)) = cache.get_mut(&original_sql) {
313 *cached_sql = Some(sql.clone().into());
314 }
315 });
316 }
317 }
318
319 Ok((sql, placeholder_values))
320 }
321
322 #[inline]
324 pub fn bind<C, V>(mut self, name: C, value: V) -> Self
325 where
326 C: ToString,
327 V: Into<Value>,
328 {
329 let name = name.to_string();
330 let value = value.into();
331 let value: PlaceholderValue = value.into();
332
333 for part in self.parts.iter_mut() {
334 if let SqlPart::Placeholder(part_name, part_value) = part {
335 if **part_name == *name.as_str() {
336 *part_value = Some(value.clone());
337 }
338 }
339 }
340
341 self
342 }
343
344 #[inline]
351 pub fn bind_vec<C, V>(mut self, name: C, value: Vec<V>) -> Result<Self, Error>
352 where
353 C: ToString,
354 V: Into<Value>,
355 {
356 let name = name.to_string();
357
358 if value.is_empty() {
359 return Err(Error::EmptyInClause(name));
360 }
361
362 let value: PlaceholderValue = value
363 .into_iter()
364 .map(|x| x.into())
365 .collect::<Vec<Value>>()
366 .into();
367
368 for part in self.parts.iter_mut() {
369 if let SqlPart::Placeholder(part_name, part_value) = part {
370 if **part_name == *name.as_str() {
371 *part_value = Some(value.clone());
372 }
373 }
374 }
375
376 Ok(self)
377 }
378
379 pub async fn pluck<C>(self, conn: &C) -> Result<Option<Value>, Error>
381 where
382 C: DatabaseExecutor,
383 {
384 conn.pluck(self).await
385 }
386
387 pub async fn batch<C>(self, conn: &C) -> Result<(), Error>
389 where
390 C: DatabaseExecutor,
391 {
392 conn.batch(self).await
393 }
394
395 pub async fn execute<C>(self, conn: &C) -> Result<usize, Error>
397 where
398 C: DatabaseExecutor,
399 {
400 conn.execute(self).await
401 }
402
403 pub async fn fetch_one<C>(self, conn: &C) -> Result<Option<Vec<Column>>, Error>
405 where
406 C: DatabaseExecutor,
407 {
408 conn.fetch_one(self).await
409 }
410
411 pub async fn fetch_all<C>(self, conn: &C) -> Result<Vec<Vec<Column>>, Error>
413 where
414 C: DatabaseExecutor,
415 {
416 conn.fetch_all(self).await
417 }
418}
419
420#[inline(always)]
422pub fn query(sql: &str) -> Result<Statement, Error> {
423 static CACHE: Lazy<Arc<RwLock<Cache>>> = Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
424 Statement::new(sql, CACHE.clone()).map_err(|e| Error::Database(Box::new(e)))
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430
431 #[test]
432 fn bind_vec_errors_on_empty_vec() {
433 let stmt = query("SELECT * FROM foo WHERE id IN (:ids)").unwrap();
434 let result = stmt.bind_vec("ids", Vec::<Vec<u8>>::new());
435 assert!(result.is_err());
436 assert!(matches!(result.unwrap_err(), Error::EmptyInClause(name) if name == "ids"));
437 }
438
439 #[test]
440 fn parser_preserves_postgres_cast_operator() {
441 let stmt = query("SELECT (ord - 1)::int AS matched WHERE id = :id")
442 .unwrap()
443 .bind("id", "quote-id");
444
445 let (sql, values) = stmt.to_sql().unwrap();
446
447 assert_eq!(sql, "SELECT (ord - 1)::int AS matched WHERE id = $1");
448 assert_eq!(values.len(), 1);
449 }
450
451 #[test]
452 fn bind_vec_ignores_cached_sql_for_same_query_string() {
453 let raw_sql = "SELECT * FROM cached_sql_vec_bug WHERE id IN (:ids)";
454
455 let (cached_sql, cached_values) =
456 query(raw_sql).unwrap().bind("ids", 1_i64).to_sql().unwrap();
457 assert!(cached_sql.contains("$1"));
458 assert_eq!(cached_values.len(), 1);
459
460 let (sql, values) = query(raw_sql)
461 .unwrap()
462 .bind_vec("ids", vec![1_i64, 2, 3])
463 .unwrap()
464 .to_sql()
465 .unwrap();
466
467 assert!(sql.contains("$1, $2, $3"));
468 assert_eq!(values.len(), 3);
469 }
470}