1use std::{
2 collections::HashSet,
3 fs, io,
4 marker::PhantomData,
5 path::{Path, PathBuf},
6 sync::{Arc, RwLock},
7};
8
9use crate::Dataset;
10
11use gix_tempfile::{
12 AutoRemove, ContainingDirectory, Handle,
13 handle::{Writable, persist},
14};
15use r2d2::{Pool, PooledConnection};
16use r2d2_sqlite::{
17 SqliteConnectionManager,
18 rusqlite::{OpenFlags, OptionalExtension},
19};
20use sanitize_filename::sanitize;
21use serde::{Serialize, de::DeserializeOwned};
22use serde_rusqlite::{columns_from_statement, from_row_with_columns};
23
24pub type Result<T> = core::result::Result<T, SqliteDatasetError>;
26
27#[derive(thiserror::Error, Debug)]
29pub enum SqliteDatasetError {
30 #[error("IO error: {0}")]
32 Io(#[from] io::Error),
33
34 #[error("Sql error: {0}")]
36 Sql(#[from] serde_rusqlite::rusqlite::Error),
37
38 #[error("Serde error: {0}")]
40 Serde(#[from] rmp_serde::encode::Error),
41
42 #[error("Overwrite flag is set to false and the database file already exists: {0}")]
44 FileExists(PathBuf),
45
46 #[error("Failed to create connection pool: {0}")]
48 ConnectionPool(#[from] r2d2::Error),
49
50 #[error("Could not persist the temporary database file: {0}")]
52 PersistDbFile(#[from] persist::Error<Writable>),
53
54 #[error("{0}")]
56 Other(&'static str),
57}
58
59impl From<&'static str> for SqliteDatasetError {
60 fn from(s: &'static str) -> Self {
61 SqliteDatasetError::Other(s)
62 }
63}
64
65#[derive(Debug)]
91pub struct SqliteDataset<I> {
92 db_file: PathBuf,
93 split: String,
94 conn_pool: Pool<SqliteConnectionManager>,
95 columns: Vec<String>,
96 len: usize,
97 select_statement: String,
98 row_serialized: bool,
99 phantom: PhantomData<I>,
100}
101
102impl<I> SqliteDataset<I> {
103 pub fn from_db_file<P: AsRef<Path>>(db_file: P, split: &str) -> Result<Self> {
105 let conn_pool = create_conn_pool(&db_file, false)?;
107
108 let row_serialized = Self::check_if_row_serialized(&conn_pool, split)?;
110
111 let select_statement = if row_serialized {
113 format!("select item from {split} where row_id = ?")
114 } else {
115 format!("select * from {split} where row_id = ?")
116 };
117
118 let (columns, len) = fetch_columns_and_len(&conn_pool, &select_statement, split)?;
120
121 Ok(SqliteDataset {
122 db_file: db_file.as_ref().to_path_buf(),
123 split: split.to_string(),
124 conn_pool,
125 columns,
126 len,
127 select_statement,
128 row_serialized,
129 phantom: PhantomData,
130 })
131 }
132
133 fn check_if_row_serialized(
137 conn_pool: &Pool<SqliteConnectionManager>,
138 split: &str,
139 ) -> Result<bool> {
140 struct Column {
142 name: String,
143 ty: String,
144 }
145
146 const COLUMN_NAME: usize = 1;
147 const COLUMN_TYPE: usize = 2;
148
149 let sql_statement = format!("PRAGMA table_info({split})");
150
151 let conn = conn_pool.get()?;
152
153 let mut stmt = conn.prepare(sql_statement.as_str())?;
154 let column_iter = stmt.query_map([], |row| {
155 Ok(Column {
156 name: row
157 .get::<usize, String>(COLUMN_NAME)
158 .unwrap()
159 .to_lowercase(),
160 ty: row
161 .get::<usize, String>(COLUMN_TYPE)
162 .unwrap()
163 .to_lowercase(),
164 })
165 })?;
166
167 let mut columns: Vec<Column> = vec![];
168
169 for column in column_iter {
170 columns.push(column?);
171 }
172
173 if columns.len() != 2 {
174 Ok(false)
175 } else {
176 Ok(columns[0].name == "row_id"
178 && columns[0].ty == "integer"
179 && columns[1].name == "item"
180 && columns[1].ty == "blob")
181 }
182 }
183
184 pub fn db_file(&self) -> PathBuf {
186 self.db_file.clone()
187 }
188
189 pub fn split(&self) -> &str {
191 self.split.as_str()
192 }
193}
194
195impl<I> Dataset<I> for SqliteDataset<I>
196where
197 I: Clone + Send + Sync + DeserializeOwned,
198{
199 fn get(&self, index: usize) -> Option<I> {
201 let row_id = index + 1;
203
204 let connection = self.conn_pool.get().unwrap();
206 let mut statement = connection.prepare(self.select_statement.as_str()).unwrap();
207
208 if self.row_serialized {
209 statement
211 .query_row([row_id], |row| {
212 Ok(
214 rmp_serde::from_slice::<I>(row.get_ref(0).unwrap().as_blob().unwrap())
215 .unwrap(),
216 )
217 })
218 .optional() .unwrap()
220 } else {
221 statement
223 .query_row([row_id], |row| {
224 Ok(from_row_with_columns::<I>(row, &self.columns).unwrap())
226 })
227 .optional() .unwrap()
229 }
230 }
231
232 fn len(&self) -> usize {
234 self.len
235 }
236}
237
238fn fetch_columns_and_len(
240 conn_pool: &Pool<SqliteConnectionManager>,
241 select_statement: &str,
242 split: &str,
243) -> Result<(Vec<String>, usize)> {
244 let connection = conn_pool.get()?;
246 let statement = connection.prepare(select_statement)?;
247 let columns = columns_from_statement(&statement);
248
249 let mut statement =
257 connection.prepare(format!("select coalesce(max(row_id), 0) from {split}").as_str())?;
258
259 let len = statement.query_row([], |row| {
260 let len: usize = row.get(0)?;
261 Ok(len)
262 })?;
263 Ok((columns, len))
264}
265
266fn create_conn_pool<P: AsRef<Path>>(
268 db_file: P,
269 write: bool,
270) -> Result<Pool<SqliteConnectionManager>> {
271 let sqlite_flags = if write {
272 OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE
273 } else {
274 OpenFlags::SQLITE_OPEN_READ_ONLY
275 };
276
277 let manager = SqliteConnectionManager::file(db_file).with_flags(sqlite_flags);
278 Pool::new(manager).map_err(SqliteDatasetError::ConnectionPool)
279}
280
281#[derive(Clone, Debug)]
284pub struct SqliteDatasetStorage {
285 name: Option<String>,
286 db_file: Option<PathBuf>,
287 base_dir: Option<PathBuf>,
288}
289
290impl SqliteDatasetStorage {
291 pub fn from_name(name: &str) -> Self {
297 SqliteDatasetStorage {
298 name: Some(name.to_string()),
299 db_file: None,
300 base_dir: None,
301 }
302 }
303
304 pub fn from_file<P: AsRef<Path>>(db_file: P) -> Self {
310 SqliteDatasetStorage {
311 name: None,
312 db_file: Some(db_file.as_ref().to_path_buf()),
313 base_dir: None,
314 }
315 }
316
317 pub fn with_base_dir<P: AsRef<Path>>(mut self, base_dir: P) -> Self {
323 self.base_dir = Some(base_dir.as_ref().to_path_buf());
324 self
325 }
326
327 pub fn exists(&self) -> bool {
333 self.db_file().exists()
334 }
335
336 pub fn db_file(&self) -> PathBuf {
342 match &self.db_file {
343 Some(db_file) => db_file.clone(),
344 None => {
345 let name = sanitize(self.name.as_ref().expect("Name is not set"));
346 Self::base_dir(self.base_dir.to_owned()).join(format!("{name}.db"))
347 }
348 }
349 }
350
351 pub fn base_dir(base_dir: Option<PathBuf>) -> PathBuf {
361 match base_dir {
362 Some(base_dir) => base_dir,
363 None => {
364 let home_dir = dirs::home_dir().expect("Could not get home directory");
365
366 home_dir.join(".cache").join("burn-dataset")
367 }
368 }
369 }
370
371 pub fn writer<I>(&self, overwrite: bool) -> Result<SqliteDatasetWriter<I>>
381 where
382 I: Clone + Send + Sync + Serialize + DeserializeOwned,
383 {
384 SqliteDatasetWriter::new(self.db_file(), overwrite)
385 }
386
387 pub fn reader<I>(&self, split: &str) -> Result<SqliteDataset<I>>
397 where
398 I: Clone + Send + Sync + Serialize + DeserializeOwned,
399 {
400 if !self.exists() {
401 panic!("The database file does not exist");
402 }
403
404 SqliteDataset::from_db_file(self.db_file(), split)
405 }
406}
407
408#[derive(Debug)]
419pub struct SqliteDatasetWriter<I> {
420 db_file: PathBuf,
421 db_file_tmp: Option<Handle<Writable>>,
422 splits: Arc<RwLock<HashSet<String>>>,
423 overwrite: bool,
424 conn_pool: Option<Pool<SqliteConnectionManager>>,
425 is_completed: Arc<RwLock<bool>>,
426 phantom: PhantomData<I>,
427}
428
429impl<I> SqliteDatasetWriter<I>
430where
431 I: Clone + Send + Sync + Serialize + DeserializeOwned,
432{
433 pub fn new<P: AsRef<Path>>(db_file: P, overwrite: bool) -> Result<Self> {
444 let writer = Self {
445 db_file: db_file.as_ref().to_path_buf(),
446 db_file_tmp: None,
447 splits: Arc::new(RwLock::new(HashSet::new())),
448 overwrite,
449 conn_pool: None,
450 is_completed: Arc::new(RwLock::new(false)),
451 phantom: PhantomData,
452 };
453
454 writer.init()
455 }
456
457 fn init(mut self) -> Result<Self> {
463 if self.db_file.exists() {
465 if self.overwrite {
466 fs::remove_file(&self.db_file)?;
467 } else {
468 return Err(SqliteDatasetError::FileExists(self.db_file));
469 }
470 }
471
472 let db_file_dir = self
474 .db_file
475 .parent()
476 .ok_or("Unable to get parent directory")?;
477
478 if !db_file_dir.exists() {
479 fs::create_dir_all(db_file_dir)?;
480 }
481
482 let mut db_file_tmp = self.db_file.clone();
484 db_file_tmp.set_extension("db.tmp");
485 if db_file_tmp.exists() {
486 fs::remove_file(&db_file_tmp)?;
487 }
488
489 gix_tempfile::signal::setup(Default::default());
493 self.db_file_tmp = Some(gix_tempfile::writable_at(
494 &db_file_tmp,
495 ContainingDirectory::Exists,
496 AutoRemove::Tempfile,
497 )?);
498
499 let conn_pool = create_conn_pool(db_file_tmp, true)?;
500 self.conn_pool = Some(conn_pool);
501
502 Ok(self)
503 }
504
505 pub fn write(&self, split: &str, item: &I) -> Result<usize> {
518 let is_completed = self.is_completed.read().unwrap();
520
521 if *is_completed {
523 return Err(SqliteDatasetError::Other(
524 "Cannot save to a completed dataset writer",
525 ));
526 }
527
528 if !self.splits.read().unwrap().contains(split) {
530 self.create_table(split)?;
531 }
532
533 let conn_pool = self.conn_pool.as_ref().unwrap();
535 let conn = conn_pool.get()?;
536
537 let serialized_item = rmp_serde::to_vec(item)?;
539
540 pragma_update_with_error_handling(&conn, "synchronous", "OFF")?;
544 pragma_update_with_error_handling(&conn, "journal_mode", "OFF")?;
545
546 let insert_statement = format!("insert into {split} (item) values (?)");
548 conn.execute(insert_statement.as_str(), [serialized_item])?;
549
550 let index = (conn.last_insert_rowid() - 1) as usize;
552
553 Ok(index)
554 }
555
556 pub fn set_completed(&mut self) -> Result<()> {
558 let mut is_completed = self.is_completed.write().unwrap();
559
560 if let Some(pool) = self.conn_pool.take() {
564 std::mem::drop(pool);
565 }
566
567 let _file_result = self
569 .db_file_tmp
570 .take() .unwrap() .persist(&self.db_file)?
573 .ok_or("Unable to persist the database file")?;
574
575 *is_completed = true;
576 Ok(())
577 }
578
579 fn create_table(&self, split: &str) -> Result<()> {
593 if self.splits.read().unwrap().contains(split) {
595 return Ok(());
596 }
597
598 let conn_pool = self.conn_pool.as_ref().unwrap();
599 let connection = conn_pool.get()?;
600 let create_table_statement = format!(
601 "create table if not exists {split} (row_id integer primary key autoincrement not \
602 null, item blob not null)"
603 );
604
605 connection.execute(create_table_statement.as_str(), [])?;
606
607 self.splits.write().unwrap().insert(split.to_string());
609
610 Ok(())
611 }
612}
613
614fn pragma_update_with_error_handling(
620 conn: &PooledConnection<SqliteConnectionManager>,
621 setting: &str,
622 value: &str,
623) -> Result<()> {
624 let result = conn.pragma_update(None, setting, value);
625 if let Err(error) = result {
626 if error != rusqlite::Error::ExecuteReturnedResults {
627 return Err(SqliteDatasetError::Sql(error));
628 }
629 }
630 Ok(())
631}
632
633#[cfg(test)]
634mod tests {
635 use rayon::prelude::*;
636 use rstest::{fixture, rstest};
637 use serde::{Deserialize, Serialize};
638 use tempfile::{NamedTempFile, TempDir, tempdir};
639
640 use super::*;
641
642 type SqlDs = SqliteDataset<Sample>;
643
644 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
645 pub struct Sample {
646 column_str: String,
647 column_bytes: Vec<u8>,
648 column_int: i64,
649 column_bool: bool,
650 column_float: f64,
651 }
652
653 #[fixture]
654 fn train_dataset() -> SqlDs {
655 SqliteDataset::<Sample>::from_db_file("tests/data/sqlite-dataset.db", "train").unwrap()
656 }
657
658 #[rstest]
659 pub fn len(train_dataset: SqlDs) {
660 assert_eq!(train_dataset.len(), 2);
661 }
662
663 #[rstest]
664 pub fn get_some(train_dataset: SqlDs) {
665 let item = train_dataset.get(0).unwrap();
666 assert_eq!(item.column_str, "HI1");
667 assert_eq!(item.column_bytes, vec![55, 231, 159]);
668 assert_eq!(item.column_int, 1);
669 assert!(item.column_bool);
670 assert_eq!(item.column_float, 1.0);
671 }
672
673 #[rstest]
674 pub fn get_none(train_dataset: SqlDs) {
675 assert_eq!(train_dataset.get(10), None);
676 }
677
678 #[rstest]
679 pub fn multi_thread(train_dataset: SqlDs) {
680 let indices: Vec<usize> = vec![0, 1, 1, 3, 4, 5, 6, 0, 8, 1];
681 let results: Vec<Option<Sample>> =
682 indices.par_iter().map(|&i| train_dataset.get(i)).collect();
683
684 let mut match_count = 0;
685 for (_index, result) in indices.iter().zip(results.iter()) {
686 if let Some(_val) = result {
687 match_count += 1
688 }
689 }
690
691 assert_eq!(match_count, 5);
692 }
693
694 #[test]
695 fn sqlite_dataset_storage() {
696 let storage = SqliteDatasetStorage::from_file("non-existing.db");
698 assert!(!storage.exists());
699
700 let storage = SqliteDatasetStorage::from_name("non-existing.db");
702 assert!(!storage.exists());
703
704 let storage = SqliteDatasetStorage::from_file("tests/data/sqlite-dataset.db");
706 assert!(storage.exists());
707 let result = storage.reader::<Sample>("train");
708 assert!(result.is_ok());
709 let train = result.unwrap();
710 assert_eq!(train.len(), 2);
711
712 let temp_file = NamedTempFile::new().unwrap();
714 let storage = SqliteDatasetStorage::from_file(temp_file.path());
715 assert!(storage.exists());
716 let result = storage.writer::<Sample>(true);
717 assert!(result.is_ok());
718 }
719
720 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
721 pub struct Complex {
722 column_str: String,
723 column_bytes: Vec<u8>,
724 column_int: i64,
725 column_bool: bool,
726 column_float: f64,
727 column_complex: Vec<Vec<Vec<[u8; 3]>>>,
728 }
729
730 #[fixture]
732 fn tmp_dir() -> TempDir {
733 tempdir().unwrap()
736 }
737 type Writer = SqliteDatasetWriter<Complex>;
738
739 #[fixture]
742 fn writer_fixture(tmp_dir: TempDir) -> (Writer, TempDir) {
743 let temp_dir_str = tmp_dir.path();
744 let storage = SqliteDatasetStorage::from_name("preprocessed").with_base_dir(temp_dir_str);
745 let overwrite = true;
746 let result = storage.writer::<Complex>(overwrite);
747 assert!(result.is_ok());
748 let writer = result.unwrap();
749 (writer, tmp_dir)
750 }
751
752 #[test]
753 fn test_new() {
754 let test_path = NamedTempFile::new().unwrap();
756 let _writer = SqliteDatasetWriter::<Complex>::new(&test_path, true).unwrap();
757 assert!(!test_path.path().exists());
758
759 let test_path = NamedTempFile::new().unwrap();
761 let result = SqliteDatasetWriter::<Complex>::new(&test_path, false);
762 assert!(result.is_err());
763
764 let temp = NamedTempFile::new().unwrap();
766 let test_path = temp.path().to_path_buf();
767 assert!(temp.close().is_ok());
768 assert!(!test_path.exists());
769 let _writer = SqliteDatasetWriter::<Complex>::new(&test_path, true).unwrap();
770 assert!(!test_path.exists());
771 }
772
773 #[rstest]
774 pub fn sqlite_writer_write(writer_fixture: (Writer, TempDir)) {
775 let (writer, _tmp_dir) = writer_fixture;
777
778 assert!(writer.overwrite);
779 assert!(!writer.db_file.exists());
780
781 let new_item = Complex {
782 column_str: "HI1".to_string(),
783 column_bytes: vec![1_u8, 2, 3],
784 column_int: 0,
785 column_bool: true,
786 column_float: 1.0,
787 column_complex: vec![vec![vec![[1, 23_u8, 3]]]],
788 };
789
790 let index = writer.write("train", &new_item).unwrap();
791 assert_eq!(index, 0);
792
793 let mut writer = writer;
794
795 writer.set_completed().expect("Failed to set completed");
796
797 assert!(writer.db_file.exists());
798 assert!(writer.db_file_tmp.is_none());
799
800 let result = writer.write("train", &new_item);
801
802 assert!(result.is_err());
804
805 let dataset = SqliteDataset::<Complex>::from_db_file(writer.db_file, "train").unwrap();
806
807 let fetched_item = dataset.get(0).unwrap();
808 assert_eq!(fetched_item, new_item);
809 assert_eq!(dataset.len(), 1);
810 }
811
812 #[rstest]
813 pub fn sqlite_writer_write_multi_thread(writer_fixture: (Writer, TempDir)) {
814 let (writer, _tmp_dir) = writer_fixture;
816
817 let writer = Arc::new(writer);
818 let record_count = 20;
819
820 let splits = ["train", "test"];
821
822 (0..record_count).into_par_iter().for_each(|index: i64| {
823 let thread_id: std::thread::ThreadId = std::thread::current().id();
824 let sample = Complex {
825 column_str: format!("test_{thread_id:?}_{index}"),
826 column_bytes: vec![index as u8, 2, 3],
827 column_int: index,
828 column_bool: true,
829 column_float: 1.0,
830 column_complex: vec![vec![vec![[1, index as u8, 3]]]],
831 };
832
833 let split = splits[index as usize % 2];
835
836 let _index = writer.write(split, &sample).unwrap();
837 });
838
839 let mut writer = Arc::try_unwrap(writer).unwrap();
840
841 writer
842 .set_completed()
843 .expect("Should set completed successfully");
844
845 let train =
846 SqliteDataset::<Complex>::from_db_file(writer.db_file.clone(), "train").unwrap();
847 let test = SqliteDataset::<Complex>::from_db_file(writer.db_file, "test").unwrap();
848
849 assert_eq!(train.len(), record_count as usize / 2);
850 assert_eq!(test.len(), record_count as usize / 2);
851 }
852}