1use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
20use sqlparser::dialect::SQLiteDialect;
21use sqlparser::keywords::Keyword;
22use sqlparser::tokenizer::{Token, Tokenizer};
23
24use crate::error::{Result, SQLRiteError};
25use crate::mvcc::JournalMode;
26use crate::sql::CommandOutput;
27use crate::sql::db::database::Database;
28
29#[derive(Debug, Clone, PartialEq)]
37pub enum PragmaValue {
38 Number(String),
40 Identifier(String),
42 String(String),
44}
45
46#[derive(Debug, Clone, PartialEq)]
49pub struct PragmaStatement {
50 pub name: String,
51 pub value: Option<PragmaValue>,
52}
53
54pub fn try_parse_pragma(sql: &str) -> Result<Option<PragmaStatement>> {
60 let dialect = SQLiteDialect {};
61 let tokens = Tokenizer::new(&dialect, sql)
62 .tokenize()
63 .map_err(|e| SQLRiteError::General(format!("PRAGMA tokenize error: {e}")))?;
64
65 let mut iter = tokens
66 .into_iter()
67 .filter(|t| !matches!(t, Token::Whitespace(_)))
68 .peekable();
69
70 match iter.peek() {
73 Some(Token::Word(w)) if w.keyword == Keyword::PRAGMA => {
74 iter.next();
75 }
76 _ => return Ok(None),
77 }
78
79 let name = match iter.next() {
80 Some(Token::Word(w)) => w.value,
81 Some(other) => {
82 return Err(SQLRiteError::General(format!(
83 "PRAGMA: expected pragma name, got {other:?}"
84 )));
85 }
86 None => {
87 return Err(SQLRiteError::General(
88 "PRAGMA: missing pragma name".to_string(),
89 ));
90 }
91 };
92
93 let value = match iter.peek() {
94 None | Some(Token::SemiColon) => None,
95 Some(Token::Eq) => {
96 iter.next();
97 Some(read_pragma_value(&mut iter)?)
98 }
99 Some(Token::LParen) => {
100 iter.next();
101 let v = read_pragma_value(&mut iter)?;
102 match iter.next() {
103 Some(Token::RParen) => {}
104 Some(other) => {
105 return Err(SQLRiteError::General(format!(
106 "PRAGMA: expected ')' to close parenthesised value, got {other:?}"
107 )));
108 }
109 None => {
110 return Err(SQLRiteError::General(
111 "PRAGMA: expected ')' to close parenthesised value".to_string(),
112 ));
113 }
114 }
115 Some(v)
116 }
117 Some(other) => {
118 return Err(SQLRiteError::General(format!(
119 "PRAGMA: expected '=', '(', ';' or end of statement after name, got {other:?}"
120 )));
121 }
122 };
123
124 if matches!(iter.peek(), Some(Token::SemiColon)) {
128 iter.next();
129 }
130 if let Some(extra) = iter.next() {
131 return Err(SQLRiteError::General(format!(
132 "PRAGMA: unexpected trailing content {extra:?}"
133 )));
134 }
135
136 Ok(Some(PragmaStatement { name, value }))
137}
138
139fn read_pragma_value<I>(iter: &mut std::iter::Peekable<I>) -> Result<PragmaValue>
140where
141 I: Iterator<Item = Token>,
142{
143 let mut neg = false;
148 let first = iter.next().ok_or_else(|| {
149 SQLRiteError::General("PRAGMA: missing value after '=' or '('".to_string())
150 })?;
151
152 let tok = if matches!(first, Token::Minus) {
153 neg = true;
154 iter.next()
155 .ok_or_else(|| SQLRiteError::General("PRAGMA: missing value after '-'".to_string()))?
156 } else {
157 first
158 };
159
160 Ok(match tok {
161 Token::Number(s, _) => {
162 if neg {
163 PragmaValue::Number(format!("-{s}"))
164 } else {
165 PragmaValue::Number(s)
166 }
167 }
168 Token::SingleQuotedString(s) | Token::DoubleQuotedString(s) => {
169 if neg {
170 return Err(SQLRiteError::General(
171 "PRAGMA: unary '-' is only valid in front of a number".to_string(),
172 ));
173 }
174 PragmaValue::String(s)
175 }
176 Token::Word(w) => {
177 if neg {
178 return Err(SQLRiteError::General(
179 "PRAGMA: unary '-' is only valid in front of a number".to_string(),
180 ));
181 }
182 PragmaValue::Identifier(w.value)
183 }
184 other => {
185 return Err(SQLRiteError::General(format!(
186 "PRAGMA: unsupported value token {other:?}"
187 )));
188 }
189 })
190}
191
192pub fn execute_pragma(stmt: PragmaStatement, db: &mut Database) -> Result<CommandOutput> {
195 match stmt.name.to_ascii_lowercase().as_str() {
196 "auto_vacuum" => pragma_auto_vacuum(stmt.value, db),
197 "journal_mode" => pragma_journal_mode(stmt.value, db),
198 "table_list" => pragma_table_list(stmt.value, db),
199 other => Err(SQLRiteError::NotImplemented(format!(
200 "PRAGMA '{other}' is not supported"
201 ))),
202 }
203}
204
205fn pragma_journal_mode(value: Option<PragmaValue>, db: &mut Database) -> Result<CommandOutput> {
211 match value {
212 None => render_journal_mode(db.journal_mode()),
213 Some(v) => {
214 let target = parse_journal_mode_target(&v)?;
215 db.set_journal_mode(target)?;
216 render_journal_mode(db.journal_mode())
219 }
220 }
221}
222
223fn render_journal_mode(mode: JournalMode) -> Result<CommandOutput> {
224 let mut t = PrintTable::new();
225 t.add_row(PrintRow::new(vec![PrintCell::new("journal_mode")]));
226 t.add_row(PrintRow::new(vec![PrintCell::new(mode.as_str())]));
227 Ok(CommandOutput {
228 status: "PRAGMA journal_mode executed. 1 row returned.".to_string(),
229 rendered: Some(t.to_string()),
230 })
231}
232
233fn parse_journal_mode_target(value: &PragmaValue) -> Result<JournalMode> {
234 let s = match value {
235 PragmaValue::Identifier(s) | PragmaValue::String(s) => s.as_str(),
236 PragmaValue::Number(s) => {
237 return Err(SQLRiteError::General(format!(
238 "PRAGMA journal_mode: expected 'wal' or 'mvcc', got numeric '{s}'"
239 )));
240 }
241 };
242 JournalMode::from_str_lossless(s).ok_or_else(|| {
243 SQLRiteError::General(format!(
244 "PRAGMA journal_mode: unknown mode '{s}' (supported: 'wal', 'mvcc')"
245 ))
246 })
247}
248
249fn pragma_table_list(value: Option<PragmaValue>, db: &Database) -> Result<CommandOutput> {
260 if value.is_some() {
261 return Err(SQLRiteError::General(
262 "PRAGMA table_list does not take a value".to_string(),
263 ));
264 }
265
266 let mut t = PrintTable::new();
267 t.add_row(PrintRow::new(vec![
268 PrintCell::new("schema"),
269 PrintCell::new("name"),
270 PrintCell::new("type"),
271 PrintCell::new("ncol"),
272 PrintCell::new("wr"),
273 PrintCell::new("strict"),
274 ]));
275
276 let mut names: Vec<&String> = db.tables.keys().collect();
277 names.sort();
278 let mut row_count = 0usize;
279 for name in names {
280 let ncol = db.tables[name].columns.len();
281 t.add_row(PrintRow::new(vec![
282 PrintCell::new("main"),
283 PrintCell::new(name),
284 PrintCell::new("table"),
285 PrintCell::new(&ncol.to_string()),
286 PrintCell::new("0"),
287 PrintCell::new("0"),
288 ]));
289 row_count += 1;
290 }
291
292 t.add_row(PrintRow::new(vec![
294 PrintCell::new("main"),
295 PrintCell::new(crate::sql::pager::MASTER_TABLE_NAME),
296 PrintCell::new("table"),
297 PrintCell::new("5"),
298 PrintCell::new("0"),
299 PrintCell::new("0"),
300 ]));
301 row_count += 1;
302
303 Ok(CommandOutput {
304 status: format!("PRAGMA table_list executed. {row_count} rows returned."),
305 rendered: Some(t.to_string()),
306 })
307}
308
309fn pragma_auto_vacuum(value: Option<PragmaValue>, db: &mut Database) -> Result<CommandOutput> {
313 match value {
314 None => {
315 let mut t = PrintTable::new();
323 t.add_row(PrintRow::new(vec![PrintCell::new("auto_vacuum")]));
324 let cell_value = match db.auto_vacuum_threshold() {
325 Some(v) => format!("{v}"),
326 None => "OFF".to_string(),
327 };
328 t.add_row(PrintRow::new(vec![PrintCell::new(&cell_value)]));
329 Ok(CommandOutput {
330 status: "PRAGMA auto_vacuum executed. 1 row returned.".to_string(),
331 rendered: Some(t.to_string()),
332 })
333 }
334 Some(v) => {
335 let new_threshold = parse_auto_vacuum_target(&v)?;
336 db.set_auto_vacuum_threshold(new_threshold)?;
337 Ok(CommandOutput {
338 status: "PRAGMA auto_vacuum executed.".to_string(),
339 rendered: None,
340 })
341 }
342 }
343}
344
345fn parse_auto_vacuum_target(value: &PragmaValue) -> Result<Option<f32>> {
350 match value {
351 PragmaValue::Identifier(s) | PragmaValue::String(s) => {
352 match s.to_ascii_lowercase().as_str() {
353 "off" | "none" => Ok(None),
354 _ => Err(SQLRiteError::General(format!(
355 "PRAGMA auto_vacuum: expected a number in 0.0..=1.0 or OFF/NONE, got '{s}'"
356 ))),
357 }
358 }
359 PragmaValue::Number(s) => {
360 let f: f32 = s.parse().map_err(|_| {
361 SQLRiteError::General(format!("PRAGMA auto_vacuum: '{s}' is not a valid number"))
362 })?;
363 Ok(Some(f))
364 }
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
373 fn try_parse_pragma_returns_none_for_non_pragma() {
374 assert!(try_parse_pragma("SELECT 1;").unwrap().is_none());
375 assert!(
376 try_parse_pragma("CREATE TABLE t (id INTEGER);")
377 .unwrap()
378 .is_none()
379 );
380 assert!(try_parse_pragma("").unwrap().is_none());
382 assert!(try_parse_pragma(" \n\t ").unwrap().is_none());
383 assert!(try_parse_pragma("-- hello\n").unwrap().is_none());
384 }
385
386 #[test]
387 fn try_parse_pragma_read_form() {
388 let stmt = try_parse_pragma("PRAGMA auto_vacuum;").unwrap().unwrap();
389 assert_eq!(stmt.name, "auto_vacuum");
390 assert_eq!(stmt.value, None);
391
392 let stmt = try_parse_pragma(" PRAGMA auto_vacuum ").unwrap().unwrap();
394 assert_eq!(stmt.name, "auto_vacuum");
395 assert_eq!(stmt.value, None);
396
397 let stmt = try_parse_pragma("pragma auto_vacuum;").unwrap().unwrap();
399 assert_eq!(stmt.name, "auto_vacuum");
400 }
401
402 #[test]
403 fn try_parse_pragma_eq_number() {
404 let stmt = try_parse_pragma("PRAGMA auto_vacuum = 0.5;")
405 .unwrap()
406 .unwrap();
407 assert_eq!(stmt.name, "auto_vacuum");
408 assert_eq!(stmt.value, Some(PragmaValue::Number("0.5".to_string())));
409
410 let stmt = try_parse_pragma("PRAGMA auto_vacuum = 0;")
411 .unwrap()
412 .unwrap();
413 assert_eq!(stmt.value, Some(PragmaValue::Number("0".to_string())));
414
415 let stmt = try_parse_pragma("PRAGMA auto_vacuum = -0.1;")
418 .unwrap()
419 .unwrap();
420 assert_eq!(stmt.value, Some(PragmaValue::Number("-0.1".to_string())));
421 }
422
423 #[test]
424 fn try_parse_pragma_eq_identifier() {
425 let stmt = try_parse_pragma("PRAGMA auto_vacuum = OFF;")
426 .unwrap()
427 .unwrap();
428 assert_eq!(stmt.value, Some(PragmaValue::Identifier("OFF".to_string())));
429
430 let stmt = try_parse_pragma("PRAGMA auto_vacuum = none;")
431 .unwrap()
432 .unwrap();
433 assert_eq!(
434 stmt.value,
435 Some(PragmaValue::Identifier("none".to_string()))
436 );
437 }
438
439 #[test]
440 fn try_parse_pragma_eq_string() {
441 let stmt = try_parse_pragma("PRAGMA auto_vacuum = 'OFF';")
443 .unwrap()
444 .unwrap();
445 assert_eq!(stmt.value, Some(PragmaValue::String("OFF".to_string())));
446
447 let stmt = try_parse_pragma("PRAGMA auto_vacuum = \"NONE\";")
453 .unwrap()
454 .unwrap();
455 assert_eq!(
456 stmt.value,
457 Some(PragmaValue::Identifier("NONE".to_string()))
458 );
459 }
460
461 #[test]
462 fn try_parse_pragma_paren_form() {
463 let stmt = try_parse_pragma("PRAGMA auto_vacuum(0.5);")
464 .unwrap()
465 .unwrap();
466 assert_eq!(stmt.value, Some(PragmaValue::Number("0.5".to_string())));
467
468 let stmt = try_parse_pragma("PRAGMA auto_vacuum (OFF);")
469 .unwrap()
470 .unwrap();
471 assert_eq!(stmt.value, Some(PragmaValue::Identifier("OFF".to_string())));
472 }
473
474 #[test]
475 fn try_parse_pragma_rejects_malformed() {
476 assert!(try_parse_pragma("PRAGMA;").is_err());
477 assert!(try_parse_pragma("PRAGMA = 0.5;").is_err());
478 assert!(try_parse_pragma("PRAGMA auto_vacuum =;").is_err());
479 assert!(try_parse_pragma("PRAGMA auto_vacuum (0.5;").is_err());
480 assert!(try_parse_pragma("PRAGMA auto_vacuum; SELECT 1;").is_err());
482 assert!(try_parse_pragma("PRAGMA auto_vacuum = -'OFF';").is_err());
484 }
485
486 #[test]
487 fn parse_auto_vacuum_target_disables_on_off_or_none() {
488 for raw in ["OFF", "off", "Off", "NONE", "none"] {
489 assert_eq!(
490 parse_auto_vacuum_target(&PragmaValue::Identifier(raw.to_string())).unwrap(),
491 None
492 );
493 assert_eq!(
494 parse_auto_vacuum_target(&PragmaValue::String(raw.to_string())).unwrap(),
495 None
496 );
497 }
498 }
499
500 #[test]
501 fn parse_auto_vacuum_target_passes_numbers_through() {
502 assert_eq!(
503 parse_auto_vacuum_target(&PragmaValue::Number("0.5".to_string())).unwrap(),
504 Some(0.5_f32)
505 );
506 assert_eq!(
507 parse_auto_vacuum_target(&PragmaValue::Number("0".to_string())).unwrap(),
508 Some(0.0_f32)
509 );
510 assert_eq!(
513 parse_auto_vacuum_target(&PragmaValue::Number("1.5".to_string())).unwrap(),
514 Some(1.5_f32)
515 );
516 }
517
518 #[test]
519 fn parse_auto_vacuum_target_rejects_unknown_strings() {
520 let err =
521 parse_auto_vacuum_target(&PragmaValue::Identifier("WAL".to_string())).unwrap_err();
522 assert!(format!("{err}").contains("OFF/NONE"));
523 }
524
525 #[test]
526 fn execute_pragma_unknown_returns_not_implemented() {
527 let mut db = Database::new("t".to_string());
530 let err = execute_pragma(
531 PragmaStatement {
532 name: "synchronous".to_string(),
533 value: None,
534 },
535 &mut db,
536 )
537 .unwrap_err();
538 assert!(matches!(err, SQLRiteError::NotImplemented(_)));
539 }
540
541 #[test]
542 fn execute_pragma_auto_vacuum_set_and_read() {
543 let mut db = Database::new("t".to_string());
544
545 let out = execute_pragma(
547 PragmaStatement {
548 name: "auto_vacuum".to_string(),
549 value: Some(PragmaValue::Number("0.5".to_string())),
550 },
551 &mut db,
552 )
553 .unwrap();
554 assert!(out.rendered.is_none());
555 assert_eq!(db.auto_vacuum_threshold(), Some(0.5));
556
557 let out = execute_pragma(
558 PragmaStatement {
559 name: "auto_vacuum".to_string(),
560 value: None,
561 },
562 &mut db,
563 )
564 .unwrap();
565 let rendered = out.rendered.expect("read form must render rows");
566 assert!(rendered.contains("auto_vacuum"));
567 assert!(rendered.contains("0.5"));
568
569 execute_pragma(
571 PragmaStatement {
572 name: "auto_vacuum".to_string(),
573 value: Some(PragmaValue::Identifier("OFF".to_string())),
574 },
575 &mut db,
576 )
577 .unwrap();
578 assert_eq!(db.auto_vacuum_threshold(), None);
579
580 let out = execute_pragma(
582 PragmaStatement {
583 name: "auto_vacuum".to_string(),
584 value: None,
585 },
586 &mut db,
587 )
588 .unwrap();
589 let rendered = out.rendered.unwrap();
590 assert!(rendered.contains("OFF"));
591 }
592
593 #[test]
594 fn execute_pragma_auto_vacuum_rejects_out_of_range() {
595 let mut db = Database::new("t".to_string());
596 let err = execute_pragma(
597 PragmaStatement {
598 name: "auto_vacuum".to_string(),
599 value: Some(PragmaValue::Number("1.5".to_string())),
600 },
601 &mut db,
602 )
603 .unwrap_err();
604 assert!(format!("{err}").contains("auto_vacuum_threshold"));
605
606 assert_eq!(db.auto_vacuum_threshold(), Some(0.25));
608 }
609
610 #[test]
611 fn execute_pragma_table_list_lists_tables_and_catalog() {
612 use crate::sql::process_command;
613
614 let mut db = Database::new("t".to_string());
615 process_command(
616 "CREATE TABLE users (id INTEGER PRIMARY KEY, email TEXT);",
617 &mut db,
618 )
619 .unwrap();
620 process_command("CREATE TABLE posts (id INTEGER PRIMARY KEY);", &mut db).unwrap();
621
622 let out = execute_pragma(
623 PragmaStatement {
624 name: "table_list".to_string(),
625 value: None,
626 },
627 &mut db,
628 )
629 .unwrap();
630 let rendered = out.rendered.expect("table_list renders rows");
631 assert!(rendered.contains("users"), "lists user table 'users'");
632 assert!(rendered.contains("posts"), "lists user table 'posts'");
633 assert!(
634 rendered.contains("sqlrite_master"),
635 "lists the catalog table"
636 );
637 assert!(rendered.contains("ncol"));
639 assert!(out.status.contains("3 rows"), "status: {}", out.status);
641 }
642
643 #[test]
644 fn execute_pragma_table_list_rejects_value() {
645 let mut db = Database::new("t".to_string());
646 let err = execute_pragma(
647 PragmaStatement {
648 name: "table_list".to_string(),
649 value: Some(PragmaValue::Identifier("x".to_string())),
650 },
651 &mut db,
652 )
653 .unwrap_err();
654 assert!(format!("{err}").contains("does not take a value"));
655 }
656
657 #[test]
658 fn execute_pragma_auto_vacuum_rejects_negative() {
659 let mut db = Database::new("t".to_string());
660 let err = execute_pragma(
661 PragmaStatement {
662 name: "auto_vacuum".to_string(),
663 value: Some(PragmaValue::Number("-0.1".to_string())),
664 },
665 &mut db,
666 )
667 .unwrap_err();
668 assert!(format!("{err}").contains("auto_vacuum_threshold"));
669 }
670}