use std::path::Path;
use crate::error::{Result, StoreError};
#[derive(Clone, Debug, PartialEq)]
pub enum Value {
Null,
Int(i64),
Real(f64),
Text(String),
Blob(Vec<u8>),
}
pub type Row = Vec<Value>;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Dialect {
Sqlite,
Postgres,
}
pub trait Backend {
fn dialect(&self) -> Dialect;
fn exec(&self, sql: &str, params: &[Value]) -> Result<u64>;
fn query(&self, sql: &str, params: &[Value]) -> Result<Vec<Row>>;
}
pub struct SqliteBackend {
conn: rusqlite::Connection,
}
impl SqliteBackend {
pub fn open(path: &Path) -> Result<Self> {
let conn = rusqlite::Connection::open(path)?;
Self::apply_pragmas(&conn)?;
Ok(Self { conn })
}
pub fn in_memory() -> Result<Self> {
let conn = rusqlite::Connection::open_in_memory()?;
Ok(Self { conn })
}
fn apply_pragmas(conn: &rusqlite::Connection) -> Result<()> {
conn.busy_timeout(std::time::Duration::from_secs(5))?;
conn.pragma_update(None, "journal_mode", "WAL")?;
conn.pragma_update(None, "synchronous", "NORMAL")?;
Ok(())
}
}
impl Backend for SqliteBackend {
fn dialect(&self) -> Dialect {
Dialect::Sqlite
}
fn exec(&self, sql: &str, params: &[Value]) -> Result<u64> {
let n = self
.conn
.execute(sql, rusqlite::params_from_iter(params.iter()))?;
Ok(n as u64)
}
fn query(&self, sql: &str, params: &[Value]) -> Result<Vec<Row>> {
let mut stmt = self.conn.prepare(sql)?;
let ncols = stmt.column_count();
let rows = stmt
.query_map(rusqlite::params_from_iter(params.iter()), |row| {
(0..ncols)
.map(|i| row.get_ref(i).map(value_from_ref))
.collect::<rusqlite::Result<Row>>()
})?
.collect::<rusqlite::Result<Vec<Row>>>()?;
Ok(rows)
}
}
fn value_from_ref(v: rusqlite::types::ValueRef<'_>) -> Value {
use rusqlite::types::ValueRef;
match v {
ValueRef::Null => Value::Null,
ValueRef::Integer(i) => Value::Int(i),
ValueRef::Real(f) => Value::Real(f),
ValueRef::Text(t) => Value::Text(String::from_utf8_lossy(t).into_owned()),
ValueRef::Blob(b) => Value::Blob(b.to_vec()),
}
}
impl rusqlite::ToSql for Value {
fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
use rusqlite::types::{ToSqlOutput, Value as SqlValue, ValueRef};
Ok(match self {
Value::Null => ToSqlOutput::Owned(SqlValue::Null),
Value::Int(i) => ToSqlOutput::Owned(SqlValue::Integer(*i)),
Value::Real(f) => ToSqlOutput::Owned(SqlValue::Real(*f)),
Value::Text(s) => ToSqlOutput::Borrowed(ValueRef::Text(s.as_bytes())),
Value::Blob(b) => ToSqlOutput::Borrowed(ValueRef::Blob(b)),
})
}
}
pub(crate) fn blob32(v: &Value) -> Result<[u8; 32]> {
match v {
Value::Blob(b) if b.len() == 32 => {
let mut out = [0u8; 32];
out.copy_from_slice(b);
Ok(out)
}
Value::Blob(b) => Err(StoreError::MalformedRow(format!(
"expected 32-byte hash, got {} bytes",
b.len()
))),
other => Err(StoreError::MalformedRow(format!(
"expected blob hash, got {other:?}"
))),
}
}
pub(crate) fn as_u64(v: &Value) -> Result<u64> {
match v {
Value::Int(i) => Ok(*i as u64),
other => Err(StoreError::MalformedRow(format!(
"expected integer, got {other:?}"
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trips_values() {
let db = SqliteBackend::in_memory().unwrap();
db.exec(
"CREATE TABLE t (i INTEGER, r REAL, s TEXT, b BLOB, n INTEGER)",
&[],
)
.unwrap();
db.exec(
"INSERT INTO t (i, r, s, b, n) VALUES (?, ?, ?, ?, ?)",
&[
Value::Int(42),
Value::Real(1.5),
Value::Text("hi".into()),
Value::Blob(vec![1, 2, 3]),
Value::Null,
],
)
.unwrap();
let rows = db.query("SELECT i, r, s, b, n FROM t", &[]).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(
rows[0],
vec![
Value::Int(42),
Value::Real(1.5),
Value::Text("hi".into()),
Value::Blob(vec![1, 2, 3]),
Value::Null,
]
);
}
#[test]
fn dialect_is_sqlite() {
let db = SqliteBackend::in_memory().unwrap();
assert_eq!(db.dialect(), Dialect::Sqlite);
}
}