use crate::fsentry::FilesystemEntry;
use rusqlite::{params, types::ToSqlOutput, CachedStatement, Connection, OpenFlags, Row, ToSql};
use std::collections::HashSet;
use std::convert::TryFrom;
use std::path::{Path, PathBuf};
pub struct Database {
conn: Connection,
}
impl Database {
pub fn create<P: AsRef<Path>>(filename: P) -> Result<Self, DatabaseError> {
if filename.as_ref().exists() {
return Err(DatabaseError::Exists(filename.as_ref().to_path_buf()));
}
let flags = OpenFlags::SQLITE_OPEN_CREATE | OpenFlags::SQLITE_OPEN_READ_WRITE;
let conn = Connection::open_with_flags(filename, flags)?;
conn.execute("BEGIN", params![])?;
Ok(Self { conn })
}
pub fn open<P: AsRef<Path>>(filename: P) -> Result<Self, DatabaseError> {
let flags = OpenFlags::SQLITE_OPEN_READ_ONLY;
let conn = Connection::open_with_flags(filename, flags)?;
Ok(Self { conn })
}
pub fn close(self) -> Result<(), DatabaseError> {
self.conn.execute("COMMIT", params![])?;
self.conn
.close()
.map_err(|(_, err)| DatabaseError::Rusqlite(err))?;
Ok(())
}
pub fn create_table(&self, table: &Table) -> Result<(), DatabaseError> {
let sql = sql_statement::create_table(table);
self.conn.execute(&sql, params![])?;
Ok(())
}
pub fn create_index(
&self,
name: &str,
table: &Table,
column: &str,
) -> Result<(), DatabaseError> {
let sql = sql_statement::create_index(name, table, column);
self.conn.execute(&sql, params![])?;
Ok(())
}
pub fn insert(&mut self, table: &Table, values: &[Value]) -> Result<(), DatabaseError> {
let mut stmt = self.conn.prepare_cached(table.insert())?;
assert!(table.has_columns(values));
stmt.execute(rusqlite::params_from_iter(values.iter().map(|v| {
v.to_sql()
.expect("conversion of Obnam value to SQLite value failed unexpectedly")
})))?;
Ok(())
}
pub fn all_rows<T>(
&self,
table: &Table,
rowfunc: &'static dyn Fn(&Row) -> Result<T, rusqlite::Error>,
) -> Result<SqlResults<T>, DatabaseError> {
let sql = sql_statement::select_all_rows(table);
SqlResults::new(
&self.conn,
&sql,
None,
Box::new(|stmt, _| {
let iter = stmt.query_map(params![], |row| rowfunc(row))?;
let iter = iter.map(|x| match x {
Ok(t) => Ok(t),
Err(e) => Err(DatabaseError::Rusqlite(e)),
});
Ok(Box::new(iter))
}),
)
}
pub fn some_rows<T>(
&self,
table: &Table,
value: &Value,
rowfunc: &'static dyn Fn(&Row) -> Result<T, rusqlite::Error>,
) -> Result<SqlResults<T>, DatabaseError> {
assert!(table.has_column(value));
let sql = sql_statement::select_some_rows(table, value.name());
SqlResults::new(
&self.conn,
&sql,
Some(OwnedValue::from(value)),
Box::new(|stmt, value| {
let iter = stmt.query_map(params![value], |row| rowfunc(row))?;
let iter = iter.map(|x| match x {
Ok(t) => Ok(t),
Err(e) => Err(DatabaseError::Rusqlite(e)),
});
Ok(Box::new(iter))
}),
)
}
}
#[derive(Debug, thiserror::Error)]
pub enum DatabaseError {
#[error(transparent)]
Rusqlite(#[from] rusqlite::Error),
#[error("Database {0} already exists")]
Exists(PathBuf),
}
type SqlResultsIterator<'stmt, T> = Box<dyn Iterator<Item = Result<T, DatabaseError>> + 'stmt>;
type CreateIterFn<'conn, ItemT> = Box<
dyn for<'stmt> Fn(
&'stmt mut CachedStatement<'conn>,
&Option<OwnedValue>,
) -> Result<SqlResultsIterator<'stmt, ItemT>, DatabaseError>,
>;
pub struct SqlResults<'conn, ItemT> {
stmt: CachedStatement<'conn>,
value: Option<OwnedValue>,
create_iter: CreateIterFn<'conn, ItemT>,
}
impl<'conn, ItemT> SqlResults<'conn, ItemT> {
fn new(
conn: &'conn Connection,
statement: &str,
value: Option<OwnedValue>,
create_iter: CreateIterFn<'conn, ItemT>,
) -> Result<Self, DatabaseError> {
let stmt = conn.prepare_cached(statement)?;
Ok(Self {
stmt,
value,
create_iter,
})
}
pub fn iter(&'_ mut self) -> Result<SqlResultsIterator<'_, ItemT>, DatabaseError> {
(self.create_iter)(&mut self.stmt, &self.value)
}
}
pub struct Table {
table: String,
columns: Vec<Column>,
insert: Option<String>,
column_names: HashSet<String>,
}
impl Table {
pub fn new(table: &str) -> Self {
Self {
table: table.to_string(),
columns: vec![],
insert: None,
column_names: HashSet::new(),
}
}
pub fn column(mut self, column: Column) -> Self {
self.column_names.insert(column.name().to_string());
self.columns.push(column);
self
}
pub fn build(mut self) -> Self {
assert!(self.insert.is_none());
self.insert = Some(sql_statement::insert(&self));
self
}
fn has_columns(&self, values: &[Value]) -> bool {
assert!(self.insert.is_some());
for v in values.iter() {
if !self.column_names.contains(v.name()) {
return false;
}
}
true
}
fn has_column(&self, value: &Value) -> bool {
assert!(self.insert.is_some());
self.column_names.contains(value.name())
}
fn insert(&self) -> &str {
assert!(self.insert.is_some());
self.insert.as_ref().unwrap()
}
pub fn name(&self) -> &str {
&self.table
}
pub fn num_columns(&self) -> usize {
self.columns.len()
}
pub fn column_names(&self) -> impl Iterator<Item = &str> {
self.columns.iter().map(|c| c.name())
}
pub fn column_definitions(&self) -> String {
let mut ret = String::new();
for c in self.columns.iter() {
if !ret.is_empty() {
ret.push(',');
}
ret.push_str(c.name());
ret.push(' ');
ret.push_str(c.typename());
}
ret
}
}
pub enum Column {
PrimaryKey(String),
Int(String),
Text(String),
Blob(String),
Bool(String),
}
impl Column {
fn name(&self) -> &str {
match self {
Self::PrimaryKey(name) => name,
Self::Int(name) => name,
Self::Text(name) => name,
Self::Blob(name) => name,
Self::Bool(name) => name,
}
}
fn typename(&self) -> &str {
match self {
Self::PrimaryKey(_) => "INTEGER PRIMARY KEY",
Self::Int(_) => "INTEGER",
Self::Text(_) => "TEXT",
Self::Blob(_) => "BLOB",
Self::Bool(_) => "BOOLEAN",
}
}
pub fn primary_key(name: &str) -> Self {
Self::PrimaryKey(name.to_string())
}
pub fn int(name: &str) -> Self {
Self::Int(name.to_string())
}
pub fn text(name: &str) -> Self {
Self::Text(name.to_string())
}
pub fn blob(name: &str) -> Self {
Self::Blob(name.to_string())
}
pub fn bool(name: &str) -> Self {
Self::Bool(name.to_string())
}
}
pub type DbInt = i64;
#[derive(Debug)]
pub enum Value<'a> {
PrimaryKey(&'a str, DbInt),
Int(&'a str, DbInt),
Text(&'a str, &'a str),
Blob(&'a str, &'a [u8]),
Bool(&'a str, bool),
}
impl<'a> Value<'a> {
pub fn name(&self) -> &str {
match self {
Self::PrimaryKey(name, _) => name,
Self::Int(name, _) => name,
Self::Text(name, _) => name,
Self::Blob(name, _) => name,
Self::Bool(name, _) => name,
}
}
pub fn primary_key(name: &'a str, value: DbInt) -> Self {
Self::PrimaryKey(name, value)
}
pub fn int(name: &'a str, value: DbInt) -> Self {
Self::Int(name, value)
}
pub fn text(name: &'a str, value: &'a str) -> Self {
Self::Text(name, value)
}
pub fn blob(name: &'a str, value: &'a [u8]) -> Self {
Self::Blob(name, value)
}
pub fn bool(name: &'a str, value: bool) -> Self {
Self::Bool(name, value)
}
}
#[allow(clippy::useless_conversion)]
impl<'a> ToSql for Value<'a> {
fn to_sql(&self) -> Result<rusqlite::types::ToSqlOutput, rusqlite::Error> {
use rusqlite::types::ValueRef;
let v = match self {
Self::PrimaryKey(_, v) => ValueRef::Integer(
i64::try_from(*v)
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?,
),
Self::Int(_, v) => ValueRef::Integer(
i64::try_from(*v)
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?,
),
Self::Bool(_, v) => ValueRef::Integer(
i64::try_from(*v)
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?,
),
Self::Text(_, v) => ValueRef::Text(v.as_ref()),
Self::Blob(_, v) => ValueRef::Blob(v),
};
Ok(ToSqlOutput::Borrowed(v))
}
}
pub enum OwnedValue {
PrimaryKey(String, DbInt),
Int(String, DbInt),
Text(String, String),
Blob(String, Vec<u8>),
Bool(String, bool),
}
impl From<&Value<'_>> for OwnedValue {
fn from(v: &Value) -> Self {
match *v {
Value::PrimaryKey(name, v) => Self::PrimaryKey(name.to_string(), v),
Value::Int(name, v) => Self::Int(name.to_string(), v),
Value::Text(name, v) => Self::Text(name.to_string(), v.to_string()),
Value::Blob(name, v) => Self::Blob(name.to_string(), v.to_vec()),
Value::Bool(name, v) => Self::Bool(name.to_string(), v),
}
}
}
impl ToSql for OwnedValue {
#[allow(clippy::useless_conversion)]
fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput> {
use rusqlite::types::Value;
let v = match self {
Self::PrimaryKey(_, v) => Value::Integer(
i64::try_from(*v)
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?,
),
Self::Int(_, v) => Value::Integer(
i64::try_from(*v)
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?,
),
Self::Bool(_, v) => Value::Integer(
i64::try_from(*v)
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?,
),
Self::Text(_, v) => Value::Text(v.to_string()),
Self::Blob(_, v) => Value::Blob(v.to_vec()),
};
Ok(ToSqlOutput::Owned(v))
}
}
impl rusqlite::types::ToSql for FilesystemEntry {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
let json = serde_json::to_string(self)
.map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
let json = rusqlite::types::Value::Text(json);
Ok(ToSqlOutput::Owned(json))
}
}
mod sql_statement {
use super::Table;
pub fn create_table(table: &Table) -> String {
format!(
"CREATE TABLE {} ({})",
table.name(),
table.column_definitions()
)
}
pub fn create_index(name: &str, table: &Table, column: &str) -> String {
format!("CREATE INDEX {} ON {} ({})", name, table.name(), column,)
}
pub fn insert(table: &Table) -> String {
format!(
"INSERT INTO {} ({}) VALUES ({})",
table.name(),
&column_names(table),
placeholders(table.column_names().count())
)
}
pub fn select_all_rows(table: &Table) -> String {
format!("SELECT * FROM {}", table.name())
}
pub fn select_some_rows(table: &Table, column: &str) -> String {
format!("SELECT * FROM {} WHERE {} = ?", table.name(), column)
}
fn column_names(table: &Table) -> String {
table.column_names().collect::<Vec<&str>>().join(",")
}
fn placeholders(num_columns: usize) -> String {
let mut s = String::new();
for _ in 0..num_columns {
if !s.is_empty() {
s.push(',');
}
s.push('?');
}
s
}
}
#[cfg(test)]
mod test {
use super::*;
use std::path::Path;
use tempfile::tempdir;
fn get_bar(row: &rusqlite::Row) -> Result<DbInt, rusqlite::Error> {
row.get("bar")
}
fn table() -> Table {
Table::new("foo").column(Column::int("bar")).build()
}
fn create_db(file: &Path) -> Database {
let table = table();
let db = Database::create(file).unwrap();
db.create_table(&table).unwrap();
db
}
fn open_db(file: &Path) -> Database {
Database::open(file).unwrap()
}
fn insert(db: &mut Database, value: DbInt) {
let table = table();
db.insert(&table, &[Value::int("bar", value)]).unwrap();
}
fn values(db: Database) -> Vec<DbInt> {
let table = table();
let mut rows = db.all_rows(&table, &get_bar).unwrap();
let iter = rows.iter().unwrap();
let mut values = vec![];
for x in iter {
values.push(x.unwrap());
}
values
}
#[test]
fn creates_db() {
let tmp = tempdir().unwrap();
let filename = tmp.path().join("test.db");
let db = Database::create(&filename).unwrap();
db.close().unwrap();
let _ = Database::open(&filename).unwrap();
}
#[test]
fn inserts_row() {
let tmp = tempdir().unwrap();
let filename = tmp.path().join("test.db");
let mut db = create_db(&filename);
insert(&mut db, 42);
db.close().unwrap();
let db = open_db(&filename);
let values = values(db);
assert_eq!(values, vec![42]);
}
#[test]
fn inserts_many_rows() {
const N: DbInt = 1000;
let tmp = tempdir().unwrap();
let filename = tmp.path().join("test.db");
let mut db = create_db(&filename);
for i in 0..N {
insert(&mut db, i);
}
db.close().unwrap();
let db = open_db(&filename);
let values = values(db);
assert_eq!(values.len() as DbInt, N);
let mut expected = vec![];
for i in 0..N {
expected.push(i);
}
assert_eq!(values, expected);
}
#[test]
fn round_trips_int_max() {
let tmp = tempdir().unwrap();
let filename = tmp.path().join("test.db");
let mut db = create_db(&filename);
insert(&mut db, DbInt::MAX);
db.close().unwrap();
let db = open_db(&filename);
let values = values(db);
assert_eq!(values, vec![DbInt::MAX]);
}
}