1use std::collections::HashMap;
3use std::sync::{Arc, RwLock};
4
5use cdk_common::database::Error;
6use once_cell::sync::Lazy;
7
8use crate::database::DatabaseExecutor;
9use crate::value::Value;
10
11pub type Column = Value;
13
14#[derive(Debug, Clone, Copy, Default)]
16pub enum ExpectedSqlResponse {
17 SingleRow,
19 #[default]
21 ManyRows,
22 AffectedRows,
24 Pluck,
26 Batch,
28}
29
30#[derive(Debug, Clone)]
32pub enum PlaceholderValue {
33 Value(Value),
35 Set(Vec<Value>),
37}
38
39impl From<Value> for PlaceholderValue {
40 fn from(value: Value) -> Self {
41 PlaceholderValue::Value(value)
42 }
43}
44
45impl From<Vec<Value>> for PlaceholderValue {
46 fn from(value: Vec<Value>) -> Self {
47 PlaceholderValue::Set(value)
48 }
49}
50
51#[derive(Debug, Clone)]
53pub enum SqlPart {
54 Raw(Arc<str>),
56 Placeholder(Arc<str>, Option<PlaceholderValue>),
58}
59
60#[derive(Debug, PartialEq, thiserror::Error)]
62pub enum SqlParseError {
63 #[error("Unterminated String literal")]
65 UnterminatedStringLiteral,
66 #[error("Invalid placeholder name")]
68 InvalidPlaceholder,
69}
70
71pub fn split_sql_parts(input: &str) -> Result<Vec<SqlPart>, SqlParseError> {
76 let mut parts = Vec::new();
77 let mut current = String::new();
78 let mut chars = input.chars().peekable();
79
80 while let Some(&c) = chars.peek() {
81 match c {
82 '\'' | '"' => {
83 let quote = c;
85 current.push(
86 chars
87 .next()
88 .ok_or(SqlParseError::UnterminatedStringLiteral)?,
89 );
90
91 let mut closed = false;
92 while let Some(&next) = chars.peek() {
93 current.push(
94 chars
95 .next()
96 .ok_or(SqlParseError::UnterminatedStringLiteral)?,
97 );
98
99 if next == quote {
100 if chars.peek() == Some("e) {
101 current.push(
103 chars
104 .next()
105 .ok_or(SqlParseError::UnterminatedStringLiteral)?,
106 );
107 } else {
108 closed = true;
109 break;
110 }
111 }
112 }
113
114 if !closed {
115 return Err(SqlParseError::UnterminatedStringLiteral);
116 }
117 }
118
119 '-' => {
120 current.push(
121 chars
122 .next()
123 .ok_or(SqlParseError::UnterminatedStringLiteral)?,
124 );
125
126 if chars.peek() == Some(&'-') {
127 while let Some(&next) = chars.peek() {
128 current.push(
129 chars
130 .next()
131 .ok_or(SqlParseError::UnterminatedStringLiteral)?,
132 );
133 if next == '\n' {
134 break;
135 }
136 }
137 }
138 }
139
140 ':' => {
141 if !current.is_empty() {
143 parts.push(SqlPart::Raw(current.clone().into()));
144 current.clear();
145 }
146
147 chars.next(); let mut name = String::new();
149
150 while let Some(&next) = chars.peek() {
151 if next.is_alphanumeric() || next == '_' {
152 name.push(
153 chars
154 .next()
155 .ok_or(SqlParseError::UnterminatedStringLiteral)?,
156 );
157 } else {
158 break;
159 }
160 }
161
162 if name.is_empty() {
163 return Err(SqlParseError::InvalidPlaceholder);
164 }
165
166 parts.push(SqlPart::Placeholder(name.into(), None));
167 }
168
169 _ => {
170 current.push(
171 chars
172 .next()
173 .ok_or(SqlParseError::UnterminatedStringLiteral)?,
174 );
175 }
176 }
177 }
178
179 if !current.is_empty() {
180 parts.push(SqlPart::Raw(current.into()));
181 }
182
183 Ok(parts)
184}
185
186type Cache = HashMap<String, (Vec<SqlPart>, Option<Arc<str>>)>;
187
188#[derive(Debug, Default)]
190pub struct Statement {
191 cache: Arc<RwLock<Cache>>,
192 cached_sql: Option<Arc<str>>,
193 sql: Option<String>,
194 pub parts: Vec<SqlPart>,
196 pub expected_response: ExpectedSqlResponse,
198}
199
200impl Statement {
201 fn new(sql: &str, cache: Arc<RwLock<Cache>>) -> Result<Self, SqlParseError> {
203 let parsed = cache
204 .read()
205 .map(|cache| cache.get(sql).cloned())
206 .ok()
207 .flatten();
208
209 if let Some((parts, cached_sql)) = parsed {
210 Ok(Self {
211 parts,
212 cached_sql,
213 sql: None,
214 cache,
215 ..Default::default()
216 })
217 } else {
218 let parts = split_sql_parts(sql)?;
219
220 if let Ok(mut cache) = cache.write() {
221 cache.insert(sql.to_owned(), (parts.clone(), None));
222 } else {
223 tracing::warn!("Failed to acquire write lock for SQL statement cache");
224 }
225
226 Ok(Self {
227 parts,
228 sql: Some(sql.to_owned()),
229 cache,
230 ..Default::default()
231 })
232 }
233 }
234
235 pub fn to_sql(self) -> Result<(String, Vec<Value>), Error> {
241 if let Some(cached_sql) = self.cached_sql {
242 let sql = cached_sql.to_string();
243 let values = self
244 .parts
245 .into_iter()
246 .map(|x| match x {
247 SqlPart::Placeholder(name, value) => {
248 match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
249 PlaceholderValue::Value(value) => Ok(vec![value]),
250 PlaceholderValue::Set(values) => Ok(values),
251 }
252 }
253 SqlPart::Raw(_) => Ok(vec![]),
254 })
255 .collect::<Result<Vec<_>, Error>>()?
256 .into_iter()
257 .flatten()
258 .collect::<Vec<_>>();
259 return Ok((sql, values));
260 }
261
262 let mut placeholder_values = Vec::new();
263 let mut can_be_cached = true;
264 let sql = self
265 .parts
266 .into_iter()
267 .map(|x| match x {
268 SqlPart::Placeholder(name, value) => {
269 match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
270 PlaceholderValue::Value(value) => {
271 placeholder_values.push(value);
272 Ok::<_, Error>(format!("${}", placeholder_values.len()))
273 }
274 PlaceholderValue::Set(mut values) => {
275 can_be_cached = false;
276 let start_size = placeholder_values.len();
277 placeholder_values.append(&mut values);
278 let placeholders = (start_size + 1..=placeholder_values.len())
279 .map(|i| format!("${i}"))
280 .collect::<Vec<_>>()
281 .join(", ");
282 Ok(placeholders)
283 }
284 }
285 }
286 SqlPart::Raw(raw) => Ok(raw.trim().to_string()),
287 })
288 .collect::<Result<Vec<String>, _>>()?
289 .join(" ");
290
291 if can_be_cached {
292 if let Some(original_sql) = self.sql {
293 let _ = self.cache.write().map(|mut cache| {
294 if let Some((_, cached_sql)) = cache.get_mut(&original_sql) {
295 *cached_sql = Some(sql.clone().into());
296 }
297 });
298 }
299 }
300
301 Ok((sql, placeholder_values))
302 }
303
304 #[inline]
306 pub fn bind<C, V>(mut self, name: C, value: V) -> Self
307 where
308 C: ToString,
309 V: Into<Value>,
310 {
311 let name = name.to_string();
312 let value = value.into();
313 let value: PlaceholderValue = value.into();
314
315 for part in self.parts.iter_mut() {
316 if let SqlPart::Placeholder(part_name, part_value) = part {
317 if **part_name == *name.as_str() {
318 *part_value = Some(value.clone());
319 }
320 }
321 }
322
323 self
324 }
325
326 #[inline]
331 pub fn bind_vec<C, V>(mut self, name: C, value: Vec<V>) -> Self
332 where
333 C: ToString,
334 V: Into<Value>,
335 {
336 let name = name.to_string();
337 let value: PlaceholderValue = value
338 .into_iter()
339 .map(|x| x.into())
340 .collect::<Vec<Value>>()
341 .into();
342
343 for part in self.parts.iter_mut() {
344 if let SqlPart::Placeholder(part_name, part_value) = part {
345 if **part_name == *name.as_str() {
346 *part_value = Some(value.clone());
347 }
348 }
349 }
350
351 self
352 }
353
354 pub async fn pluck<C>(self, conn: &C) -> Result<Option<Value>, Error>
356 where
357 C: DatabaseExecutor,
358 {
359 conn.pluck(self).await
360 }
361
362 pub async fn batch<C>(self, conn: &C) -> Result<(), Error>
364 where
365 C: DatabaseExecutor,
366 {
367 conn.batch(self).await
368 }
369
370 pub async fn execute<C>(self, conn: &C) -> Result<usize, Error>
372 where
373 C: DatabaseExecutor,
374 {
375 conn.execute(self).await
376 }
377
378 pub async fn fetch_one<C>(self, conn: &C) -> Result<Option<Vec<Column>>, Error>
380 where
381 C: DatabaseExecutor,
382 {
383 conn.fetch_one(self).await
384 }
385
386 pub async fn fetch_all<C>(self, conn: &C) -> Result<Vec<Vec<Column>>, Error>
388 where
389 C: DatabaseExecutor,
390 {
391 conn.fetch_all(self).await
392 }
393}
394
395#[inline(always)]
397pub fn query(sql: &str) -> Result<Statement, Error> {
398 static CACHE: Lazy<Arc<RwLock<Cache>>> = Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
399 Statement::new(sql, CACHE.clone()).map_err(|e| Error::Database(Box::new(e)))
400}