1use std::collections::HashMap;
3use std::sync::{Arc, RwLock};
4
5use cdk_common::database::Error;
6use once_cell::sync::Lazy;
7
8use crate::database::DatabaseExecutor;
9use crate::value::Value;
10
11pub type Column = Value;
13
14#[derive(Debug, Clone, Copy, Default)]
16pub enum ExpectedSqlResponse {
17 SingleRow,
19 #[default]
21 ManyRows,
22 AffectedRows,
24 Pluck,
26 Batch,
28}
29
30#[derive(Debug, Clone)]
32pub enum PlaceholderValue {
33 Value(Value),
35 Set(Vec<Value>),
37}
38
39impl From<Value> for PlaceholderValue {
40 fn from(value: Value) -> Self {
41 PlaceholderValue::Value(value)
42 }
43}
44
45impl From<Vec<Value>> for PlaceholderValue {
46 fn from(value: Vec<Value>) -> Self {
47 PlaceholderValue::Set(value)
48 }
49}
50
51#[derive(Debug, Clone)]
53pub enum SqlPart {
54 Raw(Arc<str>),
56 Placeholder(Arc<str>, Option<PlaceholderValue>),
58}
59
60#[derive(Debug, PartialEq, thiserror::Error)]
62pub enum SqlParseError {
63 #[error("Unterminated String literal")]
65 UnterminatedStringLiteral,
66 #[error("Invalid placeholder name")]
68 InvalidPlaceholder,
69}
70
71pub fn split_sql_parts(input: &str) -> Result<Vec<SqlPart>, SqlParseError> {
76 let mut parts = Vec::new();
77 let mut current = String::new();
78 let mut chars = input.chars().peekable();
79
80 while let Some(&c) = chars.peek() {
81 match c {
82 '\'' | '"' => {
83 let quote = c;
85 current.push(chars.next().unwrap());
86
87 let mut closed = false;
88 while let Some(&next) = chars.peek() {
89 current.push(chars.next().unwrap());
90
91 if next == quote {
92 if chars.peek() == Some("e) {
93 current.push(chars.next().unwrap());
95 } else {
96 closed = true;
97 break;
98 }
99 }
100 }
101
102 if !closed {
103 return Err(SqlParseError::UnterminatedStringLiteral);
104 }
105 }
106
107 ':' => {
108 if !current.is_empty() {
110 parts.push(SqlPart::Raw(current.clone().into()));
111 current.clear();
112 }
113
114 chars.next(); let mut name = String::new();
116
117 while let Some(&next) = chars.peek() {
118 if next.is_alphanumeric() || next == '_' {
119 name.push(chars.next().unwrap());
120 } else {
121 break;
122 }
123 }
124
125 if name.is_empty() {
126 return Err(SqlParseError::InvalidPlaceholder);
127 }
128
129 parts.push(SqlPart::Placeholder(name.into(), None));
130 }
131
132 _ => {
133 current.push(chars.next().unwrap());
134 }
135 }
136 }
137
138 if !current.is_empty() {
139 parts.push(SqlPart::Raw(current.into()));
140 }
141
142 Ok(parts)
143}
144
145type Cache = HashMap<String, (Vec<SqlPart>, Option<Arc<str>>)>;
146
147#[derive(Debug, Default)]
149pub struct Statement {
150 cache: Arc<RwLock<Cache>>,
151 cached_sql: Option<Arc<str>>,
152 sql: Option<String>,
153 pub parts: Vec<SqlPart>,
155 pub expected_response: ExpectedSqlResponse,
157}
158
159impl Statement {
160 fn new(sql: &str, cache: Arc<RwLock<Cache>>) -> Result<Self, SqlParseError> {
162 let parsed = cache
163 .read()
164 .map(|cache| cache.get(sql).cloned())
165 .ok()
166 .flatten();
167
168 if let Some((parts, cached_sql)) = parsed {
169 Ok(Self {
170 parts,
171 cached_sql,
172 sql: None,
173 cache,
174 ..Default::default()
175 })
176 } else {
177 let parts = split_sql_parts(sql)?;
178
179 if let Ok(mut cache) = cache.write() {
180 cache.insert(sql.to_owned(), (parts.clone(), None));
181 } else {
182 tracing::warn!("Failed to acquire write lock for SQL statement cache");
183 }
184
185 Ok(Self {
186 parts,
187 sql: Some(sql.to_owned()),
188 cache,
189 ..Default::default()
190 })
191 }
192 }
193
194 pub fn to_sql(self) -> Result<(String, Vec<Value>), Error> {
200 if let Some(cached_sql) = self.cached_sql {
201 let sql = cached_sql.to_string();
202 let values = self
203 .parts
204 .into_iter()
205 .map(|x| match x {
206 SqlPart::Placeholder(name, value) => {
207 match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
208 PlaceholderValue::Value(value) => Ok(vec![value]),
209 PlaceholderValue::Set(values) => Ok(values),
210 }
211 }
212 SqlPart::Raw(_) => Ok(vec![]),
213 })
214 .collect::<Result<Vec<_>, Error>>()?
215 .into_iter()
216 .flatten()
217 .collect::<Vec<_>>();
218 return Ok((sql, values));
219 }
220
221 let mut placeholder_values = Vec::new();
222 let mut can_be_cached = true;
223 let sql = self
224 .parts
225 .into_iter()
226 .map(|x| match x {
227 SqlPart::Placeholder(name, value) => {
228 match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
229 PlaceholderValue::Value(value) => {
230 placeholder_values.push(value);
231 Ok::<_, Error>(format!("${}", placeholder_values.len()))
232 }
233 PlaceholderValue::Set(mut values) => {
234 can_be_cached = false;
235 let start_size = placeholder_values.len();
236 placeholder_values.append(&mut values);
237 let placeholders = (start_size + 1..=placeholder_values.len())
238 .map(|i| format!("${i}"))
239 .collect::<Vec<_>>()
240 .join(", ");
241 Ok(placeholders)
242 }
243 }
244 }
245 SqlPart::Raw(raw) => Ok(raw.trim().to_string()),
246 })
247 .collect::<Result<Vec<String>, _>>()?
248 .join(" ");
249
250 if can_be_cached {
251 if let Some(original_sql) = self.sql {
252 let _ = self.cache.write().map(|mut cache| {
253 if let Some((_, cached_sql)) = cache.get_mut(&original_sql) {
254 *cached_sql = Some(sql.clone().into());
255 }
256 });
257 }
258 }
259
260 Ok((sql, placeholder_values))
261 }
262
263 #[inline]
265 pub fn bind<C, V>(mut self, name: C, value: V) -> Self
266 where
267 C: ToString,
268 V: Into<Value>,
269 {
270 let name = name.to_string();
271 let value = value.into();
272 let value: PlaceholderValue = value.into();
273
274 for part in self.parts.iter_mut() {
275 if let SqlPart::Placeholder(part_name, part_value) = part {
276 if **part_name == *name.as_str() {
277 *part_value = Some(value.clone());
278 }
279 }
280 }
281
282 self
283 }
284
285 #[inline]
290 pub fn bind_vec<C, V>(mut self, name: C, value: Vec<V>) -> Self
291 where
292 C: ToString,
293 V: Into<Value>,
294 {
295 let name = name.to_string();
296 let value: PlaceholderValue = value
297 .into_iter()
298 .map(|x| x.into())
299 .collect::<Vec<Value>>()
300 .into();
301
302 for part in self.parts.iter_mut() {
303 if let SqlPart::Placeholder(part_name, part_value) = part {
304 if **part_name == *name.as_str() {
305 *part_value = Some(value.clone());
306 }
307 }
308 }
309
310 self
311 }
312
313 pub async fn pluck<C>(self, conn: &C) -> Result<Option<Value>, Error>
315 where
316 C: DatabaseExecutor,
317 {
318 conn.pluck(self).await
319 }
320
321 pub async fn batch<C>(self, conn: &C) -> Result<(), Error>
323 where
324 C: DatabaseExecutor,
325 {
326 conn.batch(self).await
327 }
328
329 pub async fn execute<C>(self, conn: &C) -> Result<usize, Error>
331 where
332 C: DatabaseExecutor,
333 {
334 conn.execute(self).await
335 }
336
337 pub async fn fetch_one<C>(self, conn: &C) -> Result<Option<Vec<Column>>, Error>
339 where
340 C: DatabaseExecutor,
341 {
342 conn.fetch_one(self).await
343 }
344
345 pub async fn fetch_all<C>(self, conn: &C) -> Result<Vec<Vec<Column>>, Error>
347 where
348 C: DatabaseExecutor,
349 {
350 conn.fetch_all(self).await
351 }
352}
353
354#[inline(always)]
356pub fn query(sql: &str) -> Result<Statement, Error> {
357 static CACHE: Lazy<Arc<RwLock<Cache>>> = Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
358 Statement::new(sql, CACHE.clone()).map_err(|e| Error::Database(Box::new(e)))
359}