Skip to main content

cdk_sql_common/
stmt.rs

1//! Stataments mod
2use std::collections::HashMap;
3use std::sync::{Arc, RwLock};
4
5use cdk_common::database::Error;
6use once_cell::sync::Lazy;
7
8use crate::database::DatabaseExecutor;
9use crate::value::Value;
10
11/// The Column type
12pub type Column = Value;
13
14/// Expected response type for a given SQL statement
15#[derive(Debug, Clone, Copy, Default)]
16pub enum ExpectedSqlResponse {
17    /// A single row
18    SingleRow,
19    /// All the rows that matches a query
20    #[default]
21    ManyRows,
22    /// How many rows were affected by the query
23    AffectedRows,
24    /// Return the first column of the first row
25    Pluck,
26    /// Batch
27    Batch,
28}
29
30/// Part value
31#[derive(Debug, Clone)]
32pub enum PlaceholderValue {
33    /// Value
34    Value(Value),
35    /// Set
36    Set(Vec<Value>),
37}
38
39impl From<Value> for PlaceholderValue {
40    fn from(value: Value) -> Self {
41        PlaceholderValue::Value(value)
42    }
43}
44
45impl From<Vec<Value>> for PlaceholderValue {
46    fn from(value: Vec<Value>) -> Self {
47        PlaceholderValue::Set(value)
48    }
49}
50
51/// SQL Part
52#[derive(Debug, Clone)]
53pub enum SqlPart {
54    /// Raw SQL statement
55    Raw(Arc<str>),
56    /// Placeholder
57    Placeholder(Arc<str>, Option<PlaceholderValue>),
58}
59
60/// SQL parser error
61#[derive(Debug, PartialEq, thiserror::Error)]
62pub enum SqlParseError {
63    /// Invalid SQL
64    #[error("Unterminated String literal")]
65    UnterminatedStringLiteral,
66    /// Invalid placeholder name
67    #[error("Invalid placeholder name")]
68    InvalidPlaceholder,
69}
70
71/// Rudimentary SQL parser.
72///
73/// This function does not validate the SQL statement, it only extracts the placeholder to be
74/// database agnostic.
75pub fn split_sql_parts(input: &str) -> Result<Vec<SqlPart>, SqlParseError> {
76    let mut parts = Vec::new();
77    let mut current = String::new();
78    let mut chars = input.chars().peekable();
79
80    while let Some(&c) = chars.peek() {
81        match c {
82            '\'' | '"' => {
83                // Start of string literal
84                let quote = c;
85                current.push(
86                    chars
87                        .next()
88                        .ok_or(SqlParseError::UnterminatedStringLiteral)?,
89                );
90
91                let mut closed = false;
92                while let Some(&next) = chars.peek() {
93                    current.push(
94                        chars
95                            .next()
96                            .ok_or(SqlParseError::UnterminatedStringLiteral)?,
97                    );
98
99                    if next == quote {
100                        if chars.peek() == Some(&quote) {
101                            // Escaped quote (e.g. '' inside strings)
102                            current.push(
103                                chars
104                                    .next()
105                                    .ok_or(SqlParseError::UnterminatedStringLiteral)?,
106                            );
107                        } else {
108                            closed = true;
109                            break;
110                        }
111                    }
112                }
113
114                if !closed {
115                    return Err(SqlParseError::UnterminatedStringLiteral);
116                }
117            }
118
119            '-' => {
120                current.push(
121                    chars
122                        .next()
123                        .ok_or(SqlParseError::UnterminatedStringLiteral)?,
124                );
125
126                if chars.peek() == Some(&'-') {
127                    while let Some(&next) = chars.peek() {
128                        current.push(
129                            chars
130                                .next()
131                                .ok_or(SqlParseError::UnterminatedStringLiteral)?,
132                        );
133                        if next == '\n' {
134                            break;
135                        }
136                    }
137                }
138            }
139
140            ':' => {
141                chars.next(); // consume ':'
142
143                if chars.peek() == Some(&':') {
144                    current.push(':');
145                    current.push(
146                        chars
147                            .next()
148                            .ok_or(SqlParseError::UnterminatedStringLiteral)?,
149                    );
150                    continue;
151                }
152
153                // Flush current raw SQL
154                if !current.is_empty() {
155                    parts.push(SqlPart::Raw(current.clone().into()));
156                    current.clear();
157                }
158
159                let mut name = String::new();
160
161                while let Some(&next) = chars.peek() {
162                    if next.is_alphanumeric() || next == '_' {
163                        name.push(
164                            chars
165                                .next()
166                                .ok_or(SqlParseError::UnterminatedStringLiteral)?,
167                        );
168                    } else {
169                        break;
170                    }
171                }
172
173                if name.is_empty() {
174                    return Err(SqlParseError::InvalidPlaceholder);
175                }
176
177                parts.push(SqlPart::Placeholder(name.into(), None));
178            }
179
180            _ => {
181                current.push(
182                    chars
183                        .next()
184                        .ok_or(SqlParseError::UnterminatedStringLiteral)?,
185                );
186            }
187        }
188    }
189
190    if !current.is_empty() {
191        parts.push(SqlPart::Raw(current.into()));
192    }
193
194    Ok(parts)
195}
196
197type Cache = HashMap<String, (Vec<SqlPart>, Option<Arc<str>>)>;
198
199/// Sql message
200#[derive(Debug, Default)]
201pub struct Statement {
202    cache: Arc<RwLock<Cache>>,
203    cached_sql: Option<Arc<str>>,
204    sql: Option<String>,
205    /// The SQL statement
206    pub parts: Vec<SqlPart>,
207    /// The expected response type
208    pub expected_response: ExpectedSqlResponse,
209}
210
211impl Statement {
212    /// Creates a new statement
213    fn new(sql: &str, cache: Arc<RwLock<Cache>>) -> Result<Self, SqlParseError> {
214        let parsed = cache
215            .read()
216            .map(|cache| cache.get(sql).cloned())
217            .ok()
218            .flatten();
219
220        if let Some((parts, cached_sql)) = parsed {
221            Ok(Self {
222                parts,
223                cached_sql,
224                sql: None,
225                cache,
226                ..Default::default()
227            })
228        } else {
229            let parts = split_sql_parts(sql)?;
230
231            if let Ok(mut cache) = cache.write() {
232                cache.insert(sql.to_owned(), (parts.clone(), None));
233            } else {
234                tracing::warn!("Failed to acquire write lock for SQL statement cache");
235            }
236
237            Ok(Self {
238                parts,
239                sql: Some(sql.to_owned()),
240                cache,
241                ..Default::default()
242            })
243        }
244    }
245
246    /// Convert Statement into a SQL statement and the list of placeholders
247    ///
248    /// By default it converts the statement into placeholder using $1..$n placeholders which seems
249    /// to be more widely supported, although it can be reimplemented with other formats since part
250    /// is public
251    pub fn to_sql(self) -> Result<(String, Vec<Value>), Error> {
252        let has_set_placeholder = self.parts.iter().any(|part| {
253            matches!(
254                part,
255                SqlPart::Placeholder(_, Some(PlaceholderValue::Set(_)))
256            )
257        });
258
259        if let (false, Some(cached_sql)) = (has_set_placeholder, self.cached_sql) {
260            let sql = cached_sql.to_string();
261            let values = self
262                .parts
263                .into_iter()
264                .map(|x| match x {
265                    SqlPart::Placeholder(name, value) => {
266                        match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
267                            PlaceholderValue::Value(value) => Ok(vec![value]),
268                            PlaceholderValue::Set(values) => Ok(values),
269                        }
270                    }
271                    SqlPart::Raw(_) => Ok(vec![]),
272                })
273                .collect::<Result<Vec<_>, Error>>()?
274                .into_iter()
275                .flatten()
276                .collect::<Vec<_>>();
277            return Ok((sql, values));
278        }
279
280        let mut placeholder_values = Vec::new();
281        let mut can_be_cached = true;
282        let sql = self
283            .parts
284            .into_iter()
285            .map(|x| match x {
286                SqlPart::Placeholder(name, value) => {
287                    match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
288                        PlaceholderValue::Value(value) => {
289                            placeholder_values.push(value);
290                            Ok::<_, Error>(format!("${}", placeholder_values.len()))
291                        }
292                        PlaceholderValue::Set(mut values) => {
293                            can_be_cached = false;
294                            let start_size = placeholder_values.len();
295                            placeholder_values.append(&mut values);
296                            let placeholders = (start_size + 1..=placeholder_values.len())
297                                .map(|i| format!("${i}"))
298                                .collect::<Vec<_>>()
299                                .join(", ");
300                            Ok(placeholders)
301                        }
302                    }
303                }
304                SqlPart::Raw(raw) => Ok(raw.trim().to_string()),
305            })
306            .collect::<Result<Vec<String>, _>>()?
307            .join(" ");
308
309        if can_be_cached {
310            if let Some(original_sql) = self.sql {
311                let _ = self.cache.write().map(|mut cache| {
312                    if let Some((_, cached_sql)) = cache.get_mut(&original_sql) {
313                        *cached_sql = Some(sql.clone().into());
314                    }
315                });
316            }
317        }
318
319        Ok((sql, placeholder_values))
320    }
321
322    /// Binds a given placeholder to a value.
323    #[inline]
324    pub fn bind<C, V>(mut self, name: C, value: V) -> Self
325    where
326        C: ToString,
327        V: Into<Value>,
328    {
329        let name = name.to_string();
330        let value = value.into();
331        let value: PlaceholderValue = value.into();
332
333        for part in self.parts.iter_mut() {
334            if let SqlPart::Placeholder(part_name, part_value) = part {
335                if **part_name == *name.as_str() {
336                    *part_value = Some(value.clone());
337                }
338            }
339        }
340
341        self
342    }
343
344    /// Binds a single variable with a vector.
345    ///
346    /// This will rewrite the function from `:foo` (where value is vec![1, 2, 3]) to `:foo0, :foo1,
347    /// :foo2` and binds each value from the value vector accordingly.
348    ///
349    /// Returns an error if the vector is empty, as empty `IN` clauses produce invalid SQL.
350    #[inline]
351    pub fn bind_vec<C, V>(mut self, name: C, value: Vec<V>) -> Result<Self, Error>
352    where
353        C: ToString,
354        V: Into<Value>,
355    {
356        let name = name.to_string();
357
358        if value.is_empty() {
359            return Err(Error::EmptyInClause(name));
360        }
361
362        let value: PlaceholderValue = value
363            .into_iter()
364            .map(|x| x.into())
365            .collect::<Vec<Value>>()
366            .into();
367
368        for part in self.parts.iter_mut() {
369            if let SqlPart::Placeholder(part_name, part_value) = part {
370                if **part_name == *name.as_str() {
371                    *part_value = Some(value.clone());
372                }
373            }
374        }
375
376        Ok(self)
377    }
378
379    /// Executes a query and returns the affected rows
380    pub async fn pluck<C>(self, conn: &C) -> Result<Option<Value>, Error>
381    where
382        C: DatabaseExecutor,
383    {
384        conn.pluck(self).await
385    }
386
387    /// Executes a query and returns the affected rows
388    pub async fn batch<C>(self, conn: &C) -> Result<(), Error>
389    where
390        C: DatabaseExecutor,
391    {
392        conn.batch(self).await
393    }
394
395    /// Executes a query and returns the affected rows
396    pub async fn execute<C>(self, conn: &C) -> Result<usize, Error>
397    where
398        C: DatabaseExecutor,
399    {
400        conn.execute(self).await
401    }
402
403    /// Runs the query and returns the first row or None
404    pub async fn fetch_one<C>(self, conn: &C) -> Result<Option<Vec<Column>>, Error>
405    where
406        C: DatabaseExecutor,
407    {
408        conn.fetch_one(self).await
409    }
410
411    /// Runs the query and returns the first row or None
412    pub async fn fetch_all<C>(self, conn: &C) -> Result<Vec<Vec<Column>>, Error>
413    where
414        C: DatabaseExecutor,
415    {
416        conn.fetch_all(self).await
417    }
418}
419
420/// Creates a new query statement
421#[inline(always)]
422pub fn query(sql: &str) -> Result<Statement, Error> {
423    static CACHE: Lazy<Arc<RwLock<Cache>>> = Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
424    Statement::new(sql, CACHE.clone()).map_err(|e| Error::Database(Box::new(e)))
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430
431    #[test]
432    fn bind_vec_errors_on_empty_vec() {
433        let stmt = query("SELECT * FROM foo WHERE id IN (:ids)").unwrap();
434        let result = stmt.bind_vec("ids", Vec::<Vec<u8>>::new());
435        assert!(result.is_err());
436        assert!(matches!(result.unwrap_err(), Error::EmptyInClause(name) if name == "ids"));
437    }
438
439    #[test]
440    fn parser_preserves_postgres_cast_operator() {
441        let stmt = query("SELECT (ord - 1)::int AS matched WHERE id = :id")
442            .unwrap()
443            .bind("id", "quote-id");
444
445        let (sql, values) = stmt.to_sql().unwrap();
446
447        assert_eq!(sql, "SELECT (ord - 1)::int AS matched WHERE id = $1");
448        assert_eq!(values.len(), 1);
449    }
450
451    #[test]
452    fn bind_vec_ignores_cached_sql_for_same_query_string() {
453        let raw_sql = "SELECT * FROM cached_sql_vec_bug WHERE id IN (:ids)";
454
455        let (cached_sql, cached_values) =
456            query(raw_sql).unwrap().bind("ids", 1_i64).to_sql().unwrap();
457        assert!(cached_sql.contains("$1"));
458        assert_eq!(cached_values.len(), 1);
459
460        let (sql, values) = query(raw_sql)
461            .unwrap()
462            .bind_vec("ids", vec![1_i64, 2, 3])
463            .unwrap()
464            .to_sql()
465            .unwrap();
466
467        assert!(sql.contains("$1, $2, $3"));
468        assert_eq!(values.len(), 3);
469    }
470}