cdk_sql_common/
stmt.rs

1//! Stataments mod
2use std::collections::HashMap;
3use std::sync::{Arc, RwLock};
4
5use cdk_common::database::Error;
6use once_cell::sync::Lazy;
7
8use crate::database::DatabaseExecutor;
9use crate::value::Value;
10
11/// The Column type
12pub type Column = Value;
13
14/// Expected response type for a given SQL statement
15#[derive(Debug, Clone, Copy, Default)]
16pub enum ExpectedSqlResponse {
17    /// A single row
18    SingleRow,
19    /// All the rows that matches a query
20    #[default]
21    ManyRows,
22    /// How many rows were affected by the query
23    AffectedRows,
24    /// Return the first column of the first row
25    Pluck,
26    /// Batch
27    Batch,
28}
29
30/// Part value
31#[derive(Debug, Clone)]
32pub enum PlaceholderValue {
33    /// Value
34    Value(Value),
35    /// Set
36    Set(Vec<Value>),
37}
38
39impl From<Value> for PlaceholderValue {
40    fn from(value: Value) -> Self {
41        PlaceholderValue::Value(value)
42    }
43}
44
45impl From<Vec<Value>> for PlaceholderValue {
46    fn from(value: Vec<Value>) -> Self {
47        PlaceholderValue::Set(value)
48    }
49}
50
51/// SQL Part
52#[derive(Debug, Clone)]
53pub enum SqlPart {
54    /// Raw SQL statement
55    Raw(Arc<str>),
56    /// Placeholder
57    Placeholder(Arc<str>, Option<PlaceholderValue>),
58}
59
60/// SQL parser error
61#[derive(Debug, PartialEq, thiserror::Error)]
62pub enum SqlParseError {
63    /// Invalid SQL
64    #[error("Unterminated String literal")]
65    UnterminatedStringLiteral,
66    /// Invalid placeholder name
67    #[error("Invalid placeholder name")]
68    InvalidPlaceholder,
69}
70
71/// Rudimentary SQL parser.
72///
73/// This function does not validate the SQL statement, it only extracts the placeholder to be
74/// database agnostic.
75pub fn split_sql_parts(input: &str) -> Result<Vec<SqlPart>, SqlParseError> {
76    let mut parts = Vec::new();
77    let mut current = String::new();
78    let mut chars = input.chars().peekable();
79
80    while let Some(&c) = chars.peek() {
81        match c {
82            '\'' | '"' => {
83                // Start of string literal
84                let quote = c;
85                current.push(chars.next().unwrap());
86
87                let mut closed = false;
88                while let Some(&next) = chars.peek() {
89                    current.push(chars.next().unwrap());
90
91                    if next == quote {
92                        if chars.peek() == Some(&quote) {
93                            // Escaped quote (e.g. '' inside strings)
94                            current.push(chars.next().unwrap());
95                        } else {
96                            closed = true;
97                            break;
98                        }
99                    }
100                }
101
102                if !closed {
103                    return Err(SqlParseError::UnterminatedStringLiteral);
104                }
105            }
106
107            '-' => {
108                current.push(chars.next().unwrap());
109                if chars.peek() == Some(&'-') {
110                    while let Some(&next) = chars.peek() {
111                        current.push(chars.next().unwrap());
112                        if next == '\n' {
113                            break;
114                        }
115                    }
116                }
117            }
118
119            ':' => {
120                // Flush current raw SQL
121                if !current.is_empty() {
122                    parts.push(SqlPart::Raw(current.clone().into()));
123                    current.clear();
124                }
125
126                chars.next(); // consume ':'
127                let mut name = String::new();
128
129                while let Some(&next) = chars.peek() {
130                    if next.is_alphanumeric() || next == '_' {
131                        name.push(chars.next().unwrap());
132                    } else {
133                        break;
134                    }
135                }
136
137                if name.is_empty() {
138                    return Err(SqlParseError::InvalidPlaceholder);
139                }
140
141                parts.push(SqlPart::Placeholder(name.into(), None));
142            }
143
144            _ => {
145                current.push(chars.next().unwrap());
146            }
147        }
148    }
149
150    if !current.is_empty() {
151        parts.push(SqlPart::Raw(current.into()));
152    }
153
154    Ok(parts)
155}
156
157type Cache = HashMap<String, (Vec<SqlPart>, Option<Arc<str>>)>;
158
159/// Sql message
160#[derive(Debug, Default)]
161pub struct Statement {
162    cache: Arc<RwLock<Cache>>,
163    cached_sql: Option<Arc<str>>,
164    sql: Option<String>,
165    /// The SQL statement
166    pub parts: Vec<SqlPart>,
167    /// The expected response type
168    pub expected_response: ExpectedSqlResponse,
169}
170
171impl Statement {
172    /// Creates a new statement
173    fn new(sql: &str, cache: Arc<RwLock<Cache>>) -> Result<Self, SqlParseError> {
174        let parsed = cache
175            .read()
176            .map(|cache| cache.get(sql).cloned())
177            .ok()
178            .flatten();
179
180        if let Some((parts, cached_sql)) = parsed {
181            Ok(Self {
182                parts,
183                cached_sql,
184                sql: None,
185                cache,
186                ..Default::default()
187            })
188        } else {
189            let parts = split_sql_parts(sql)?;
190
191            if let Ok(mut cache) = cache.write() {
192                cache.insert(sql.to_owned(), (parts.clone(), None));
193            } else {
194                tracing::warn!("Failed to acquire write lock for SQL statement cache");
195            }
196
197            Ok(Self {
198                parts,
199                sql: Some(sql.to_owned()),
200                cache,
201                ..Default::default()
202            })
203        }
204    }
205
206    /// Convert Statement into a SQL statement and the list of placeholders
207    ///
208    /// By default it converts the statement into placeholder using $1..$n placeholders which seems
209    /// to be more widely supported, although it can be reimplemented with other formats since part
210    /// is public
211    pub fn to_sql(self) -> Result<(String, Vec<Value>), Error> {
212        if let Some(cached_sql) = self.cached_sql {
213            let sql = cached_sql.to_string();
214            let values = self
215                .parts
216                .into_iter()
217                .map(|x| match x {
218                    SqlPart::Placeholder(name, value) => {
219                        match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
220                            PlaceholderValue::Value(value) => Ok(vec![value]),
221                            PlaceholderValue::Set(values) => Ok(values),
222                        }
223                    }
224                    SqlPart::Raw(_) => Ok(vec![]),
225                })
226                .collect::<Result<Vec<_>, Error>>()?
227                .into_iter()
228                .flatten()
229                .collect::<Vec<_>>();
230            return Ok((sql, values));
231        }
232
233        let mut placeholder_values = Vec::new();
234        let mut can_be_cached = true;
235        let sql = self
236            .parts
237            .into_iter()
238            .map(|x| match x {
239                SqlPart::Placeholder(name, value) => {
240                    match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
241                        PlaceholderValue::Value(value) => {
242                            placeholder_values.push(value);
243                            Ok::<_, Error>(format!("${}", placeholder_values.len()))
244                        }
245                        PlaceholderValue::Set(mut values) => {
246                            can_be_cached = false;
247                            let start_size = placeholder_values.len();
248                            placeholder_values.append(&mut values);
249                            let placeholders = (start_size + 1..=placeholder_values.len())
250                                .map(|i| format!("${i}"))
251                                .collect::<Vec<_>>()
252                                .join(", ");
253                            Ok(placeholders)
254                        }
255                    }
256                }
257                SqlPart::Raw(raw) => Ok(raw.trim().to_string()),
258            })
259            .collect::<Result<Vec<String>, _>>()?
260            .join(" ");
261
262        if can_be_cached {
263            if let Some(original_sql) = self.sql {
264                let _ = self.cache.write().map(|mut cache| {
265                    if let Some((_, cached_sql)) = cache.get_mut(&original_sql) {
266                        *cached_sql = Some(sql.clone().into());
267                    }
268                });
269            }
270        }
271
272        Ok((sql, placeholder_values))
273    }
274
275    /// Binds a given placeholder to a value.
276    #[inline]
277    pub fn bind<C, V>(mut self, name: C, value: V) -> Self
278    where
279        C: ToString,
280        V: Into<Value>,
281    {
282        let name = name.to_string();
283        let value = value.into();
284        let value: PlaceholderValue = value.into();
285
286        for part in self.parts.iter_mut() {
287            if let SqlPart::Placeholder(part_name, part_value) = part {
288                if **part_name == *name.as_str() {
289                    *part_value = Some(value.clone());
290                }
291            }
292        }
293
294        self
295    }
296
297    /// Binds a single variable with a vector.
298    ///
299    /// This will rewrite the function from `:foo` (where value is vec![1, 2, 3]) to `:foo0, :foo1,
300    /// :foo2` and binds each value from the value vector accordingly.
301    #[inline]
302    pub fn bind_vec<C, V>(mut self, name: C, value: Vec<V>) -> Self
303    where
304        C: ToString,
305        V: Into<Value>,
306    {
307        let name = name.to_string();
308        let value: PlaceholderValue = value
309            .into_iter()
310            .map(|x| x.into())
311            .collect::<Vec<Value>>()
312            .into();
313
314        for part in self.parts.iter_mut() {
315            if let SqlPart::Placeholder(part_name, part_value) = part {
316                if **part_name == *name.as_str() {
317                    *part_value = Some(value.clone());
318                }
319            }
320        }
321
322        self
323    }
324
325    /// Executes a query and returns the affected rows
326    pub async fn pluck<C>(self, conn: &C) -> Result<Option<Value>, Error>
327    where
328        C: DatabaseExecutor,
329    {
330        conn.pluck(self).await
331    }
332
333    /// Executes a query and returns the affected rows
334    pub async fn batch<C>(self, conn: &C) -> Result<(), Error>
335    where
336        C: DatabaseExecutor,
337    {
338        conn.batch(self).await
339    }
340
341    /// Executes a query and returns the affected rows
342    pub async fn execute<C>(self, conn: &C) -> Result<usize, Error>
343    where
344        C: DatabaseExecutor,
345    {
346        conn.execute(self).await
347    }
348
349    /// Runs the query and returns the first row or None
350    pub async fn fetch_one<C>(self, conn: &C) -> Result<Option<Vec<Column>>, Error>
351    where
352        C: DatabaseExecutor,
353    {
354        conn.fetch_one(self).await
355    }
356
357    /// Runs the query and returns the first row or None
358    pub async fn fetch_all<C>(self, conn: &C) -> Result<Vec<Vec<Column>>, Error>
359    where
360        C: DatabaseExecutor,
361    {
362        conn.fetch_all(self).await
363    }
364}
365
366/// Creates a new query statement
367#[inline(always)]
368pub fn query(sql: &str) -> Result<Statement, Error> {
369    static CACHE: Lazy<Arc<RwLock<Cache>>> = Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
370    Statement::new(sql, CACHE.clone()).map_err(|e| Error::Database(Box::new(e)))
371}