cdk_sql_common/
stmt.rs

1//! Stataments mod
2use std::collections::HashMap;
3use std::sync::{Arc, RwLock};
4
5use cdk_common::database::Error;
6use once_cell::sync::Lazy;
7
8use crate::database::DatabaseExecutor;
9use crate::value::Value;
10
11/// The Column type
12pub type Column = Value;
13
14/// Expected response type for a given SQL statement
15#[derive(Debug, Clone, Copy, Default)]
16pub enum ExpectedSqlResponse {
17    /// A single row
18    SingleRow,
19    /// All the rows that matches a query
20    #[default]
21    ManyRows,
22    /// How many rows were affected by the query
23    AffectedRows,
24    /// Return the first column of the first row
25    Pluck,
26    /// Batch
27    Batch,
28}
29
30/// Part value
31#[derive(Debug, Clone)]
32pub enum PlaceholderValue {
33    /// Value
34    Value(Value),
35    /// Set
36    Set(Vec<Value>),
37}
38
39impl From<Value> for PlaceholderValue {
40    fn from(value: Value) -> Self {
41        PlaceholderValue::Value(value)
42    }
43}
44
45impl From<Vec<Value>> for PlaceholderValue {
46    fn from(value: Vec<Value>) -> Self {
47        PlaceholderValue::Set(value)
48    }
49}
50
51/// SQL Part
52#[derive(Debug, Clone)]
53pub enum SqlPart {
54    /// Raw SQL statement
55    Raw(Arc<str>),
56    /// Placeholder
57    Placeholder(Arc<str>, Option<PlaceholderValue>),
58}
59
60/// SQL parser error
61#[derive(Debug, PartialEq, thiserror::Error)]
62pub enum SqlParseError {
63    /// Invalid SQL
64    #[error("Unterminated String literal")]
65    UnterminatedStringLiteral,
66    /// Invalid placeholder name
67    #[error("Invalid placeholder name")]
68    InvalidPlaceholder,
69}
70
71/// Rudimentary SQL parser.
72///
73/// This function does not validate the SQL statement, it only extracts the placeholder to be
74/// database agnostic.
75pub fn split_sql_parts(input: &str) -> Result<Vec<SqlPart>, SqlParseError> {
76    let mut parts = Vec::new();
77    let mut current = String::new();
78    let mut chars = input.chars().peekable();
79
80    while let Some(&c) = chars.peek() {
81        match c {
82            '\'' | '"' => {
83                // Start of string literal
84                let quote = c;
85                current.push(chars.next().unwrap());
86
87                let mut closed = false;
88                while let Some(&next) = chars.peek() {
89                    current.push(chars.next().unwrap());
90
91                    if next == quote {
92                        if chars.peek() == Some(&quote) {
93                            // Escaped quote (e.g. '' inside strings)
94                            current.push(chars.next().unwrap());
95                        } else {
96                            closed = true;
97                            break;
98                        }
99                    }
100                }
101
102                if !closed {
103                    return Err(SqlParseError::UnterminatedStringLiteral);
104                }
105            }
106
107            ':' => {
108                // Flush current raw SQL
109                if !current.is_empty() {
110                    parts.push(SqlPart::Raw(current.clone().into()));
111                    current.clear();
112                }
113
114                chars.next(); // consume ':'
115                let mut name = String::new();
116
117                while let Some(&next) = chars.peek() {
118                    if next.is_alphanumeric() || next == '_' {
119                        name.push(chars.next().unwrap());
120                    } else {
121                        break;
122                    }
123                }
124
125                if name.is_empty() {
126                    return Err(SqlParseError::InvalidPlaceholder);
127                }
128
129                parts.push(SqlPart::Placeholder(name.into(), None));
130            }
131
132            _ => {
133                current.push(chars.next().unwrap());
134            }
135        }
136    }
137
138    if !current.is_empty() {
139        parts.push(SqlPart::Raw(current.into()));
140    }
141
142    Ok(parts)
143}
144
145type Cache = HashMap<String, (Vec<SqlPart>, Option<Arc<str>>)>;
146
147/// Sql message
148#[derive(Debug, Default)]
149pub struct Statement {
150    cache: Arc<RwLock<Cache>>,
151    cached_sql: Option<Arc<str>>,
152    sql: Option<String>,
153    /// The SQL statement
154    pub parts: Vec<SqlPart>,
155    /// The expected response type
156    pub expected_response: ExpectedSqlResponse,
157}
158
159impl Statement {
160    /// Creates a new statement
161    fn new(sql: &str, cache: Arc<RwLock<Cache>>) -> Result<Self, SqlParseError> {
162        let parsed = cache
163            .read()
164            .map(|cache| cache.get(sql).cloned())
165            .ok()
166            .flatten();
167
168        if let Some((parts, cached_sql)) = parsed {
169            Ok(Self {
170                parts,
171                cached_sql,
172                sql: None,
173                cache,
174                ..Default::default()
175            })
176        } else {
177            let parts = split_sql_parts(sql)?;
178
179            if let Ok(mut cache) = cache.write() {
180                cache.insert(sql.to_owned(), (parts.clone(), None));
181            } else {
182                tracing::warn!("Failed to acquire write lock for SQL statement cache");
183            }
184
185            Ok(Self {
186                parts,
187                sql: Some(sql.to_owned()),
188                cache,
189                ..Default::default()
190            })
191        }
192    }
193
194    /// Convert Statement into a SQL statement and the list of placeholders
195    ///
196    /// By default it converts the statement into placeholder using $1..$n placeholders which seems
197    /// to be more widely supported, although it can be reimplemented with other formats since part
198    /// is public
199    pub fn to_sql(self) -> Result<(String, Vec<Value>), Error> {
200        if let Some(cached_sql) = self.cached_sql {
201            let sql = cached_sql.to_string();
202            let values = self
203                .parts
204                .into_iter()
205                .map(|x| match x {
206                    SqlPart::Placeholder(name, value) => {
207                        match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
208                            PlaceholderValue::Value(value) => Ok(vec![value]),
209                            PlaceholderValue::Set(values) => Ok(values),
210                        }
211                    }
212                    SqlPart::Raw(_) => Ok(vec![]),
213                })
214                .collect::<Result<Vec<_>, Error>>()?
215                .into_iter()
216                .flatten()
217                .collect::<Vec<_>>();
218            return Ok((sql, values));
219        }
220
221        let mut placeholder_values = Vec::new();
222        let mut can_be_cached = true;
223        let sql = self
224            .parts
225            .into_iter()
226            .map(|x| match x {
227                SqlPart::Placeholder(name, value) => {
228                    match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
229                        PlaceholderValue::Value(value) => {
230                            placeholder_values.push(value);
231                            Ok::<_, Error>(format!("${}", placeholder_values.len()))
232                        }
233                        PlaceholderValue::Set(mut values) => {
234                            can_be_cached = false;
235                            let start_size = placeholder_values.len();
236                            placeholder_values.append(&mut values);
237                            let placeholders = (start_size + 1..=placeholder_values.len())
238                                .map(|i| format!("${i}"))
239                                .collect::<Vec<_>>()
240                                .join(", ");
241                            Ok(placeholders)
242                        }
243                    }
244                }
245                SqlPart::Raw(raw) => Ok(raw.trim().to_string()),
246            })
247            .collect::<Result<Vec<String>, _>>()?
248            .join(" ");
249
250        if can_be_cached {
251            if let Some(original_sql) = self.sql {
252                let _ = self.cache.write().map(|mut cache| {
253                    if let Some((_, cached_sql)) = cache.get_mut(&original_sql) {
254                        *cached_sql = Some(sql.clone().into());
255                    }
256                });
257            }
258        }
259
260        Ok((sql, placeholder_values))
261    }
262
263    /// Binds a given placeholder to a value.
264    #[inline]
265    pub fn bind<C, V>(mut self, name: C, value: V) -> Self
266    where
267        C: ToString,
268        V: Into<Value>,
269    {
270        let name = name.to_string();
271        let value = value.into();
272        let value: PlaceholderValue = value.into();
273
274        for part in self.parts.iter_mut() {
275            if let SqlPart::Placeholder(part_name, part_value) = part {
276                if **part_name == *name.as_str() {
277                    *part_value = Some(value.clone());
278                }
279            }
280        }
281
282        self
283    }
284
285    /// Binds a single variable with a vector.
286    ///
287    /// This will rewrite the function from `:foo` (where value is vec![1, 2, 3]) to `:foo0, :foo1,
288    /// :foo2` and binds each value from the value vector accordingly.
289    #[inline]
290    pub fn bind_vec<C, V>(mut self, name: C, value: Vec<V>) -> Self
291    where
292        C: ToString,
293        V: Into<Value>,
294    {
295        let name = name.to_string();
296        let value: PlaceholderValue = value
297            .into_iter()
298            .map(|x| x.into())
299            .collect::<Vec<Value>>()
300            .into();
301
302        for part in self.parts.iter_mut() {
303            if let SqlPart::Placeholder(part_name, part_value) = part {
304                if **part_name == *name.as_str() {
305                    *part_value = Some(value.clone());
306                }
307            }
308        }
309
310        self
311    }
312
313    /// Executes a query and returns the affected rows
314    pub async fn pluck<C>(self, conn: &C) -> Result<Option<Value>, Error>
315    where
316        C: DatabaseExecutor,
317    {
318        conn.pluck(self).await
319    }
320
321    /// Executes a query and returns the affected rows
322    pub async fn batch<C>(self, conn: &C) -> Result<(), Error>
323    where
324        C: DatabaseExecutor,
325    {
326        conn.batch(self).await
327    }
328
329    /// Executes a query and returns the affected rows
330    pub async fn execute<C>(self, conn: &C) -> Result<usize, Error>
331    where
332        C: DatabaseExecutor,
333    {
334        conn.execute(self).await
335    }
336
337    /// Runs the query and returns the first row or None
338    pub async fn fetch_one<C>(self, conn: &C) -> Result<Option<Vec<Column>>, Error>
339    where
340        C: DatabaseExecutor,
341    {
342        conn.fetch_one(self).await
343    }
344
345    /// Runs the query and returns the first row or None
346    pub async fn fetch_all<C>(self, conn: &C) -> Result<Vec<Vec<Column>>, Error>
347    where
348        C: DatabaseExecutor,
349    {
350        conn.fetch_all(self).await
351    }
352}
353
354/// Creates a new query statement
355#[inline(always)]
356pub fn query(sql: &str) -> Result<Statement, Error> {
357    static CACHE: Lazy<Arc<RwLock<Cache>>> = Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
358    Statement::new(sql, CACHE.clone()).map_err(|e| Error::Database(Box::new(e)))
359}