Skip to main content

cdk_sql_common/
stmt.rs

1//! Stataments mod
2use std::collections::HashMap;
3use std::sync::{Arc, RwLock};
4
5use cdk_common::database::Error;
6use once_cell::sync::Lazy;
7
8use crate::database::DatabaseExecutor;
9use crate::value::Value;
10
11/// The Column type
12pub type Column = Value;
13
14/// Expected response type for a given SQL statement
15#[derive(Debug, Clone, Copy, Default)]
16pub enum ExpectedSqlResponse {
17    /// A single row
18    SingleRow,
19    /// All the rows that matches a query
20    #[default]
21    ManyRows,
22    /// How many rows were affected by the query
23    AffectedRows,
24    /// Return the first column of the first row
25    Pluck,
26    /// Batch
27    Batch,
28}
29
30/// Part value
31#[derive(Debug, Clone)]
32pub enum PlaceholderValue {
33    /// Value
34    Value(Value),
35    /// Set
36    Set(Vec<Value>),
37}
38
39impl From<Value> for PlaceholderValue {
40    fn from(value: Value) -> Self {
41        PlaceholderValue::Value(value)
42    }
43}
44
45impl From<Vec<Value>> for PlaceholderValue {
46    fn from(value: Vec<Value>) -> Self {
47        PlaceholderValue::Set(value)
48    }
49}
50
51/// SQL Part
52#[derive(Debug, Clone)]
53pub enum SqlPart {
54    /// Raw SQL statement
55    Raw(Arc<str>),
56    /// Placeholder
57    Placeholder(Arc<str>, Option<PlaceholderValue>),
58}
59
60/// SQL parser error
61#[derive(Debug, PartialEq, thiserror::Error)]
62pub enum SqlParseError {
63    /// Invalid SQL
64    #[error("Unterminated String literal")]
65    UnterminatedStringLiteral,
66    /// Invalid placeholder name
67    #[error("Invalid placeholder name")]
68    InvalidPlaceholder,
69}
70
71/// Rudimentary SQL parser.
72///
73/// This function does not validate the SQL statement, it only extracts the placeholder to be
74/// database agnostic.
75pub fn split_sql_parts(input: &str) -> Result<Vec<SqlPart>, SqlParseError> {
76    let mut parts = Vec::new();
77    let mut current = String::new();
78    let mut chars = input.chars().peekable();
79
80    while let Some(&c) = chars.peek() {
81        match c {
82            '\'' | '"' => {
83                // Start of string literal
84                let quote = c;
85                current.push(
86                    chars
87                        .next()
88                        .ok_or(SqlParseError::UnterminatedStringLiteral)?,
89                );
90
91                let mut closed = false;
92                while let Some(&next) = chars.peek() {
93                    current.push(
94                        chars
95                            .next()
96                            .ok_or(SqlParseError::UnterminatedStringLiteral)?,
97                    );
98
99                    if next == quote {
100                        if chars.peek() == Some(&quote) {
101                            // Escaped quote (e.g. '' inside strings)
102                            current.push(
103                                chars
104                                    .next()
105                                    .ok_or(SqlParseError::UnterminatedStringLiteral)?,
106                            );
107                        } else {
108                            closed = true;
109                            break;
110                        }
111                    }
112                }
113
114                if !closed {
115                    return Err(SqlParseError::UnterminatedStringLiteral);
116                }
117            }
118
119            '-' => {
120                current.push(
121                    chars
122                        .next()
123                        .ok_or(SqlParseError::UnterminatedStringLiteral)?,
124                );
125
126                if chars.peek() == Some(&'-') {
127                    while let Some(&next) = chars.peek() {
128                        current.push(
129                            chars
130                                .next()
131                                .ok_or(SqlParseError::UnterminatedStringLiteral)?,
132                        );
133                        if next == '\n' {
134                            break;
135                        }
136                    }
137                }
138            }
139
140            ':' => {
141                chars.next(); // consume ':'
142
143                if chars.peek() == Some(&':') {
144                    current.push(':');
145                    current.push(
146                        chars
147                            .next()
148                            .ok_or(SqlParseError::UnterminatedStringLiteral)?,
149                    );
150                    continue;
151                }
152
153                // Flush current raw SQL
154                if !current.is_empty() {
155                    parts.push(SqlPart::Raw(current.clone().into()));
156                    current.clear();
157                }
158
159                let mut name = String::new();
160
161                while let Some(&next) = chars.peek() {
162                    if next.is_alphanumeric() || next == '_' {
163                        name.push(
164                            chars
165                                .next()
166                                .ok_or(SqlParseError::UnterminatedStringLiteral)?,
167                        );
168                    } else {
169                        break;
170                    }
171                }
172
173                if name.is_empty() {
174                    return Err(SqlParseError::InvalidPlaceholder);
175                }
176
177                parts.push(SqlPart::Placeholder(name.into(), None));
178            }
179
180            _ => {
181                current.push(
182                    chars
183                        .next()
184                        .ok_or(SqlParseError::UnterminatedStringLiteral)?,
185                );
186            }
187        }
188    }
189
190    if !current.is_empty() {
191        parts.push(SqlPart::Raw(current.into()));
192    }
193
194    Ok(parts)
195}
196
197type Cache = HashMap<String, (Vec<SqlPart>, Option<Arc<str>>)>;
198
199/// Sql message
200#[derive(Debug, Default)]
201pub struct Statement {
202    cache: Arc<RwLock<Cache>>,
203    cached_sql: Option<Arc<str>>,
204    sql: Option<String>,
205    /// The SQL statement
206    pub parts: Vec<SqlPart>,
207    /// The expected response type
208    pub expected_response: ExpectedSqlResponse,
209}
210
211impl Statement {
212    /// Creates a new statement
213    fn new(sql: &str, cache: Arc<RwLock<Cache>>) -> Result<Self, SqlParseError> {
214        let parsed = cache
215            .read()
216            .map(|cache| cache.get(sql).cloned())
217            .ok()
218            .flatten();
219
220        if let Some((parts, cached_sql)) = parsed {
221            Ok(Self {
222                parts,
223                cached_sql,
224                sql: None,
225                cache,
226                ..Default::default()
227            })
228        } else {
229            let parts = split_sql_parts(sql)?;
230
231            if let Ok(mut cache) = cache.write() {
232                cache.insert(sql.to_owned(), (parts.clone(), None));
233            } else {
234                tracing::warn!("Failed to acquire write lock for SQL statement cache");
235            }
236
237            Ok(Self {
238                parts,
239                sql: Some(sql.to_owned()),
240                cache,
241                ..Default::default()
242            })
243        }
244    }
245
246    /// Convert Statement into a SQL statement and the list of placeholders
247    ///
248    /// By default it converts the statement into placeholder using $1..$n placeholders which seems
249    /// to be more widely supported, although it can be reimplemented with other formats since part
250    /// is public
251    pub fn to_sql(self) -> Result<(String, Vec<Value>), Error> {
252        if let Some(cached_sql) = self.cached_sql {
253            let sql = cached_sql.to_string();
254            let values = self
255                .parts
256                .into_iter()
257                .map(|x| match x {
258                    SqlPart::Placeholder(name, value) => {
259                        match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
260                            PlaceholderValue::Value(value) => Ok(vec![value]),
261                            PlaceholderValue::Set(values) => Ok(values),
262                        }
263                    }
264                    SqlPart::Raw(_) => Ok(vec![]),
265                })
266                .collect::<Result<Vec<_>, Error>>()?
267                .into_iter()
268                .flatten()
269                .collect::<Vec<_>>();
270            return Ok((sql, values));
271        }
272
273        let mut placeholder_values = Vec::new();
274        let mut can_be_cached = true;
275        let sql = self
276            .parts
277            .into_iter()
278            .map(|x| match x {
279                SqlPart::Placeholder(name, value) => {
280                    match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
281                        PlaceholderValue::Value(value) => {
282                            placeholder_values.push(value);
283                            Ok::<_, Error>(format!("${}", placeholder_values.len()))
284                        }
285                        PlaceholderValue::Set(mut values) => {
286                            can_be_cached = false;
287                            let start_size = placeholder_values.len();
288                            placeholder_values.append(&mut values);
289                            let placeholders = (start_size + 1..=placeholder_values.len())
290                                .map(|i| format!("${i}"))
291                                .collect::<Vec<_>>()
292                                .join(", ");
293                            Ok(placeholders)
294                        }
295                    }
296                }
297                SqlPart::Raw(raw) => Ok(raw.trim().to_string()),
298            })
299            .collect::<Result<Vec<String>, _>>()?
300            .join(" ");
301
302        if can_be_cached {
303            if let Some(original_sql) = self.sql {
304                let _ = self.cache.write().map(|mut cache| {
305                    if let Some((_, cached_sql)) = cache.get_mut(&original_sql) {
306                        *cached_sql = Some(sql.clone().into());
307                    }
308                });
309            }
310        }
311
312        Ok((sql, placeholder_values))
313    }
314
315    /// Binds a given placeholder to a value.
316    #[inline]
317    pub fn bind<C, V>(mut self, name: C, value: V) -> Self
318    where
319        C: ToString,
320        V: Into<Value>,
321    {
322        let name = name.to_string();
323        let value = value.into();
324        let value: PlaceholderValue = value.into();
325
326        for part in self.parts.iter_mut() {
327            if let SqlPart::Placeholder(part_name, part_value) = part {
328                if **part_name == *name.as_str() {
329                    *part_value = Some(value.clone());
330                }
331            }
332        }
333
334        self
335    }
336
337    /// Binds a single variable with a vector.
338    ///
339    /// This will rewrite the function from `:foo` (where value is vec![1, 2, 3]) to `:foo0, :foo1,
340    /// :foo2` and binds each value from the value vector accordingly.
341    ///
342    /// Returns an error if the vector is empty, as empty `IN` clauses produce invalid SQL.
343    #[inline]
344    pub fn bind_vec<C, V>(mut self, name: C, value: Vec<V>) -> Result<Self, Error>
345    where
346        C: ToString,
347        V: Into<Value>,
348    {
349        let name = name.to_string();
350
351        if value.is_empty() {
352            return Err(Error::EmptyInClause(name));
353        }
354
355        let value: PlaceholderValue = value
356            .into_iter()
357            .map(|x| x.into())
358            .collect::<Vec<Value>>()
359            .into();
360
361        for part in self.parts.iter_mut() {
362            if let SqlPart::Placeholder(part_name, part_value) = part {
363                if **part_name == *name.as_str() {
364                    *part_value = Some(value.clone());
365                }
366            }
367        }
368
369        Ok(self)
370    }
371
372    /// Executes a query and returns the affected rows
373    pub async fn pluck<C>(self, conn: &C) -> Result<Option<Value>, Error>
374    where
375        C: DatabaseExecutor,
376    {
377        conn.pluck(self).await
378    }
379
380    /// Executes a query and returns the affected rows
381    pub async fn batch<C>(self, conn: &C) -> Result<(), Error>
382    where
383        C: DatabaseExecutor,
384    {
385        conn.batch(self).await
386    }
387
388    /// Executes a query and returns the affected rows
389    pub async fn execute<C>(self, conn: &C) -> Result<usize, Error>
390    where
391        C: DatabaseExecutor,
392    {
393        conn.execute(self).await
394    }
395
396    /// Runs the query and returns the first row or None
397    pub async fn fetch_one<C>(self, conn: &C) -> Result<Option<Vec<Column>>, Error>
398    where
399        C: DatabaseExecutor,
400    {
401        conn.fetch_one(self).await
402    }
403
404    /// Runs the query and returns the first row or None
405    pub async fn fetch_all<C>(self, conn: &C) -> Result<Vec<Vec<Column>>, Error>
406    where
407        C: DatabaseExecutor,
408    {
409        conn.fetch_all(self).await
410    }
411}
412
413/// Creates a new query statement
414#[inline(always)]
415pub fn query(sql: &str) -> Result<Statement, Error> {
416    static CACHE: Lazy<Arc<RwLock<Cache>>> = Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
417    Statement::new(sql, CACHE.clone()).map_err(|e| Error::Database(Box::new(e)))
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423
424    #[test]
425    fn bind_vec_errors_on_empty_vec() {
426        let stmt = query("SELECT * FROM foo WHERE id IN (:ids)").unwrap();
427        let result = stmt.bind_vec("ids", Vec::<Vec<u8>>::new());
428        assert!(result.is_err());
429        assert!(matches!(result.unwrap_err(), Error::EmptyInClause(name) if name == "ids"));
430    }
431
432    #[test]
433    fn parser_preserves_postgres_cast_operator() {
434        let stmt = query("SELECT (ord - 1)::int AS matched WHERE id = :id")
435            .unwrap()
436            .bind("id", "quote-id");
437
438        let (sql, values) = stmt.to_sql().unwrap();
439
440        assert_eq!(sql, "SELECT (ord - 1)::int AS matched WHERE id = $1");
441        assert_eq!(values.len(), 1);
442    }
443}