use std::{
collections::HashSet,
fs, io,
marker::PhantomData,
path::{Path, PathBuf},
sync::{Arc, RwLock},
};
use crate::Dataset;
use gix_tempfile::{
AutoRemove, ContainingDirectory, Handle,
handle::{Writable, persist},
};
use r2d2::{Pool, PooledConnection};
use r2d2_sqlite::{
SqliteConnectionManager,
rusqlite::{OpenFlags, OptionalExtension},
};
use sanitize_filename::sanitize;
use serde::{Serialize, de::DeserializeOwned};
use serde_rusqlite::{columns_from_statement, from_row_with_columns};
pub type Result<T> = core::result::Result<T, SqliteDatasetError>;
#[derive(thiserror::Error, Debug)]
pub enum SqliteDatasetError {
#[error("IO error: {0}")]
Io(#[from] io::Error),
#[error("Sql error: {0}")]
Sql(#[from] serde_rusqlite::rusqlite::Error),
#[error("Serde error: {0}")]
Serde(#[from] rmp_serde::encode::Error),
#[error("Overwrite flag is set to false and the database file already exists: {0}")]
FileExists(PathBuf),
#[error("Failed to create connection pool: {0}")]
ConnectionPool(#[from] r2d2::Error),
#[error("Could not persist the temporary database file: {0}")]
PersistDbFile(#[from] persist::Error<Writable>),
#[error("{0}")]
Other(&'static str),
}
impl From<&'static str> for SqliteDatasetError {
fn from(s: &'static str) -> Self {
SqliteDatasetError::Other(s)
}
}
#[derive(Debug)]
pub struct SqliteDataset<I> {
db_file: PathBuf,
split: String,
conn_pool: Pool<SqliteConnectionManager>,
columns: Vec<String>,
len: usize,
select_statement: String,
row_serialized: bool,
phantom: PhantomData<I>,
}
impl<I> SqliteDataset<I> {
pub fn from_db_file<P: AsRef<Path>>(db_file: P, split: &str) -> Result<Self> {
let conn_pool = create_conn_pool(&db_file, false)?;
let row_serialized = Self::check_if_row_serialized(&conn_pool, split)?;
let select_statement = if row_serialized {
format!("select item from {split} where row_id = ?")
} else {
format!("select * from {split} where row_id = ?")
};
let (columns, len) = fetch_columns_and_len(&conn_pool, &select_statement, split)?;
Ok(SqliteDataset {
db_file: db_file.as_ref().to_path_buf(),
split: split.to_string(),
conn_pool,
columns,
len,
select_statement,
row_serialized,
phantom: PhantomData,
})
}
fn check_if_row_serialized(
conn_pool: &Pool<SqliteConnectionManager>,
split: &str,
) -> Result<bool> {
struct Column {
name: String,
ty: String,
}
const COLUMN_NAME: usize = 1;
const COLUMN_TYPE: usize = 2;
let sql_statement = format!("PRAGMA table_info({split})");
let conn = conn_pool.get()?;
let mut stmt = conn.prepare(sql_statement.as_str())?;
let column_iter = stmt.query_map([], |row| {
Ok(Column {
name: row
.get::<usize, String>(COLUMN_NAME)
.unwrap()
.to_lowercase(),
ty: row
.get::<usize, String>(COLUMN_TYPE)
.unwrap()
.to_lowercase(),
})
})?;
let mut columns: Vec<Column> = vec![];
for column in column_iter {
columns.push(column?);
}
if columns.len() != 2 {
Ok(false)
} else {
Ok(columns[0].name == "row_id"
&& columns[0].ty == "integer"
&& columns[1].name == "item"
&& columns[1].ty == "blob")
}
}
pub fn db_file(&self) -> PathBuf {
self.db_file.clone()
}
pub fn split(&self) -> &str {
self.split.as_str()
}
}
impl<I> Dataset<I> for SqliteDataset<I>
where
I: Clone + Send + Sync + DeserializeOwned,
{
fn get(&self, index: usize) -> Option<I> {
let row_id = index + 1;
let connection = self.conn_pool.get().unwrap();
let mut statement = connection.prepare(self.select_statement.as_str()).unwrap();
if self.row_serialized {
statement
.query_row([row_id], |row| {
Ok(
rmp_serde::from_slice::<I>(row.get_ref(0).unwrap().as_blob().unwrap())
.unwrap(),
)
})
.optional() .unwrap()
} else {
statement
.query_row([row_id], |row| {
Ok(from_row_with_columns::<I>(row, &self.columns).unwrap())
})
.optional() .unwrap()
}
}
fn len(&self) -> usize {
self.len
}
}
fn fetch_columns_and_len(
conn_pool: &Pool<SqliteConnectionManager>,
select_statement: &str,
split: &str,
) -> Result<(Vec<String>, usize)> {
let connection = conn_pool.get()?;
let statement = connection.prepare(select_statement)?;
let columns = columns_from_statement(&statement);
let mut statement =
connection.prepare(format!("select coalesce(max(row_id), 0) from {split}").as_str())?;
let len = statement.query_row([], |row| {
let len: usize = row.get(0)?;
Ok(len)
})?;
Ok((columns, len))
}
fn create_conn_pool<P: AsRef<Path>>(
db_file: P,
write: bool,
) -> Result<Pool<SqliteConnectionManager>> {
let sqlite_flags = if write {
OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE
} else {
OpenFlags::SQLITE_OPEN_READ_ONLY
};
let manager = SqliteConnectionManager::file(db_file).with_flags(sqlite_flags);
Pool::new(manager).map_err(SqliteDatasetError::ConnectionPool)
}
#[derive(Clone, Debug)]
pub struct SqliteDatasetStorage {
name: Option<String>,
db_file: Option<PathBuf>,
base_dir: Option<PathBuf>,
}
impl SqliteDatasetStorage {
pub fn from_name(name: &str) -> Self {
SqliteDatasetStorage {
name: Some(name.to_string()),
db_file: None,
base_dir: None,
}
}
pub fn from_file<P: AsRef<Path>>(db_file: P) -> Self {
SqliteDatasetStorage {
name: None,
db_file: Some(db_file.as_ref().to_path_buf()),
base_dir: None,
}
}
pub fn with_base_dir<P: AsRef<Path>>(mut self, base_dir: P) -> Self {
self.base_dir = Some(base_dir.as_ref().to_path_buf());
self
}
pub fn exists(&self) -> bool {
self.db_file().exists()
}
pub fn db_file(&self) -> PathBuf {
match &self.db_file {
Some(db_file) => db_file.clone(),
None => {
let name = sanitize(self.name.as_ref().expect("Name is not set"));
Self::base_dir(self.base_dir.to_owned()).join(format!("{name}.db"))
}
}
}
pub fn base_dir(base_dir: Option<PathBuf>) -> PathBuf {
match base_dir {
Some(base_dir) => base_dir,
None => dirs::cache_dir()
.expect("Could not get cache directory")
.join("burn-dataset"),
}
}
pub fn writer<I>(&self, overwrite: bool) -> Result<SqliteDatasetWriter<I>>
where
I: Clone + Send + Sync + Serialize + DeserializeOwned,
{
SqliteDatasetWriter::new(self.db_file(), overwrite)
}
pub fn reader<I>(&self, split: &str) -> Result<SqliteDataset<I>>
where
I: Clone + Send + Sync + Serialize + DeserializeOwned,
{
if !self.exists() {
panic!("The database file does not exist");
}
SqliteDataset::from_db_file(self.db_file(), split)
}
}
#[derive(Debug)]
pub struct SqliteDatasetWriter<I> {
db_file: PathBuf,
db_file_tmp: Option<Handle<Writable>>,
splits: Arc<RwLock<HashSet<String>>>,
overwrite: bool,
conn_pool: Option<Pool<SqliteConnectionManager>>,
is_completed: Arc<RwLock<bool>>,
phantom: PhantomData<I>,
}
impl<I> SqliteDatasetWriter<I>
where
I: Clone + Send + Sync + Serialize + DeserializeOwned,
{
pub fn new<P: AsRef<Path>>(db_file: P, overwrite: bool) -> Result<Self> {
let writer = Self {
db_file: db_file.as_ref().to_path_buf(),
db_file_tmp: None,
splits: Arc::new(RwLock::new(HashSet::new())),
overwrite,
conn_pool: None,
is_completed: Arc::new(RwLock::new(false)),
phantom: PhantomData,
};
writer.init()
}
fn init(mut self) -> Result<Self> {
if self.db_file.exists() {
if self.overwrite {
fs::remove_file(&self.db_file)?;
} else {
return Err(SqliteDatasetError::FileExists(self.db_file));
}
}
let db_file_dir = self
.db_file
.parent()
.ok_or("Unable to get parent directory")?;
if !db_file_dir.exists() {
fs::create_dir_all(db_file_dir)?;
}
let mut db_file_tmp = self.db_file.clone();
db_file_tmp.set_extension("db.tmp");
if db_file_tmp.exists() {
fs::remove_file(&db_file_tmp)?;
}
gix_tempfile::signal::setup(Default::default());
self.db_file_tmp = Some(gix_tempfile::writable_at(
&db_file_tmp,
ContainingDirectory::Exists,
AutoRemove::Tempfile,
)?);
let conn_pool = create_conn_pool(db_file_tmp, true)?;
self.conn_pool = Some(conn_pool);
Ok(self)
}
pub fn write(&self, split: &str, item: &I) -> Result<usize> {
let is_completed = self.is_completed.read().unwrap();
if *is_completed {
return Err(SqliteDatasetError::Other(
"Cannot save to a completed dataset writer",
));
}
if !self.splits.read().unwrap().contains(split) {
self.create_table(split)?;
}
let conn_pool = self.conn_pool.as_ref().unwrap();
let conn = conn_pool.get()?;
let serialized_item = rmp_serde::to_vec(item)?;
pragma_update_with_error_handling(&conn, "synchronous", "OFF")?;
pragma_update_with_error_handling(&conn, "journal_mode", "OFF")?;
let insert_statement = format!("insert into {split} (item) values (?)");
conn.execute(insert_statement.as_str(), [serialized_item])?;
let index = (conn.last_insert_rowid() - 1) as usize;
Ok(index)
}
pub fn set_completed(&mut self) -> Result<()> {
let mut is_completed = self.is_completed.write().unwrap();
if let Some(pool) = self.conn_pool.take() {
std::mem::drop(pool);
}
let _file_result = self
.db_file_tmp
.take() .unwrap() .persist(&self.db_file)?
.ok_or("Unable to persist the database file")?;
*is_completed = true;
Ok(())
}
fn create_table(&self, split: &str) -> Result<()> {
if self.splits.read().unwrap().contains(split) {
return Ok(());
}
let conn_pool = self.conn_pool.as_ref().unwrap();
let connection = conn_pool.get()?;
let create_table_statement = format!(
"create table if not exists {split} (row_id integer primary key autoincrement not \
null, item blob not null)"
);
connection.execute(create_table_statement.as_str(), [])?;
self.splits.write().unwrap().insert(split.to_string());
Ok(())
}
}
fn pragma_update_with_error_handling(
conn: &PooledConnection<SqliteConnectionManager>,
setting: &str,
value: &str,
) -> Result<()> {
let result = conn.pragma_update(None, setting, value);
if let Err(error) = result
&& error != rusqlite::Error::ExecuteReturnedResults
{
return Err(SqliteDatasetError::Sql(error));
}
Ok(())
}
#[cfg(test)]
mod tests {
use rayon::prelude::*;
use rstest::{fixture, rstest};
use serde::{Deserialize, Serialize};
use tempfile::{NamedTempFile, TempDir, tempdir};
use super::*;
type SqlDs = SqliteDataset<Sample>;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Sample {
column_str: String,
column_bytes: Vec<u8>,
column_int: i64,
column_bool: bool,
column_float: f64,
}
#[fixture]
fn train_dataset() -> SqlDs {
SqliteDataset::<Sample>::from_db_file("tests/data/sqlite-dataset.db", "train").unwrap()
}
#[rstest]
pub fn len(train_dataset: SqlDs) {
assert_eq!(train_dataset.len(), 2);
}
#[rstest]
pub fn get_some(train_dataset: SqlDs) {
let item = train_dataset.get(0).unwrap();
assert_eq!(item.column_str, "HI1");
assert_eq!(item.column_bytes, vec![55, 231, 159]);
assert_eq!(item.column_int, 1);
assert!(item.column_bool);
assert_eq!(item.column_float, 1.0);
}
#[rstest]
pub fn get_none(train_dataset: SqlDs) {
assert_eq!(train_dataset.get(10), None);
}
#[rstest]
pub fn multi_thread(train_dataset: SqlDs) {
let indices: Vec<usize> = vec![0, 1, 1, 3, 4, 5, 6, 0, 8, 1];
let results: Vec<Option<Sample>> =
indices.par_iter().map(|&i| train_dataset.get(i)).collect();
let mut match_count = 0;
for (_index, result) in indices.iter().zip(results.iter()) {
if let Some(_val) = result {
match_count += 1
}
}
assert_eq!(match_count, 5);
}
#[test]
fn sqlite_dataset_storage() {
let storage = SqliteDatasetStorage::from_file("non-existing.db");
assert!(!storage.exists());
let storage = SqliteDatasetStorage::from_name("non-existing.db");
assert!(!storage.exists());
let storage = SqliteDatasetStorage::from_file("tests/data/sqlite-dataset.db");
assert!(storage.exists());
let result = storage.reader::<Sample>("train");
assert!(result.is_ok());
let train = result.unwrap();
assert_eq!(train.len(), 2);
let temp_file = NamedTempFile::new().unwrap();
let storage = SqliteDatasetStorage::from_file(temp_file.path());
assert!(storage.exists());
let result = storage.writer::<Sample>(true);
assert!(result.is_ok());
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Complex {
column_str: String,
column_bytes: Vec<u8>,
column_int: i64,
column_bool: bool,
column_float: f64,
column_complex: Vec<Vec<Vec<[u8; 3]>>>,
}
#[fixture]
fn tmp_dir() -> TempDir {
tempdir().unwrap()
}
type Writer = SqliteDatasetWriter<Complex>;
#[fixture]
fn writer_fixture(tmp_dir: TempDir) -> (Writer, TempDir) {
let temp_dir_str = tmp_dir.path();
let storage = SqliteDatasetStorage::from_name("preprocessed").with_base_dir(temp_dir_str);
let overwrite = true;
let result = storage.writer::<Complex>(overwrite);
assert!(result.is_ok());
let writer = result.unwrap();
(writer, tmp_dir)
}
#[test]
fn test_new() {
let test_path = NamedTempFile::new().unwrap();
let _writer = SqliteDatasetWriter::<Complex>::new(&test_path, true).unwrap();
assert!(!test_path.path().exists());
let test_path = NamedTempFile::new().unwrap();
let result = SqliteDatasetWriter::<Complex>::new(&test_path, false);
assert!(result.is_err());
let temp = NamedTempFile::new().unwrap();
let test_path = temp.path().to_path_buf();
assert!(temp.close().is_ok());
assert!(!test_path.exists());
let _writer = SqliteDatasetWriter::<Complex>::new(&test_path, true).unwrap();
assert!(!test_path.exists());
}
#[rstest]
pub fn sqlite_writer_write(writer_fixture: (Writer, TempDir)) {
let (writer, _tmp_dir) = writer_fixture;
assert!(writer.overwrite);
assert!(!writer.db_file.exists());
let new_item = Complex {
column_str: "HI1".to_string(),
column_bytes: vec![1_u8, 2, 3],
column_int: 0,
column_bool: true,
column_float: 1.0,
column_complex: vec![vec![vec![[1, 23_u8, 3]]]],
};
let index = writer.write("train", &new_item).unwrap();
assert_eq!(index, 0);
let mut writer = writer;
writer.set_completed().expect("Failed to set completed");
assert!(writer.db_file.exists());
assert!(writer.db_file_tmp.is_none());
let result = writer.write("train", &new_item);
assert!(result.is_err());
let dataset = SqliteDataset::<Complex>::from_db_file(writer.db_file, "train").unwrap();
let fetched_item = dataset.get(0).unwrap();
assert_eq!(fetched_item, new_item);
assert_eq!(dataset.len(), 1);
}
#[rstest]
pub fn sqlite_writer_write_multi_thread(writer_fixture: (Writer, TempDir)) {
let (writer, _tmp_dir) = writer_fixture;
let writer = Arc::new(writer);
let record_count = 20;
let splits = ["train", "test"];
(0..record_count).into_par_iter().for_each(|index: i64| {
let thread_id: std::thread::ThreadId = std::thread::current().id();
let sample = Complex {
column_str: format!("test_{thread_id:?}_{index}"),
column_bytes: vec![index as u8, 2, 3],
column_int: index,
column_bool: true,
column_float: 1.0,
column_complex: vec![vec![vec![[1, index as u8, 3]]]],
};
let split = splits[index as usize % 2];
let _index = writer.write(split, &sample).unwrap();
});
let mut writer = Arc::try_unwrap(writer).unwrap();
writer
.set_completed()
.expect("Should set completed successfully");
let train =
SqliteDataset::<Complex>::from_db_file(writer.db_file.clone(), "train").unwrap();
let test = SqliteDataset::<Complex>::from_db_file(writer.db_file, "test").unwrap();
assert_eq!(train.len(), record_count as usize / 2);
assert_eq!(test.len(), record_count as usize / 2);
}
}