1use std::collections::HashMap;
3use std::sync::{Arc, RwLock};
4
5use cdk_common::database::Error;
6use once_cell::sync::Lazy;
7
8use crate::database::DatabaseExecutor;
9use crate::value::Value;
10
11pub type Column = Value;
13
14#[derive(Debug, Clone, Copy, Default)]
16pub enum ExpectedSqlResponse {
17 SingleRow,
19 #[default]
21 ManyRows,
22 AffectedRows,
24 Pluck,
26 Batch,
28}
29
30#[derive(Debug, Clone)]
32pub enum PlaceholderValue {
33 Value(Value),
35 Set(Vec<Value>),
37}
38
39impl From<Value> for PlaceholderValue {
40 fn from(value: Value) -> Self {
41 PlaceholderValue::Value(value)
42 }
43}
44
45impl From<Vec<Value>> for PlaceholderValue {
46 fn from(value: Vec<Value>) -> Self {
47 PlaceholderValue::Set(value)
48 }
49}
50
51#[derive(Debug, Clone)]
53pub enum SqlPart {
54 Raw(Arc<str>),
56 Placeholder(Arc<str>, Option<PlaceholderValue>),
58}
59
60#[derive(Debug, PartialEq, thiserror::Error)]
62pub enum SqlParseError {
63 #[error("Unterminated String literal")]
65 UnterminatedStringLiteral,
66 #[error("Invalid placeholder name")]
68 InvalidPlaceholder,
69}
70
71pub fn split_sql_parts(input: &str) -> Result<Vec<SqlPart>, SqlParseError> {
76 let mut parts = Vec::new();
77 let mut current = String::new();
78 let mut chars = input.chars().peekable();
79
80 while let Some(&c) = chars.peek() {
81 match c {
82 '\'' | '"' => {
83 let quote = c;
85 current.push(
86 chars
87 .next()
88 .ok_or(SqlParseError::UnterminatedStringLiteral)?,
89 );
90
91 let mut closed = false;
92 while let Some(&next) = chars.peek() {
93 current.push(
94 chars
95 .next()
96 .ok_or(SqlParseError::UnterminatedStringLiteral)?,
97 );
98
99 if next == quote {
100 if chars.peek() == Some("e) {
101 current.push(
103 chars
104 .next()
105 .ok_or(SqlParseError::UnterminatedStringLiteral)?,
106 );
107 } else {
108 closed = true;
109 break;
110 }
111 }
112 }
113
114 if !closed {
115 return Err(SqlParseError::UnterminatedStringLiteral);
116 }
117 }
118
119 '-' => {
120 current.push(
121 chars
122 .next()
123 .ok_or(SqlParseError::UnterminatedStringLiteral)?,
124 );
125
126 if chars.peek() == Some(&'-') {
127 while let Some(&next) = chars.peek() {
128 current.push(
129 chars
130 .next()
131 .ok_or(SqlParseError::UnterminatedStringLiteral)?,
132 );
133 if next == '\n' {
134 break;
135 }
136 }
137 }
138 }
139
140 ':' => {
141 chars.next(); if chars.peek() == Some(&':') {
144 current.push(':');
145 current.push(
146 chars
147 .next()
148 .ok_or(SqlParseError::UnterminatedStringLiteral)?,
149 );
150 continue;
151 }
152
153 if !current.is_empty() {
155 parts.push(SqlPart::Raw(current.clone().into()));
156 current.clear();
157 }
158
159 let mut name = String::new();
160
161 while let Some(&next) = chars.peek() {
162 if next.is_alphanumeric() || next == '_' {
163 name.push(
164 chars
165 .next()
166 .ok_or(SqlParseError::UnterminatedStringLiteral)?,
167 );
168 } else {
169 break;
170 }
171 }
172
173 if name.is_empty() {
174 return Err(SqlParseError::InvalidPlaceholder);
175 }
176
177 parts.push(SqlPart::Placeholder(name.into(), None));
178 }
179
180 _ => {
181 current.push(
182 chars
183 .next()
184 .ok_or(SqlParseError::UnterminatedStringLiteral)?,
185 );
186 }
187 }
188 }
189
190 if !current.is_empty() {
191 parts.push(SqlPart::Raw(current.into()));
192 }
193
194 Ok(parts)
195}
196
197type Cache = HashMap<String, (Vec<SqlPart>, Option<Arc<str>>)>;
198
199#[derive(Debug, Default)]
201pub struct Statement {
202 cache: Arc<RwLock<Cache>>,
203 cached_sql: Option<Arc<str>>,
204 sql: Option<String>,
205 pub parts: Vec<SqlPart>,
207 pub expected_response: ExpectedSqlResponse,
209}
210
211impl Statement {
212 fn new(sql: &str, cache: Arc<RwLock<Cache>>) -> Result<Self, SqlParseError> {
214 let parsed = cache
215 .read()
216 .map(|cache| cache.get(sql).cloned())
217 .ok()
218 .flatten();
219
220 if let Some((parts, cached_sql)) = parsed {
221 Ok(Self {
222 parts,
223 cached_sql,
224 sql: None,
225 cache,
226 ..Default::default()
227 })
228 } else {
229 let parts = split_sql_parts(sql)?;
230
231 if let Ok(mut cache) = cache.write() {
232 cache.insert(sql.to_owned(), (parts.clone(), None));
233 } else {
234 tracing::warn!("Failed to acquire write lock for SQL statement cache");
235 }
236
237 Ok(Self {
238 parts,
239 sql: Some(sql.to_owned()),
240 cache,
241 ..Default::default()
242 })
243 }
244 }
245
246 pub fn to_sql(self) -> Result<(String, Vec<Value>), Error> {
252 if let Some(cached_sql) = self.cached_sql {
253 let sql = cached_sql.to_string();
254 let values = self
255 .parts
256 .into_iter()
257 .map(|x| match x {
258 SqlPart::Placeholder(name, value) => {
259 match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
260 PlaceholderValue::Value(value) => Ok(vec![value]),
261 PlaceholderValue::Set(values) => Ok(values),
262 }
263 }
264 SqlPart::Raw(_) => Ok(vec![]),
265 })
266 .collect::<Result<Vec<_>, Error>>()?
267 .into_iter()
268 .flatten()
269 .collect::<Vec<_>>();
270 return Ok((sql, values));
271 }
272
273 let mut placeholder_values = Vec::new();
274 let mut can_be_cached = true;
275 let sql = self
276 .parts
277 .into_iter()
278 .map(|x| match x {
279 SqlPart::Placeholder(name, value) => {
280 match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
281 PlaceholderValue::Value(value) => {
282 placeholder_values.push(value);
283 Ok::<_, Error>(format!("${}", placeholder_values.len()))
284 }
285 PlaceholderValue::Set(mut values) => {
286 can_be_cached = false;
287 let start_size = placeholder_values.len();
288 placeholder_values.append(&mut values);
289 let placeholders = (start_size + 1..=placeholder_values.len())
290 .map(|i| format!("${i}"))
291 .collect::<Vec<_>>()
292 .join(", ");
293 Ok(placeholders)
294 }
295 }
296 }
297 SqlPart::Raw(raw) => Ok(raw.trim().to_string()),
298 })
299 .collect::<Result<Vec<String>, _>>()?
300 .join(" ");
301
302 if can_be_cached {
303 if let Some(original_sql) = self.sql {
304 let _ = self.cache.write().map(|mut cache| {
305 if let Some((_, cached_sql)) = cache.get_mut(&original_sql) {
306 *cached_sql = Some(sql.clone().into());
307 }
308 });
309 }
310 }
311
312 Ok((sql, placeholder_values))
313 }
314
315 #[inline]
317 pub fn bind<C, V>(mut self, name: C, value: V) -> Self
318 where
319 C: ToString,
320 V: Into<Value>,
321 {
322 let name = name.to_string();
323 let value = value.into();
324 let value: PlaceholderValue = value.into();
325
326 for part in self.parts.iter_mut() {
327 if let SqlPart::Placeholder(part_name, part_value) = part {
328 if **part_name == *name.as_str() {
329 *part_value = Some(value.clone());
330 }
331 }
332 }
333
334 self
335 }
336
337 #[inline]
344 pub fn bind_vec<C, V>(mut self, name: C, value: Vec<V>) -> Result<Self, Error>
345 where
346 C: ToString,
347 V: Into<Value>,
348 {
349 let name = name.to_string();
350
351 if value.is_empty() {
352 return Err(Error::EmptyInClause(name));
353 }
354
355 let value: PlaceholderValue = value
356 .into_iter()
357 .map(|x| x.into())
358 .collect::<Vec<Value>>()
359 .into();
360
361 for part in self.parts.iter_mut() {
362 if let SqlPart::Placeholder(part_name, part_value) = part {
363 if **part_name == *name.as_str() {
364 *part_value = Some(value.clone());
365 }
366 }
367 }
368
369 Ok(self)
370 }
371
372 pub async fn pluck<C>(self, conn: &C) -> Result<Option<Value>, Error>
374 where
375 C: DatabaseExecutor,
376 {
377 conn.pluck(self).await
378 }
379
380 pub async fn batch<C>(self, conn: &C) -> Result<(), Error>
382 where
383 C: DatabaseExecutor,
384 {
385 conn.batch(self).await
386 }
387
388 pub async fn execute<C>(self, conn: &C) -> Result<usize, Error>
390 where
391 C: DatabaseExecutor,
392 {
393 conn.execute(self).await
394 }
395
396 pub async fn fetch_one<C>(self, conn: &C) -> Result<Option<Vec<Column>>, Error>
398 where
399 C: DatabaseExecutor,
400 {
401 conn.fetch_one(self).await
402 }
403
404 pub async fn fetch_all<C>(self, conn: &C) -> Result<Vec<Vec<Column>>, Error>
406 where
407 C: DatabaseExecutor,
408 {
409 conn.fetch_all(self).await
410 }
411}
412
413#[inline(always)]
415pub fn query(sql: &str) -> Result<Statement, Error> {
416 static CACHE: Lazy<Arc<RwLock<Cache>>> = Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
417 Statement::new(sql, CACHE.clone()).map_err(|e| Error::Database(Box::new(e)))
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423
424 #[test]
425 fn bind_vec_errors_on_empty_vec() {
426 let stmt = query("SELECT * FROM foo WHERE id IN (:ids)").unwrap();
427 let result = stmt.bind_vec("ids", Vec::<Vec<u8>>::new());
428 assert!(result.is_err());
429 assert!(matches!(result.unwrap_err(), Error::EmptyInClause(name) if name == "ids"));
430 }
431
432 #[test]
433 fn parser_preserves_postgres_cast_operator() {
434 let stmt = query("SELECT (ord - 1)::int AS matched WHERE id = :id")
435 .unwrap()
436 .bind("id", "quote-id");
437
438 let (sql, values) = stmt.to_sql().unwrap();
439
440 assert_eq!(sql, "SELECT (ord - 1)::int AS matched WHERE id = $1");
441 assert_eq!(values.len(), 1);
442 }
443}