Skip to main content

cdk_sql_common/
stmt.rs

1//! Stataments mod
2use std::collections::HashMap;
3use std::sync::{Arc, RwLock};
4
5use cdk_common::database::Error;
6use once_cell::sync::Lazy;
7
8use crate::database::DatabaseExecutor;
9use crate::value::Value;
10
11/// The Column type
12pub type Column = Value;
13
14/// Expected response type for a given SQL statement
15#[derive(Debug, Clone, Copy, Default)]
16pub enum ExpectedSqlResponse {
17    /// A single row
18    SingleRow,
19    /// All the rows that matches a query
20    #[default]
21    ManyRows,
22    /// How many rows were affected by the query
23    AffectedRows,
24    /// Return the first column of the first row
25    Pluck,
26    /// Batch
27    Batch,
28}
29
30/// Part value
31#[derive(Debug, Clone)]
32pub enum PlaceholderValue {
33    /// Value
34    Value(Value),
35    /// Set
36    Set(Vec<Value>),
37}
38
39impl From<Value> for PlaceholderValue {
40    fn from(value: Value) -> Self {
41        PlaceholderValue::Value(value)
42    }
43}
44
45impl From<Vec<Value>> for PlaceholderValue {
46    fn from(value: Vec<Value>) -> Self {
47        PlaceholderValue::Set(value)
48    }
49}
50
51/// SQL Part
52#[derive(Debug, Clone)]
53pub enum SqlPart {
54    /// Raw SQL statement
55    Raw(Arc<str>),
56    /// Placeholder
57    Placeholder(Arc<str>, Option<PlaceholderValue>),
58}
59
60/// SQL parser error
61#[derive(Debug, PartialEq, thiserror::Error)]
62pub enum SqlParseError {
63    /// Invalid SQL
64    #[error("Unterminated String literal")]
65    UnterminatedStringLiteral,
66    /// Invalid placeholder name
67    #[error("Invalid placeholder name")]
68    InvalidPlaceholder,
69}
70
71/// Rudimentary SQL parser.
72///
73/// This function does not validate the SQL statement, it only extracts the placeholder to be
74/// database agnostic.
75pub fn split_sql_parts(input: &str) -> Result<Vec<SqlPart>, SqlParseError> {
76    let mut parts = Vec::new();
77    let mut current = String::new();
78    let mut chars = input.chars().peekable();
79
80    while let Some(&c) = chars.peek() {
81        match c {
82            '\'' | '"' => {
83                // Start of string literal
84                let quote = c;
85                current.push(
86                    chars
87                        .next()
88                        .ok_or(SqlParseError::UnterminatedStringLiteral)?,
89                );
90
91                let mut closed = false;
92                while let Some(&next) = chars.peek() {
93                    current.push(
94                        chars
95                            .next()
96                            .ok_or(SqlParseError::UnterminatedStringLiteral)?,
97                    );
98
99                    if next == quote {
100                        if chars.peek() == Some(&quote) {
101                            // Escaped quote (e.g. '' inside strings)
102                            current.push(
103                                chars
104                                    .next()
105                                    .ok_or(SqlParseError::UnterminatedStringLiteral)?,
106                            );
107                        } else {
108                            closed = true;
109                            break;
110                        }
111                    }
112                }
113
114                if !closed {
115                    return Err(SqlParseError::UnterminatedStringLiteral);
116                }
117            }
118
119            '-' => {
120                current.push(
121                    chars
122                        .next()
123                        .ok_or(SqlParseError::UnterminatedStringLiteral)?,
124                );
125
126                if chars.peek() == Some(&'-') {
127                    while let Some(&next) = chars.peek() {
128                        current.push(
129                            chars
130                                .next()
131                                .ok_or(SqlParseError::UnterminatedStringLiteral)?,
132                        );
133                        if next == '\n' {
134                            break;
135                        }
136                    }
137                }
138            }
139
140            ':' => {
141                // Flush current raw SQL
142                if !current.is_empty() {
143                    parts.push(SqlPart::Raw(current.clone().into()));
144                    current.clear();
145                }
146
147                chars.next(); // consume ':'
148                let mut name = String::new();
149
150                while let Some(&next) = chars.peek() {
151                    if next.is_alphanumeric() || next == '_' {
152                        name.push(
153                            chars
154                                .next()
155                                .ok_or(SqlParseError::UnterminatedStringLiteral)?,
156                        );
157                    } else {
158                        break;
159                    }
160                }
161
162                if name.is_empty() {
163                    return Err(SqlParseError::InvalidPlaceholder);
164                }
165
166                parts.push(SqlPart::Placeholder(name.into(), None));
167            }
168
169            _ => {
170                current.push(
171                    chars
172                        .next()
173                        .ok_or(SqlParseError::UnterminatedStringLiteral)?,
174                );
175            }
176        }
177    }
178
179    if !current.is_empty() {
180        parts.push(SqlPart::Raw(current.into()));
181    }
182
183    Ok(parts)
184}
185
186type Cache = HashMap<String, (Vec<SqlPart>, Option<Arc<str>>)>;
187
188/// Sql message
189#[derive(Debug, Default)]
190pub struct Statement {
191    cache: Arc<RwLock<Cache>>,
192    cached_sql: Option<Arc<str>>,
193    sql: Option<String>,
194    /// The SQL statement
195    pub parts: Vec<SqlPart>,
196    /// The expected response type
197    pub expected_response: ExpectedSqlResponse,
198}
199
200impl Statement {
201    /// Creates a new statement
202    fn new(sql: &str, cache: Arc<RwLock<Cache>>) -> Result<Self, SqlParseError> {
203        let parsed = cache
204            .read()
205            .map(|cache| cache.get(sql).cloned())
206            .ok()
207            .flatten();
208
209        if let Some((parts, cached_sql)) = parsed {
210            Ok(Self {
211                parts,
212                cached_sql,
213                sql: None,
214                cache,
215                ..Default::default()
216            })
217        } else {
218            let parts = split_sql_parts(sql)?;
219
220            if let Ok(mut cache) = cache.write() {
221                cache.insert(sql.to_owned(), (parts.clone(), None));
222            } else {
223                tracing::warn!("Failed to acquire write lock for SQL statement cache");
224            }
225
226            Ok(Self {
227                parts,
228                sql: Some(sql.to_owned()),
229                cache,
230                ..Default::default()
231            })
232        }
233    }
234
235    /// Convert Statement into a SQL statement and the list of placeholders
236    ///
237    /// By default it converts the statement into placeholder using $1..$n placeholders which seems
238    /// to be more widely supported, although it can be reimplemented with other formats since part
239    /// is public
240    pub fn to_sql(self) -> Result<(String, Vec<Value>), Error> {
241        if let Some(cached_sql) = self.cached_sql {
242            let sql = cached_sql.to_string();
243            let values = self
244                .parts
245                .into_iter()
246                .map(|x| match x {
247                    SqlPart::Placeholder(name, value) => {
248                        match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
249                            PlaceholderValue::Value(value) => Ok(vec![value]),
250                            PlaceholderValue::Set(values) => Ok(values),
251                        }
252                    }
253                    SqlPart::Raw(_) => Ok(vec![]),
254                })
255                .collect::<Result<Vec<_>, Error>>()?
256                .into_iter()
257                .flatten()
258                .collect::<Vec<_>>();
259            return Ok((sql, values));
260        }
261
262        let mut placeholder_values = Vec::new();
263        let mut can_be_cached = true;
264        let sql = self
265            .parts
266            .into_iter()
267            .map(|x| match x {
268                SqlPart::Placeholder(name, value) => {
269                    match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
270                        PlaceholderValue::Value(value) => {
271                            placeholder_values.push(value);
272                            Ok::<_, Error>(format!("${}", placeholder_values.len()))
273                        }
274                        PlaceholderValue::Set(mut values) => {
275                            can_be_cached = false;
276                            let start_size = placeholder_values.len();
277                            placeholder_values.append(&mut values);
278                            let placeholders = (start_size + 1..=placeholder_values.len())
279                                .map(|i| format!("${i}"))
280                                .collect::<Vec<_>>()
281                                .join(", ");
282                            Ok(placeholders)
283                        }
284                    }
285                }
286                SqlPart::Raw(raw) => Ok(raw.trim().to_string()),
287            })
288            .collect::<Result<Vec<String>, _>>()?
289            .join(" ");
290
291        if can_be_cached {
292            if let Some(original_sql) = self.sql {
293                let _ = self.cache.write().map(|mut cache| {
294                    if let Some((_, cached_sql)) = cache.get_mut(&original_sql) {
295                        *cached_sql = Some(sql.clone().into());
296                    }
297                });
298            }
299        }
300
301        Ok((sql, placeholder_values))
302    }
303
304    /// Binds a given placeholder to a value.
305    #[inline]
306    pub fn bind<C, V>(mut self, name: C, value: V) -> Self
307    where
308        C: ToString,
309        V: Into<Value>,
310    {
311        let name = name.to_string();
312        let value = value.into();
313        let value: PlaceholderValue = value.into();
314
315        for part in self.parts.iter_mut() {
316            if let SqlPart::Placeholder(part_name, part_value) = part {
317                if **part_name == *name.as_str() {
318                    *part_value = Some(value.clone());
319                }
320            }
321        }
322
323        self
324    }
325
326    /// Binds a single variable with a vector.
327    ///
328    /// This will rewrite the function from `:foo` (where value is vec![1, 2, 3]) to `:foo0, :foo1,
329    /// :foo2` and binds each value from the value vector accordingly.
330    #[inline]
331    pub fn bind_vec<C, V>(mut self, name: C, value: Vec<V>) -> Self
332    where
333        C: ToString,
334        V: Into<Value>,
335    {
336        let name = name.to_string();
337        let value: PlaceholderValue = value
338            .into_iter()
339            .map(|x| x.into())
340            .collect::<Vec<Value>>()
341            .into();
342
343        for part in self.parts.iter_mut() {
344            if let SqlPart::Placeholder(part_name, part_value) = part {
345                if **part_name == *name.as_str() {
346                    *part_value = Some(value.clone());
347                }
348            }
349        }
350
351        self
352    }
353
354    /// Executes a query and returns the affected rows
355    pub async fn pluck<C>(self, conn: &C) -> Result<Option<Value>, Error>
356    where
357        C: DatabaseExecutor,
358    {
359        conn.pluck(self).await
360    }
361
362    /// Executes a query and returns the affected rows
363    pub async fn batch<C>(self, conn: &C) -> Result<(), Error>
364    where
365        C: DatabaseExecutor,
366    {
367        conn.batch(self).await
368    }
369
370    /// Executes a query and returns the affected rows
371    pub async fn execute<C>(self, conn: &C) -> Result<usize, Error>
372    where
373        C: DatabaseExecutor,
374    {
375        conn.execute(self).await
376    }
377
378    /// Runs the query and returns the first row or None
379    pub async fn fetch_one<C>(self, conn: &C) -> Result<Option<Vec<Column>>, Error>
380    where
381        C: DatabaseExecutor,
382    {
383        conn.fetch_one(self).await
384    }
385
386    /// Runs the query and returns the first row or None
387    pub async fn fetch_all<C>(self, conn: &C) -> Result<Vec<Vec<Column>>, Error>
388    where
389        C: DatabaseExecutor,
390    {
391        conn.fetch_all(self).await
392    }
393}
394
395/// Creates a new query statement
396#[inline(always)]
397pub fn query(sql: &str) -> Result<Statement, Error> {
398    static CACHE: Lazy<Arc<RwLock<Cache>>> = Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
399    Statement::new(sql, CACHE.clone()).map_err(|e| Error::Database(Box::new(e)))
400}