1use std::{
2 collections::HashSet,
3 fs, io,
4 marker::PhantomData,
5 path::{Path, PathBuf},
6 sync::{Arc, RwLock},
7};
8
9use crate::Dataset;
10
11use gix_tempfile::{
12 handle::{persist, Writable},
13 AutoRemove, ContainingDirectory, Handle,
14};
15use r2d2::{Pool, PooledConnection};
16use r2d2_sqlite::{
17 rusqlite::{OpenFlags, OptionalExtension},
18 SqliteConnectionManager,
19};
20use sanitize_filename::sanitize;
21use serde::{de::DeserializeOwned, Serialize};
22use serde_rusqlite::{columns_from_statement, from_row_with_columns};
23
24pub type Result<T> = core::result::Result<T, SqliteDatasetError>;
26
27#[derive(thiserror::Error, Debug)]
29pub enum SqliteDatasetError {
30 #[error("IO error: {0}")]
32 Io(#[from] io::Error),
33
34 #[error("Sql error: {0}")]
36 Sql(#[from] serde_rusqlite::rusqlite::Error),
37
38 #[error("Serde error: {0}")]
40 Serde(#[from] rmp_serde::encode::Error),
41
42 #[error("Overwrite flag is set to false and the database file already exists: {0}")]
44 FileExists(PathBuf),
45
46 #[error("Failed to create connection pool: {0}")]
48 ConnectionPool(#[from] r2d2::Error),
49
50 #[error("Could not persist the temporary database file: {0}")]
52 PersistDbFile(#[from] persist::Error<Writable>),
53
54 #[error("{0}")]
56 Other(&'static str),
57}
58
59impl From<&'static str> for SqliteDatasetError {
60 fn from(s: &'static str) -> Self {
61 SqliteDatasetError::Other(s)
62 }
63}
64
65#[derive(Debug)]
91pub struct SqliteDataset<I> {
92 db_file: PathBuf,
93 split: String,
94 conn_pool: Pool<SqliteConnectionManager>,
95 columns: Vec<String>,
96 len: usize,
97 select_statement: String,
98 row_serialized: bool,
99 phantom: PhantomData<I>,
100}
101
102impl<I> SqliteDataset<I> {
103 pub fn from_db_file<P: AsRef<Path>>(db_file: P, split: &str) -> Result<Self> {
105 let conn_pool = create_conn_pool(&db_file, false)?;
107
108 let row_serialized = Self::check_if_row_serialized(&conn_pool, split)?;
110
111 let select_statement = if row_serialized {
113 format!("select item from {split} where row_id = ?")
114 } else {
115 format!("select * from {split} where row_id = ?")
116 };
117
118 let (columns, len) = fetch_columns_and_len(&conn_pool, &select_statement, split)?;
120
121 Ok(SqliteDataset {
122 db_file: db_file.as_ref().to_path_buf(),
123 split: split.to_string(),
124 conn_pool,
125 columns,
126 len,
127 select_statement,
128 row_serialized,
129 phantom: PhantomData,
130 })
131 }
132
133 fn check_if_row_serialized(
137 conn_pool: &Pool<SqliteConnectionManager>,
138 split: &str,
139 ) -> Result<bool> {
140 struct Column {
142 name: String,
143 ty: String,
144 }
145
146 const COLUMN_NAME: usize = 1;
147 const COLUMN_TYPE: usize = 2;
148
149 let sql_statement = format!("PRAGMA table_info({split})");
150
151 let conn = conn_pool.get()?;
152
153 let mut stmt = conn.prepare(sql_statement.as_str())?;
154 let column_iter = stmt.query_map([], |row| {
155 Ok(Column {
156 name: row
157 .get::<usize, String>(COLUMN_NAME)
158 .unwrap()
159 .to_lowercase(),
160 ty: row
161 .get::<usize, String>(COLUMN_TYPE)
162 .unwrap()
163 .to_lowercase(),
164 })
165 })?;
166
167 let mut columns: Vec<Column> = vec![];
168
169 for column in column_iter {
170 columns.push(column?);
171 }
172
173 if columns.len() != 2 {
174 Ok(false)
175 } else {
176 Ok(columns[0].name == "row_id"
178 && columns[0].ty == "integer"
179 && columns[1].name == "item"
180 && columns[1].ty == "blob")
181 }
182 }
183
184 pub fn db_file(&self) -> PathBuf {
186 self.db_file.clone()
187 }
188
189 pub fn split(&self) -> &str {
191 self.split.as_str()
192 }
193}
194
195impl<I> Dataset<I> for SqliteDataset<I>
196where
197 I: Clone + Send + Sync + DeserializeOwned,
198{
199 fn get(&self, index: usize) -> Option<I> {
201 let row_id = index + 1;
203
204 let connection = self.conn_pool.get().unwrap();
206 let mut statement = connection.prepare(self.select_statement.as_str()).unwrap();
207
208 if self.row_serialized {
209 statement
211 .query_row([row_id], |row| {
212 Ok(
214 rmp_serde::from_slice::<I>(row.get_ref(0).unwrap().as_blob().unwrap())
215 .unwrap(),
216 )
217 })
218 .optional() .unwrap()
220 } else {
221 statement
223 .query_row([row_id], |row| {
224 Ok(from_row_with_columns::<I>(row, &self.columns).unwrap())
226 })
227 .optional() .unwrap()
229 }
230 }
231
232 fn len(&self) -> usize {
234 self.len
235 }
236}
237
238fn fetch_columns_and_len(
240 conn_pool: &Pool<SqliteConnectionManager>,
241 select_statement: &str,
242 split: &str,
243) -> Result<(Vec<String>, usize)> {
244 let connection = conn_pool.get()?;
246 let statement = connection.prepare(select_statement)?;
247 let columns = columns_from_statement(&statement);
248
249 let mut statement =
257 connection.prepare(format!("select coalesce(max(row_id), 0) from {split}").as_str())?;
258
259 let len = statement.query_row([], |row| {
260 let len: usize = row.get(0)?;
261 Ok(len)
262 })?;
263 Ok((columns, len))
264}
265
266fn create_conn_pool<P: AsRef<Path>>(
268 db_file: P,
269 write: bool,
270) -> Result<Pool<SqliteConnectionManager>> {
271 let sqlite_flags = if write {
272 OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE
273 } else {
274 OpenFlags::SQLITE_OPEN_READ_ONLY
275 };
276
277 let manager = SqliteConnectionManager::file(db_file).with_flags(sqlite_flags);
278 Pool::new(manager).map_err(SqliteDatasetError::ConnectionPool)
279}
280
281#[derive(Clone, Debug)]
284pub struct SqliteDatasetStorage {
285 name: Option<String>,
286 db_file: Option<PathBuf>,
287 base_dir: Option<PathBuf>,
288}
289
290impl SqliteDatasetStorage {
291 pub fn from_name(name: &str) -> Self {
297 SqliteDatasetStorage {
298 name: Some(name.to_string()),
299 db_file: None,
300 base_dir: None,
301 }
302 }
303
304 pub fn from_file<P: AsRef<Path>>(db_file: P) -> Self {
310 SqliteDatasetStorage {
311 name: None,
312 db_file: Some(db_file.as_ref().to_path_buf()),
313 base_dir: None,
314 }
315 }
316
317 pub fn with_base_dir<P: AsRef<Path>>(mut self, base_dir: P) -> Self {
323 self.base_dir = Some(base_dir.as_ref().to_path_buf());
324 self
325 }
326
327 pub fn exists(&self) -> bool {
333 self.db_file().exists()
334 }
335
336 pub fn db_file(&self) -> PathBuf {
342 let db_file = match &self.db_file {
343 Some(db_file) => db_file.clone(),
344 None => {
345 let name = sanitize(self.name.as_ref().expect("Name is not set"));
346 Self::base_dir(self.base_dir.to_owned()).join(format!("{name}.db"))
347 }
348 };
349 db_file
350 }
351
352 pub fn base_dir(base_dir: Option<PathBuf>) -> PathBuf {
362 match base_dir {
363 Some(base_dir) => base_dir,
364 None => {
365 let home_dir = dirs::home_dir().expect("Could not get home directory");
366
367 home_dir.join(".cache").join("burn-dataset")
368 }
369 }
370 }
371
372 pub fn writer<I>(&self, overwrite: bool) -> Result<SqliteDatasetWriter<I>>
382 where
383 I: Clone + Send + Sync + Serialize + DeserializeOwned,
384 {
385 SqliteDatasetWriter::new(self.db_file(), overwrite)
386 }
387
388 pub fn reader<I>(&self, split: &str) -> Result<SqliteDataset<I>>
398 where
399 I: Clone + Send + Sync + Serialize + DeserializeOwned,
400 {
401 if !self.exists() {
402 panic!("The database file does not exist");
403 }
404
405 SqliteDataset::from_db_file(self.db_file(), split)
406 }
407}
408
409#[derive(Debug)]
420pub struct SqliteDatasetWriter<I> {
421 db_file: PathBuf,
422 db_file_tmp: Option<Handle<Writable>>,
423 splits: Arc<RwLock<HashSet<String>>>,
424 overwrite: bool,
425 conn_pool: Option<Pool<SqliteConnectionManager>>,
426 is_completed: Arc<RwLock<bool>>,
427 phantom: PhantomData<I>,
428}
429
430impl<I> SqliteDatasetWriter<I>
431where
432 I: Clone + Send + Sync + Serialize + DeserializeOwned,
433{
434 pub fn new<P: AsRef<Path>>(db_file: P, overwrite: bool) -> Result<Self> {
445 let writer = Self {
446 db_file: db_file.as_ref().to_path_buf(),
447 db_file_tmp: None,
448 splits: Arc::new(RwLock::new(HashSet::new())),
449 overwrite,
450 conn_pool: None,
451 is_completed: Arc::new(RwLock::new(false)),
452 phantom: PhantomData,
453 };
454
455 writer.init()
456 }
457
458 fn init(mut self) -> Result<Self> {
464 if self.db_file.exists() {
466 if self.overwrite {
467 fs::remove_file(&self.db_file)?;
468 } else {
469 return Err(SqliteDatasetError::FileExists(self.db_file));
470 }
471 }
472
473 let db_file_dir = self
475 .db_file
476 .parent()
477 .ok_or("Unable to get parent directory")?;
478
479 if !db_file_dir.exists() {
480 fs::create_dir_all(db_file_dir)?;
481 }
482
483 let mut db_file_tmp = self.db_file.clone();
485 db_file_tmp.set_extension("db.tmp");
486 if db_file_tmp.exists() {
487 fs::remove_file(&db_file_tmp)?;
488 }
489
490 gix_tempfile::signal::setup(Default::default());
494 self.db_file_tmp = Some(gix_tempfile::writable_at(
495 &db_file_tmp,
496 ContainingDirectory::Exists,
497 AutoRemove::Tempfile,
498 )?);
499
500 let conn_pool = create_conn_pool(db_file_tmp, true)?;
501 self.conn_pool = Some(conn_pool);
502
503 Ok(self)
504 }
505
506 pub fn write(&self, split: &str, item: &I) -> Result<usize> {
519 let is_completed = self.is_completed.read().unwrap();
521
522 if *is_completed {
524 return Err(SqliteDatasetError::Other(
525 "Cannot save to a completed dataset writer",
526 ));
527 }
528
529 if !self.splits.read().unwrap().contains(split) {
531 self.create_table(split)?;
532 }
533
534 let conn_pool = self.conn_pool.as_ref().unwrap();
536 let conn = conn_pool.get()?;
537
538 let serialized_item = rmp_serde::to_vec(item)?;
540
541 pragma_update_with_error_handling(&conn, "synchronous", "OFF")?;
545 pragma_update_with_error_handling(&conn, "journal_mode", "OFF")?;
546
547 let insert_statement = format!("insert into {split} (item) values (?)", split = split);
549 conn.execute(insert_statement.as_str(), [serialized_item])?;
550
551 let index = (conn.last_insert_rowid() - 1) as usize;
553
554 Ok(index)
555 }
556
557 pub fn set_completed(&mut self) -> Result<()> {
559 let mut is_completed = self.is_completed.write().unwrap();
560
561 if let Some(pool) = self.conn_pool.take() {
565 std::mem::drop(pool);
566 }
567
568 let _file_result = self
570 .db_file_tmp
571 .take() .unwrap() .persist(&self.db_file)?
574 .ok_or("Unable to persist the database file")?;
575
576 *is_completed = true;
577 Ok(())
578 }
579
580 fn create_table(&self, split: &str) -> Result<()> {
594 if self.splits.read().unwrap().contains(split) {
596 return Ok(());
597 }
598
599 let conn_pool = self.conn_pool.as_ref().unwrap();
600 let connection = conn_pool.get()?;
601 let create_table_statement = format!(
602 "create table if not exists {split} (row_id integer primary key autoincrement not \
603 null, item blob not null)"
604 );
605
606 connection.execute(create_table_statement.as_str(), [])?;
607
608 self.splits.write().unwrap().insert(split.to_string());
610
611 Ok(())
612 }
613}
614
615fn pragma_update_with_error_handling(
621 conn: &PooledConnection<SqliteConnectionManager>,
622 setting: &str,
623 value: &str,
624) -> Result<()> {
625 let result = conn.pragma_update(None, setting, value);
626 if let Err(error) = result {
627 if error != rusqlite::Error::ExecuteReturnedResults {
628 return Err(SqliteDatasetError::Sql(error));
629 }
630 }
631 Ok(())
632}
633
634#[cfg(test)]
635mod tests {
636 use rayon::prelude::*;
637 use rstest::{fixture, rstest};
638 use serde::{Deserialize, Serialize};
639 use tempfile::{tempdir, NamedTempFile, TempDir};
640
641 use super::*;
642
643 type SqlDs = SqliteDataset<Sample>;
644
645 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
646 pub struct Sample {
647 column_str: String,
648 column_bytes: Vec<u8>,
649 column_int: i64,
650 column_bool: bool,
651 column_float: f64,
652 }
653
654 #[fixture]
655 fn train_dataset() -> SqlDs {
656 SqliteDataset::<Sample>::from_db_file("tests/data/sqlite-dataset.db", "train").unwrap()
657 }
658
659 #[rstest]
660 pub fn len(train_dataset: SqlDs) {
661 assert_eq!(train_dataset.len(), 2);
662 }
663
664 #[rstest]
665 pub fn get_some(train_dataset: SqlDs) {
666 let item = train_dataset.get(0).unwrap();
667 assert_eq!(item.column_str, "HI1");
668 assert_eq!(item.column_bytes, vec![55, 231, 159]);
669 assert_eq!(item.column_int, 1);
670 assert!(item.column_bool);
671 assert_eq!(item.column_float, 1.0);
672 }
673
674 #[rstest]
675 pub fn get_none(train_dataset: SqlDs) {
676 assert_eq!(train_dataset.get(10), None);
677 }
678
679 #[rstest]
680 pub fn multi_thread(train_dataset: SqlDs) {
681 let indices: Vec<usize> = vec![0, 1, 1, 3, 4, 5, 6, 0, 8, 1];
682 let results: Vec<Option<Sample>> =
683 indices.par_iter().map(|&i| train_dataset.get(i)).collect();
684
685 let mut match_count = 0;
686 for (_index, result) in indices.iter().zip(results.iter()) {
687 if let Some(_val) = result {
688 match_count += 1
689 }
690 }
691
692 assert_eq!(match_count, 5);
693 }
694
695 #[test]
696 fn sqlite_dataset_storage() {
697 let storage = SqliteDatasetStorage::from_file("non-existing.db");
699 assert!(!storage.exists());
700
701 let storage = SqliteDatasetStorage::from_name("non-existing.db");
703 assert!(!storage.exists());
704
705 let storage = SqliteDatasetStorage::from_file("tests/data/sqlite-dataset.db");
707 assert!(storage.exists());
708 let result = storage.reader::<Sample>("train");
709 assert!(result.is_ok());
710 let train = result.unwrap();
711 assert_eq!(train.len(), 2);
712
713 let temp_file = NamedTempFile::new().unwrap();
715 let storage = SqliteDatasetStorage::from_file(temp_file.path());
716 assert!(storage.exists());
717 let result = storage.writer::<Sample>(true);
718 assert!(result.is_ok());
719 }
720
721 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
722 pub struct Complex {
723 column_str: String,
724 column_bytes: Vec<u8>,
725 column_int: i64,
726 column_bool: bool,
727 column_float: f64,
728 column_complex: Vec<Vec<Vec<[u8; 3]>>>,
729 }
730
731 #[fixture]
733 fn tmp_dir() -> TempDir {
734 tempdir().unwrap()
737 }
738 type Writer = SqliteDatasetWriter<Complex>;
739
740 #[fixture]
743 fn writer_fixture(tmp_dir: TempDir) -> (Writer, TempDir) {
744 let temp_dir_str = tmp_dir.path();
745 let storage = SqliteDatasetStorage::from_name("preprocessed").with_base_dir(temp_dir_str);
746 let overwrite = true;
747 let result = storage.writer::<Complex>(overwrite);
748 assert!(result.is_ok());
749 let writer = result.unwrap();
750 (writer, tmp_dir)
751 }
752
753 #[test]
754 fn test_new() {
755 let test_path = NamedTempFile::new().unwrap();
757 let _writer = SqliteDatasetWriter::<Complex>::new(&test_path, true).unwrap();
758 assert!(!test_path.path().exists());
759
760 let test_path = NamedTempFile::new().unwrap();
762 let result = SqliteDatasetWriter::<Complex>::new(&test_path, false);
763 assert!(result.is_err());
764
765 let temp = NamedTempFile::new().unwrap();
767 let test_path = temp.path().to_path_buf();
768 assert!(temp.close().is_ok());
769 assert!(!test_path.exists());
770 let _writer = SqliteDatasetWriter::<Complex>::new(&test_path, true).unwrap();
771 assert!(!test_path.exists());
772 }
773
774 #[rstest]
775 pub fn sqlite_writer_write(writer_fixture: (Writer, TempDir)) {
776 let (writer, _tmp_dir) = writer_fixture;
778
779 assert!(writer.overwrite);
780 assert!(!writer.db_file.exists());
781
782 let new_item = Complex {
783 column_str: "HI1".to_string(),
784 column_bytes: vec![1_u8, 2, 3],
785 column_int: 0,
786 column_bool: true,
787 column_float: 1.0,
788 column_complex: vec![vec![vec![[1, 23_u8, 3]]]],
789 };
790
791 let index = writer.write("train", &new_item).unwrap();
792 assert_eq!(index, 0);
793
794 let mut writer = writer;
795
796 writer.set_completed().expect("Failed to set completed");
797
798 assert!(writer.db_file.exists());
799 assert!(writer.db_file_tmp.is_none());
800
801 let result = writer.write("train", &new_item);
802
803 assert!(result.is_err());
805
806 let dataset = SqliteDataset::<Complex>::from_db_file(writer.db_file, "train").unwrap();
807
808 let fetched_item = dataset.get(0).unwrap();
809 assert_eq!(fetched_item, new_item);
810 assert_eq!(dataset.len(), 1);
811 }
812
813 #[rstest]
814 pub fn sqlite_writer_write_multi_thread(writer_fixture: (Writer, TempDir)) {
815 let (writer, _tmp_dir) = writer_fixture;
817
818 let writer = Arc::new(writer);
819 let record_count = 20;
820
821 let splits = ["train", "test"];
822
823 (0..record_count).into_par_iter().for_each(|index: i64| {
824 let thread_id: std::thread::ThreadId = std::thread::current().id();
825 let sample = Complex {
826 column_str: format!("test_{:?}_{}", thread_id, index),
827 column_bytes: vec![index as u8, 2, 3],
828 column_int: index,
829 column_bool: true,
830 column_float: 1.0,
831 column_complex: vec![vec![vec![[1, index as u8, 3]]]],
832 };
833
834 let split = splits[index as usize % 2];
836
837 let _index = writer.write(split, &sample).unwrap();
838 });
839
840 let mut writer = Arc::try_unwrap(writer).unwrap();
841
842 writer
843 .set_completed()
844 .expect("Should set completed successfully");
845
846 let train =
847 SqliteDataset::<Complex>::from_db_file(writer.db_file.clone(), "train").unwrap();
848 let test = SqliteDataset::<Complex>::from_db_file(writer.db_file, "test").unwrap();
849
850 assert_eq!(train.len(), record_count as usize / 2);
851 assert_eq!(test.len(), record_count as usize / 2);
852 }
853}