use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
use sqlparser::dialect::SQLiteDialect;
use sqlparser::keywords::Keyword;
use sqlparser::tokenizer::{Token, Tokenizer};
use crate::error::{Result, SQLRiteError};
use crate::mvcc::JournalMode;
use crate::sql::CommandOutput;
use crate::sql::db::database::Database;
#[derive(Debug, Clone, PartialEq)]
pub enum PragmaValue {
Number(String),
Identifier(String),
String(String),
}
#[derive(Debug, Clone, PartialEq)]
pub struct PragmaStatement {
pub name: String,
pub value: Option<PragmaValue>,
}
pub fn try_parse_pragma(sql: &str) -> Result<Option<PragmaStatement>> {
let dialect = SQLiteDialect {};
let tokens = Tokenizer::new(&dialect, sql)
.tokenize()
.map_err(|e| SQLRiteError::General(format!("PRAGMA tokenize error: {e}")))?;
let mut iter = tokens
.into_iter()
.filter(|t| !matches!(t, Token::Whitespace(_)))
.peekable();
match iter.peek() {
Some(Token::Word(w)) if w.keyword == Keyword::PRAGMA => {
iter.next();
}
_ => return Ok(None),
}
let name = match iter.next() {
Some(Token::Word(w)) => w.value,
Some(other) => {
return Err(SQLRiteError::General(format!(
"PRAGMA: expected pragma name, got {other:?}"
)));
}
None => {
return Err(SQLRiteError::General(
"PRAGMA: missing pragma name".to_string(),
));
}
};
let value = match iter.peek() {
None | Some(Token::SemiColon) => None,
Some(Token::Eq) => {
iter.next();
Some(read_pragma_value(&mut iter)?)
}
Some(Token::LParen) => {
iter.next();
let v = read_pragma_value(&mut iter)?;
match iter.next() {
Some(Token::RParen) => {}
Some(other) => {
return Err(SQLRiteError::General(format!(
"PRAGMA: expected ')' to close parenthesised value, got {other:?}"
)));
}
None => {
return Err(SQLRiteError::General(
"PRAGMA: expected ')' to close parenthesised value".to_string(),
));
}
}
Some(v)
}
Some(other) => {
return Err(SQLRiteError::General(format!(
"PRAGMA: expected '=', '(', ';' or end of statement after name, got {other:?}"
)));
}
};
if matches!(iter.peek(), Some(Token::SemiColon)) {
iter.next();
}
if let Some(extra) = iter.next() {
return Err(SQLRiteError::General(format!(
"PRAGMA: unexpected trailing content {extra:?}"
)));
}
Ok(Some(PragmaStatement { name, value }))
}
fn read_pragma_value<I>(iter: &mut std::iter::Peekable<I>) -> Result<PragmaValue>
where
I: Iterator<Item = Token>,
{
let mut neg = false;
let first = iter.next().ok_or_else(|| {
SQLRiteError::General("PRAGMA: missing value after '=' or '('".to_string())
})?;
let tok = if matches!(first, Token::Minus) {
neg = true;
iter.next()
.ok_or_else(|| SQLRiteError::General("PRAGMA: missing value after '-'".to_string()))?
} else {
first
};
Ok(match tok {
Token::Number(s, _) => {
if neg {
PragmaValue::Number(format!("-{s}"))
} else {
PragmaValue::Number(s)
}
}
Token::SingleQuotedString(s) | Token::DoubleQuotedString(s) => {
if neg {
return Err(SQLRiteError::General(
"PRAGMA: unary '-' is only valid in front of a number".to_string(),
));
}
PragmaValue::String(s)
}
Token::Word(w) => {
if neg {
return Err(SQLRiteError::General(
"PRAGMA: unary '-' is only valid in front of a number".to_string(),
));
}
PragmaValue::Identifier(w.value)
}
other => {
return Err(SQLRiteError::General(format!(
"PRAGMA: unsupported value token {other:?}"
)));
}
})
}
pub fn execute_pragma(stmt: PragmaStatement, db: &mut Database) -> Result<CommandOutput> {
match stmt.name.to_ascii_lowercase().as_str() {
"auto_vacuum" => pragma_auto_vacuum(stmt.value, db),
"journal_mode" => pragma_journal_mode(stmt.value, db),
other => Err(SQLRiteError::NotImplemented(format!(
"PRAGMA '{other}' is not supported"
))),
}
}
fn pragma_journal_mode(value: Option<PragmaValue>, db: &mut Database) -> Result<CommandOutput> {
match value {
None => render_journal_mode(db.journal_mode()),
Some(v) => {
let target = parse_journal_mode_target(&v)?;
db.set_journal_mode(target)?;
render_journal_mode(db.journal_mode())
}
}
}
fn render_journal_mode(mode: JournalMode) -> Result<CommandOutput> {
let mut t = PrintTable::new();
t.add_row(PrintRow::new(vec![PrintCell::new("journal_mode")]));
t.add_row(PrintRow::new(vec![PrintCell::new(mode.as_str())]));
Ok(CommandOutput {
status: "PRAGMA journal_mode executed. 1 row returned.".to_string(),
rendered: Some(t.to_string()),
})
}
fn parse_journal_mode_target(value: &PragmaValue) -> Result<JournalMode> {
let s = match value {
PragmaValue::Identifier(s) | PragmaValue::String(s) => s.as_str(),
PragmaValue::Number(s) => {
return Err(SQLRiteError::General(format!(
"PRAGMA journal_mode: expected 'wal' or 'mvcc', got numeric '{s}'"
)));
}
};
JournalMode::from_str_lossless(s).ok_or_else(|| {
SQLRiteError::General(format!(
"PRAGMA journal_mode: unknown mode '{s}' (supported: 'wal', 'mvcc')"
))
})
}
fn pragma_auto_vacuum(value: Option<PragmaValue>, db: &mut Database) -> Result<CommandOutput> {
match value {
None => {
let mut t = PrintTable::new();
t.add_row(PrintRow::new(vec![PrintCell::new("auto_vacuum")]));
let cell_value = match db.auto_vacuum_threshold() {
Some(v) => format!("{v}"),
None => "OFF".to_string(),
};
t.add_row(PrintRow::new(vec![PrintCell::new(&cell_value)]));
Ok(CommandOutput {
status: "PRAGMA auto_vacuum executed. 1 row returned.".to_string(),
rendered: Some(t.to_string()),
})
}
Some(v) => {
let new_threshold = parse_auto_vacuum_target(&v)?;
db.set_auto_vacuum_threshold(new_threshold)?;
Ok(CommandOutput {
status: "PRAGMA auto_vacuum executed.".to_string(),
rendered: None,
})
}
}
}
fn parse_auto_vacuum_target(value: &PragmaValue) -> Result<Option<f32>> {
match value {
PragmaValue::Identifier(s) | PragmaValue::String(s) => {
match s.to_ascii_lowercase().as_str() {
"off" | "none" => Ok(None),
_ => Err(SQLRiteError::General(format!(
"PRAGMA auto_vacuum: expected a number in 0.0..=1.0 or OFF/NONE, got '{s}'"
))),
}
}
PragmaValue::Number(s) => {
let f: f32 = s.parse().map_err(|_| {
SQLRiteError::General(format!("PRAGMA auto_vacuum: '{s}' is not a valid number"))
})?;
Ok(Some(f))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn try_parse_pragma_returns_none_for_non_pragma() {
assert!(try_parse_pragma("SELECT 1;").unwrap().is_none());
assert!(
try_parse_pragma("CREATE TABLE t (id INTEGER);")
.unwrap()
.is_none()
);
assert!(try_parse_pragma("").unwrap().is_none());
assert!(try_parse_pragma(" \n\t ").unwrap().is_none());
assert!(try_parse_pragma("-- hello\n").unwrap().is_none());
}
#[test]
fn try_parse_pragma_read_form() {
let stmt = try_parse_pragma("PRAGMA auto_vacuum;").unwrap().unwrap();
assert_eq!(stmt.name, "auto_vacuum");
assert_eq!(stmt.value, None);
let stmt = try_parse_pragma(" PRAGMA auto_vacuum ").unwrap().unwrap();
assert_eq!(stmt.name, "auto_vacuum");
assert_eq!(stmt.value, None);
let stmt = try_parse_pragma("pragma auto_vacuum;").unwrap().unwrap();
assert_eq!(stmt.name, "auto_vacuum");
}
#[test]
fn try_parse_pragma_eq_number() {
let stmt = try_parse_pragma("PRAGMA auto_vacuum = 0.5;")
.unwrap()
.unwrap();
assert_eq!(stmt.name, "auto_vacuum");
assert_eq!(stmt.value, Some(PragmaValue::Number("0.5".to_string())));
let stmt = try_parse_pragma("PRAGMA auto_vacuum = 0;")
.unwrap()
.unwrap();
assert_eq!(stmt.value, Some(PragmaValue::Number("0".to_string())));
let stmt = try_parse_pragma("PRAGMA auto_vacuum = -0.1;")
.unwrap()
.unwrap();
assert_eq!(stmt.value, Some(PragmaValue::Number("-0.1".to_string())));
}
#[test]
fn try_parse_pragma_eq_identifier() {
let stmt = try_parse_pragma("PRAGMA auto_vacuum = OFF;")
.unwrap()
.unwrap();
assert_eq!(stmt.value, Some(PragmaValue::Identifier("OFF".to_string())));
let stmt = try_parse_pragma("PRAGMA auto_vacuum = none;")
.unwrap()
.unwrap();
assert_eq!(
stmt.value,
Some(PragmaValue::Identifier("none".to_string()))
);
}
#[test]
fn try_parse_pragma_eq_string() {
let stmt = try_parse_pragma("PRAGMA auto_vacuum = 'OFF';")
.unwrap()
.unwrap();
assert_eq!(stmt.value, Some(PragmaValue::String("OFF".to_string())));
let stmt = try_parse_pragma("PRAGMA auto_vacuum = \"NONE\";")
.unwrap()
.unwrap();
assert_eq!(
stmt.value,
Some(PragmaValue::Identifier("NONE".to_string()))
);
}
#[test]
fn try_parse_pragma_paren_form() {
let stmt = try_parse_pragma("PRAGMA auto_vacuum(0.5);")
.unwrap()
.unwrap();
assert_eq!(stmt.value, Some(PragmaValue::Number("0.5".to_string())));
let stmt = try_parse_pragma("PRAGMA auto_vacuum (OFF);")
.unwrap()
.unwrap();
assert_eq!(stmt.value, Some(PragmaValue::Identifier("OFF".to_string())));
}
#[test]
fn try_parse_pragma_rejects_malformed() {
assert!(try_parse_pragma("PRAGMA;").is_err());
assert!(try_parse_pragma("PRAGMA = 0.5;").is_err());
assert!(try_parse_pragma("PRAGMA auto_vacuum =;").is_err());
assert!(try_parse_pragma("PRAGMA auto_vacuum (0.5;").is_err());
assert!(try_parse_pragma("PRAGMA auto_vacuum; SELECT 1;").is_err());
assert!(try_parse_pragma("PRAGMA auto_vacuum = -'OFF';").is_err());
}
#[test]
fn parse_auto_vacuum_target_disables_on_off_or_none() {
for raw in ["OFF", "off", "Off", "NONE", "none"] {
assert_eq!(
parse_auto_vacuum_target(&PragmaValue::Identifier(raw.to_string())).unwrap(),
None
);
assert_eq!(
parse_auto_vacuum_target(&PragmaValue::String(raw.to_string())).unwrap(),
None
);
}
}
#[test]
fn parse_auto_vacuum_target_passes_numbers_through() {
assert_eq!(
parse_auto_vacuum_target(&PragmaValue::Number("0.5".to_string())).unwrap(),
Some(0.5_f32)
);
assert_eq!(
parse_auto_vacuum_target(&PragmaValue::Number("0".to_string())).unwrap(),
Some(0.0_f32)
);
assert_eq!(
parse_auto_vacuum_target(&PragmaValue::Number("1.5".to_string())).unwrap(),
Some(1.5_f32)
);
}
#[test]
fn parse_auto_vacuum_target_rejects_unknown_strings() {
let err =
parse_auto_vacuum_target(&PragmaValue::Identifier("WAL".to_string())).unwrap_err();
assert!(format!("{err}").contains("OFF/NONE"));
}
#[test]
fn execute_pragma_unknown_returns_not_implemented() {
let mut db = Database::new("t".to_string());
let err = execute_pragma(
PragmaStatement {
name: "synchronous".to_string(),
value: None,
},
&mut db,
)
.unwrap_err();
assert!(matches!(err, SQLRiteError::NotImplemented(_)));
}
#[test]
fn execute_pragma_auto_vacuum_set_and_read() {
let mut db = Database::new("t".to_string());
let out = execute_pragma(
PragmaStatement {
name: "auto_vacuum".to_string(),
value: Some(PragmaValue::Number("0.5".to_string())),
},
&mut db,
)
.unwrap();
assert!(out.rendered.is_none());
assert_eq!(db.auto_vacuum_threshold(), Some(0.5));
let out = execute_pragma(
PragmaStatement {
name: "auto_vacuum".to_string(),
value: None,
},
&mut db,
)
.unwrap();
let rendered = out.rendered.expect("read form must render rows");
assert!(rendered.contains("auto_vacuum"));
assert!(rendered.contains("0.5"));
execute_pragma(
PragmaStatement {
name: "auto_vacuum".to_string(),
value: Some(PragmaValue::Identifier("OFF".to_string())),
},
&mut db,
)
.unwrap();
assert_eq!(db.auto_vacuum_threshold(), None);
let out = execute_pragma(
PragmaStatement {
name: "auto_vacuum".to_string(),
value: None,
},
&mut db,
)
.unwrap();
let rendered = out.rendered.unwrap();
assert!(rendered.contains("OFF"));
}
#[test]
fn execute_pragma_auto_vacuum_rejects_out_of_range() {
let mut db = Database::new("t".to_string());
let err = execute_pragma(
PragmaStatement {
name: "auto_vacuum".to_string(),
value: Some(PragmaValue::Number("1.5".to_string())),
},
&mut db,
)
.unwrap_err();
assert!(format!("{err}").contains("auto_vacuum_threshold"));
assert_eq!(db.auto_vacuum_threshold(), Some(0.25));
}
#[test]
fn execute_pragma_auto_vacuum_rejects_negative() {
let mut db = Database::new("t".to_string());
let err = execute_pragma(
PragmaStatement {
name: "auto_vacuum".to_string(),
value: Some(PragmaValue::Number("-0.1".to_string())),
},
&mut db,
)
.unwrap_err();
assert!(format!("{err}").contains("auto_vacuum_threshold"));
}
}