1use std::ops::Deref;
4
5use crate::error::Error;
6use crate::ffi;
7use crate::types::{ToSql, ToSqlOutput, ValueRef};
8use crate::{Connection, DatabaseName, Result, Row};
9
10pub struct Sql {
11 buf: String,
12}
13
14impl Sql {
15 pub fn new() -> Sql {
16 Sql { buf: String::new() }
17 }
18
19 pub fn push_pragma(
20 &mut self,
21 schema_name: Option<DatabaseName<'_>>,
22 pragma_name: &str,
23 ) -> Result<()> {
24 self.push_keyword("PRAGMA")?;
25 self.push_space();
26 if let Some(schema_name) = schema_name {
27 self.push_schema_name(schema_name);
28 self.push_dot();
29 }
30 self.push_keyword(pragma_name)
31 }
32
33 pub fn push_keyword(&mut self, keyword: &str) -> Result<()> {
34 if !keyword.is_empty() && is_identifier(keyword) {
35 self.buf.push_str(keyword);
36 Ok(())
37 } else {
38 Err(Error::SqliteFailure(
39 ffi::Error::new(ffi::SQLITE_MISUSE),
40 Some(format!("Invalid keyword \"{keyword}\"")),
41 ))
42 }
43 }
44
45 pub fn push_schema_name(&mut self, schema_name: DatabaseName<'_>) {
46 match schema_name {
47 DatabaseName::Main => self.buf.push_str("main"),
48 DatabaseName::Temp => self.buf.push_str("temp"),
49 DatabaseName::Attached(s) => self.push_identifier(s),
50 };
51 }
52
53 pub fn push_identifier(&mut self, s: &str) {
54 if is_identifier(s) {
55 self.buf.push_str(s);
56 } else {
57 self.wrap_and_escape(s, '"');
58 }
59 }
60
61 pub fn push_value(&mut self, value: &dyn ToSql) -> Result<()> {
62 let value = value.to_sql()?;
63 let value = match value {
64 ToSqlOutput::Borrowed(v) => v,
65 ToSqlOutput::Owned(ref v) => ValueRef::from(v),
66 #[cfg(feature = "blob")]
67 ToSqlOutput::ZeroBlob(_) => {
68 return Err(Error::SqliteFailure(
69 ffi::Error::new(ffi::SQLITE_MISUSE),
70 Some(format!("Unsupported value \"{value:?}\"")),
71 ));
72 }
73 #[cfg(feature = "array")]
74 ToSqlOutput::Array(_) => {
75 return Err(Error::SqliteFailure(
76 ffi::Error::new(ffi::SQLITE_MISUSE),
77 Some(format!("Unsupported value \"{value:?}\"")),
78 ));
79 }
80 };
81 match value {
82 ValueRef::Integer(i) => {
83 self.push_int(i);
84 }
85 ValueRef::Real(r) => {
86 self.push_real(r);
87 }
88 ValueRef::Text(s) => {
89 let s = std::str::from_utf8(s)?;
90 self.push_string_literal(s);
91 }
92 _ => {
93 return Err(Error::SqliteFailure(
94 ffi::Error::new(ffi::SQLITE_MISUSE),
95 Some(format!("Unsupported value \"{value:?}\"")),
96 ));
97 }
98 };
99 Ok(())
100 }
101
102 pub fn push_string_literal(&mut self, s: &str) {
103 self.wrap_and_escape(s, '\'');
104 }
105
106 pub fn push_int(&mut self, i: i64) {
107 self.buf.push_str(&i.to_string());
108 }
109
110 pub fn push_real(&mut self, f: f64) {
111 self.buf.push_str(&f.to_string());
112 }
113
114 pub fn push_space(&mut self) {
115 self.buf.push(' ');
116 }
117
118 pub fn push_dot(&mut self) {
119 self.buf.push('.');
120 }
121
122 pub fn push_equal_sign(&mut self) {
123 self.buf.push('=');
124 }
125
126 pub fn open_brace(&mut self) {
127 self.buf.push('(');
128 }
129
130 pub fn close_brace(&mut self) {
131 self.buf.push(')');
132 }
133
134 pub fn as_str(&self) -> &str {
135 &self.buf
136 }
137
138 fn wrap_and_escape(&mut self, s: &str, quote: char) {
139 self.buf.push(quote);
140 let chars = s.chars();
141 for ch in chars {
142 if ch == quote {
144 self.buf.push(ch);
145 }
146 self.buf.push(ch);
147 }
148 self.buf.push(quote);
149 }
150}
151
152impl Deref for Sql {
153 type Target = str;
154
155 fn deref(&self) -> &str {
156 self.as_str()
157 }
158}
159
160impl Connection {
161 pub fn pragma_query_value<T, F>(
169 &self,
170 schema_name: Option<DatabaseName<'_>>,
171 pragma_name: &str,
172 f: F,
173 ) -> Result<T>
174 where
175 F: FnOnce(&Row<'_>) -> Result<T>,
176 {
177 let mut query = Sql::new();
178 query.push_pragma(schema_name, pragma_name)?;
179 self.query_row(&query, [], f)
180 }
181
182 pub fn pragma_query<F>(
187 &self,
188 schema_name: Option<DatabaseName<'_>>,
189 pragma_name: &str,
190 mut f: F,
191 ) -> Result<()>
192 where
193 F: FnMut(&Row<'_>) -> Result<()>,
194 {
195 let mut query = Sql::new();
196 query.push_pragma(schema_name, pragma_name)?;
197 let mut stmt = self.prepare(&query)?;
198 let mut rows = stmt.query([])?;
199 while let Some(result_row) = rows.next()? {
200 let row = result_row;
201 f(row)?;
202 }
203 Ok(())
204 }
205
206 pub fn pragma<F, V>(
216 &self,
217 schema_name: Option<DatabaseName<'_>>,
218 pragma_name: &str,
219 pragma_value: V,
220 mut f: F,
221 ) -> Result<()>
222 where
223 F: FnMut(&Row<'_>) -> Result<()>,
224 V: ToSql,
225 {
226 let mut sql = Sql::new();
227 sql.push_pragma(schema_name, pragma_name)?;
228 sql.open_brace();
232 sql.push_value(&pragma_value)?;
233 sql.close_brace();
234 let mut stmt = self.prepare(&sql)?;
235 let mut rows = stmt.query([])?;
236 while let Some(result_row) = rows.next()? {
237 let row = result_row;
238 f(row)?;
239 }
240 Ok(())
241 }
242
243 pub fn pragma_update<V>(
248 &self,
249 schema_name: Option<DatabaseName<'_>>,
250 pragma_name: &str,
251 pragma_value: V,
252 ) -> Result<()>
253 where
254 V: ToSql,
255 {
256 let mut sql = Sql::new();
257 sql.push_pragma(schema_name, pragma_name)?;
258 sql.push_equal_sign();
262 sql.push_value(&pragma_value)?;
263 self.execute_batch(&sql)
264 }
265
266 pub fn pragma_update_and_check<F, T, V>(
270 &self,
271 schema_name: Option<DatabaseName<'_>>,
272 pragma_name: &str,
273 pragma_value: V,
274 f: F,
275 ) -> Result<T>
276 where
277 F: FnOnce(&Row<'_>) -> Result<T>,
278 V: ToSql,
279 {
280 let mut sql = Sql::new();
281 sql.push_pragma(schema_name, pragma_name)?;
282 sql.push_equal_sign();
286 sql.push_value(&pragma_value)?;
287 self.query_row(&sql, [], f)
288 }
289}
290
291fn is_identifier(s: &str) -> bool {
292 let chars = s.char_indices();
293 for (i, ch) in chars {
294 if i == 0 {
295 if !is_identifier_start(ch) {
296 return false;
297 }
298 } else if !is_identifier_continue(ch) {
299 return false;
300 }
301 }
302 true
303}
304
305fn is_identifier_start(c: char) -> bool {
306 c.is_ascii_uppercase() || c == '_' || c.is_ascii_lowercase() || c > '\x7F'
307}
308
309fn is_identifier_continue(c: char) -> bool {
310 c == '$'
311 || c.is_ascii_digit()
312 || c.is_ascii_uppercase()
313 || c == '_'
314 || c.is_ascii_lowercase()
315 || c > '\x7F'
316}
317
318#[cfg(test)]
319mod test {
320 use super::Sql;
321 use crate::pragma;
322 use crate::{Connection, DatabaseName, Result};
323
324 #[test]
325 fn pragma_query_value() -> Result<()> {
326 let db = Connection::open_in_memory()?;
327 let user_version: i32 = db.pragma_query_value(None, "user_version", |row| row.get(0))?;
328 assert_eq!(0, user_version);
329 Ok(())
330 }
331
332 #[test]
333 #[cfg(feature = "modern_sqlite")]
334 fn pragma_func_query_value() -> Result<()> {
335 let db = Connection::open_in_memory()?;
336 let user_version: i32 = db.one_column("SELECT user_version FROM pragma_user_version")?;
337 assert_eq!(0, user_version);
338 Ok(())
339 }
340
341 #[test]
342 fn pragma_query_no_schema() -> Result<()> {
343 let db = Connection::open_in_memory()?;
344 let mut user_version = -1;
345 db.pragma_query(None, "user_version", |row| {
346 user_version = row.get(0)?;
347 Ok(())
348 })?;
349 assert_eq!(0, user_version);
350 Ok(())
351 }
352
353 #[test]
354 fn pragma_query_with_schema() -> Result<()> {
355 let db = Connection::open_in_memory()?;
356 let mut user_version = -1;
357 db.pragma_query(Some(DatabaseName::Main), "user_version", |row| {
358 user_version = row.get(0)?;
359 Ok(())
360 })?;
361 assert_eq!(0, user_version);
362 Ok(())
363 }
364
365 #[test]
366 fn pragma() -> Result<()> {
367 let db = Connection::open_in_memory()?;
368 let mut columns = Vec::new();
369 db.pragma(None, "table_info", "sqlite_master", |row| {
370 let column: String = row.get(1)?;
371 columns.push(column);
372 Ok(())
373 })?;
374 assert_eq!(5, columns.len());
375 Ok(())
376 }
377
378 #[test]
379 #[cfg(feature = "modern_sqlite")]
380 fn pragma_func() -> Result<()> {
381 let db = Connection::open_in_memory()?;
382 let mut table_info = db.prepare("SELECT * FROM pragma_table_info(?1)")?;
383 let mut columns = Vec::new();
384 let mut rows = table_info.query(["sqlite_master"])?;
385
386 while let Some(row) = rows.next()? {
387 let row = row;
388 let column: String = row.get(1)?;
389 columns.push(column);
390 }
391 assert_eq!(5, columns.len());
392 Ok(())
393 }
394
395 #[test]
396 fn pragma_update() -> Result<()> {
397 let db = Connection::open_in_memory()?;
398 db.pragma_update(None, "user_version", 1)
399 }
400
401 #[test]
402 fn pragma_update_and_check() -> Result<()> {
403 let db = Connection::open_in_memory()?;
404 let journal_mode: String =
405 db.pragma_update_and_check(None, "journal_mode", "OFF", |row| row.get(0))?;
406 assert!(
407 journal_mode == "off" || journal_mode == "memory",
408 "mode: {:?}",
409 journal_mode,
410 );
411 let mode =
413 db.pragma_update_and_check(None, "journal_mode", "OFF", |row| row.get::<_, String>(0))?;
414 assert!(mode == "off" || mode == "memory", "mode: {:?}", mode);
415
416 let param: &dyn crate::ToSql = &"OFF";
417 let mode =
418 db.pragma_update_and_check(None, "journal_mode", param, |row| row.get::<_, String>(0))?;
419 assert!(mode == "off" || mode == "memory", "mode: {:?}", mode);
420 Ok(())
421 }
422
423 #[test]
424 fn is_identifier() {
425 assert!(pragma::is_identifier("full"));
426 assert!(pragma::is_identifier("r2d2"));
427 assert!(!pragma::is_identifier("sp ce"));
428 assert!(!pragma::is_identifier("semi;colon"));
429 }
430
431 #[test]
432 fn double_quote() {
433 let mut sql = Sql::new();
434 sql.push_schema_name(DatabaseName::Attached(r#"schema";--"#));
435 assert_eq!(r#""schema"";--""#, sql.as_str());
436 }
437
438 #[test]
439 fn wrap_and_escape() {
440 let mut sql = Sql::new();
441 sql.push_string_literal("value'; --");
442 assert_eq!("'value''; --'", sql.as_str());
443 }
444
445 #[test]
446 fn locking_mode() -> Result<()> {
447 let db = Connection::open_in_memory()?;
448 let r = db.pragma_update(None, "locking_mode", "exclusive");
449 if cfg!(feature = "extra_check") {
450 r.unwrap_err();
451 } else {
452 r?;
453 }
454 Ok(())
455 }
456}