1use std::collections::HashMap;
3use std::sync::{Arc, RwLock};
4
5use cdk_common::database::Error;
6use once_cell::sync::Lazy;
7
8use crate::database::DatabaseExecutor;
9use crate::value::Value;
10
11pub type Column = Value;
13
14#[derive(Debug, Clone, Copy, Default)]
16pub enum ExpectedSqlResponse {
17 SingleRow,
19 #[default]
21 ManyRows,
22 AffectedRows,
24 Pluck,
26 Batch,
28}
29
30#[derive(Debug, Clone)]
32pub enum PlaceholderValue {
33 Value(Value),
35 Set(Vec<Value>),
37}
38
39impl From<Value> for PlaceholderValue {
40 fn from(value: Value) -> Self {
41 PlaceholderValue::Value(value)
42 }
43}
44
45impl From<Vec<Value>> for PlaceholderValue {
46 fn from(value: Vec<Value>) -> Self {
47 PlaceholderValue::Set(value)
48 }
49}
50
51#[derive(Debug, Clone)]
53pub enum SqlPart {
54 Raw(Arc<str>),
56 Placeholder(Arc<str>, Option<PlaceholderValue>),
58}
59
60#[derive(Debug, PartialEq, thiserror::Error)]
62pub enum SqlParseError {
63 #[error("Unterminated String literal")]
65 UnterminatedStringLiteral,
66 #[error("Invalid placeholder name")]
68 InvalidPlaceholder,
69}
70
71pub fn split_sql_parts(input: &str) -> Result<Vec<SqlPart>, SqlParseError> {
76 let mut parts = Vec::new();
77 let mut current = String::new();
78 let mut chars = input.chars().peekable();
79
80 while let Some(&c) = chars.peek() {
81 match c {
82 '\'' | '"' => {
83 let quote = c;
85 current.push(chars.next().unwrap());
86
87 let mut closed = false;
88 while let Some(&next) = chars.peek() {
89 current.push(chars.next().unwrap());
90
91 if next == quote {
92 if chars.peek() == Some("e) {
93 current.push(chars.next().unwrap());
95 } else {
96 closed = true;
97 break;
98 }
99 }
100 }
101
102 if !closed {
103 return Err(SqlParseError::UnterminatedStringLiteral);
104 }
105 }
106
107 '-' => {
108 current.push(chars.next().unwrap());
109 if chars.peek() == Some(&'-') {
110 while let Some(&next) = chars.peek() {
111 current.push(chars.next().unwrap());
112 if next == '\n' {
113 break;
114 }
115 }
116 }
117 }
118
119 ':' => {
120 if !current.is_empty() {
122 parts.push(SqlPart::Raw(current.clone().into()));
123 current.clear();
124 }
125
126 chars.next(); let mut name = String::new();
128
129 while let Some(&next) = chars.peek() {
130 if next.is_alphanumeric() || next == '_' {
131 name.push(chars.next().unwrap());
132 } else {
133 break;
134 }
135 }
136
137 if name.is_empty() {
138 return Err(SqlParseError::InvalidPlaceholder);
139 }
140
141 parts.push(SqlPart::Placeholder(name.into(), None));
142 }
143
144 _ => {
145 current.push(chars.next().unwrap());
146 }
147 }
148 }
149
150 if !current.is_empty() {
151 parts.push(SqlPart::Raw(current.into()));
152 }
153
154 Ok(parts)
155}
156
157type Cache = HashMap<String, (Vec<SqlPart>, Option<Arc<str>>)>;
158
159#[derive(Debug, Default)]
161pub struct Statement {
162 cache: Arc<RwLock<Cache>>,
163 cached_sql: Option<Arc<str>>,
164 sql: Option<String>,
165 pub parts: Vec<SqlPart>,
167 pub expected_response: ExpectedSqlResponse,
169}
170
171impl Statement {
172 fn new(sql: &str, cache: Arc<RwLock<Cache>>) -> Result<Self, SqlParseError> {
174 let parsed = cache
175 .read()
176 .map(|cache| cache.get(sql).cloned())
177 .ok()
178 .flatten();
179
180 if let Some((parts, cached_sql)) = parsed {
181 Ok(Self {
182 parts,
183 cached_sql,
184 sql: None,
185 cache,
186 ..Default::default()
187 })
188 } else {
189 let parts = split_sql_parts(sql)?;
190
191 if let Ok(mut cache) = cache.write() {
192 cache.insert(sql.to_owned(), (parts.clone(), None));
193 } else {
194 tracing::warn!("Failed to acquire write lock for SQL statement cache");
195 }
196
197 Ok(Self {
198 parts,
199 sql: Some(sql.to_owned()),
200 cache,
201 ..Default::default()
202 })
203 }
204 }
205
206 pub fn to_sql(self) -> Result<(String, Vec<Value>), Error> {
212 if let Some(cached_sql) = self.cached_sql {
213 let sql = cached_sql.to_string();
214 let values = self
215 .parts
216 .into_iter()
217 .map(|x| match x {
218 SqlPart::Placeholder(name, value) => {
219 match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
220 PlaceholderValue::Value(value) => Ok(vec![value]),
221 PlaceholderValue::Set(values) => Ok(values),
222 }
223 }
224 SqlPart::Raw(_) => Ok(vec![]),
225 })
226 .collect::<Result<Vec<_>, Error>>()?
227 .into_iter()
228 .flatten()
229 .collect::<Vec<_>>();
230 return Ok((sql, values));
231 }
232
233 let mut placeholder_values = Vec::new();
234 let mut can_be_cached = true;
235 let sql = self
236 .parts
237 .into_iter()
238 .map(|x| match x {
239 SqlPart::Placeholder(name, value) => {
240 match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
241 PlaceholderValue::Value(value) => {
242 placeholder_values.push(value);
243 Ok::<_, Error>(format!("${}", placeholder_values.len()))
244 }
245 PlaceholderValue::Set(mut values) => {
246 can_be_cached = false;
247 let start_size = placeholder_values.len();
248 placeholder_values.append(&mut values);
249 let placeholders = (start_size + 1..=placeholder_values.len())
250 .map(|i| format!("${i}"))
251 .collect::<Vec<_>>()
252 .join(", ");
253 Ok(placeholders)
254 }
255 }
256 }
257 SqlPart::Raw(raw) => Ok(raw.trim().to_string()),
258 })
259 .collect::<Result<Vec<String>, _>>()?
260 .join(" ");
261
262 if can_be_cached {
263 if let Some(original_sql) = self.sql {
264 let _ = self.cache.write().map(|mut cache| {
265 if let Some((_, cached_sql)) = cache.get_mut(&original_sql) {
266 *cached_sql = Some(sql.clone().into());
267 }
268 });
269 }
270 }
271
272 Ok((sql, placeholder_values))
273 }
274
275 #[inline]
277 pub fn bind<C, V>(mut self, name: C, value: V) -> Self
278 where
279 C: ToString,
280 V: Into<Value>,
281 {
282 let name = name.to_string();
283 let value = value.into();
284 let value: PlaceholderValue = value.into();
285
286 for part in self.parts.iter_mut() {
287 if let SqlPart::Placeholder(part_name, part_value) = part {
288 if **part_name == *name.as_str() {
289 *part_value = Some(value.clone());
290 }
291 }
292 }
293
294 self
295 }
296
297 #[inline]
302 pub fn bind_vec<C, V>(mut self, name: C, value: Vec<V>) -> Self
303 where
304 C: ToString,
305 V: Into<Value>,
306 {
307 let name = name.to_string();
308 let value: PlaceholderValue = value
309 .into_iter()
310 .map(|x| x.into())
311 .collect::<Vec<Value>>()
312 .into();
313
314 for part in self.parts.iter_mut() {
315 if let SqlPart::Placeholder(part_name, part_value) = part {
316 if **part_name == *name.as_str() {
317 *part_value = Some(value.clone());
318 }
319 }
320 }
321
322 self
323 }
324
325 pub async fn pluck<C>(self, conn: &C) -> Result<Option<Value>, Error>
327 where
328 C: DatabaseExecutor,
329 {
330 conn.pluck(self).await
331 }
332
333 pub async fn batch<C>(self, conn: &C) -> Result<(), Error>
335 where
336 C: DatabaseExecutor,
337 {
338 conn.batch(self).await
339 }
340
341 pub async fn execute<C>(self, conn: &C) -> Result<usize, Error>
343 where
344 C: DatabaseExecutor,
345 {
346 conn.execute(self).await
347 }
348
349 pub async fn fetch_one<C>(self, conn: &C) -> Result<Option<Vec<Column>>, Error>
351 where
352 C: DatabaseExecutor,
353 {
354 conn.fetch_one(self).await
355 }
356
357 pub async fn fetch_all<C>(self, conn: &C) -> Result<Vec<Vec<Column>>, Error>
359 where
360 C: DatabaseExecutor,
361 {
362 conn.fetch_all(self).await
363 }
364}
365
366#[inline(always)]
368pub fn query(sql: &str) -> Result<Statement, Error> {
369 static CACHE: Lazy<Arc<RwLock<Cache>>> = Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
370 Statement::new(sql, CACHE.clone()).map_err(|e| Error::Database(Box::new(e)))
371}