1use std::{
2 collections::HashSet,
3 fs, io,
4 marker::PhantomData,
5 path::{Path, PathBuf},
6 sync::{Arc, RwLock},
7};
8
9use crate::Dataset;
10
11use gix_tempfile::{
12 AutoRemove, ContainingDirectory, Handle,
13 handle::{Writable, persist},
14};
15use r2d2::{Pool, PooledConnection};
16use r2d2_sqlite::{
17 SqliteConnectionManager,
18 rusqlite::{OpenFlags, OptionalExtension},
19};
20use sanitize_filename::sanitize;
21use serde::{Serialize, de::DeserializeOwned};
22use serde_rusqlite::{columns_from_statement, from_row_with_columns};
23
24pub type Result<T> = core::result::Result<T, SqliteDatasetError>;
26
27#[derive(thiserror::Error, Debug)]
29pub enum SqliteDatasetError {
30 #[error("IO error: {0}")]
32 Io(#[from] io::Error),
33
34 #[error("Sql error: {0}")]
36 Sql(#[from] serde_rusqlite::rusqlite::Error),
37
38 #[error("Serde error: {0}")]
40 Serde(#[from] rmp_serde::encode::Error),
41
42 #[error("Overwrite flag is set to false and the database file already exists: {0}")]
44 FileExists(PathBuf),
45
46 #[error("Failed to create connection pool: {0}")]
48 ConnectionPool(#[from] r2d2::Error),
49
50 #[error("Could not persist the temporary database file: {0}")]
52 PersistDbFile(#[from] persist::Error<Writable>),
53
54 #[error("{0}")]
56 Other(&'static str),
57}
58
59impl From<&'static str> for SqliteDatasetError {
60 fn from(s: &'static str) -> Self {
61 SqliteDatasetError::Other(s)
62 }
63}
64
65#[derive(Debug)]
91pub struct SqliteDataset<I> {
92 db_file: PathBuf,
93 split: String,
94 conn_pool: Pool<SqliteConnectionManager>,
95 columns: Vec<String>,
96 len: usize,
97 select_statement: String,
98 row_serialized: bool,
99 phantom: PhantomData<I>,
100}
101
102impl<I> SqliteDataset<I> {
103 pub fn from_db_file<P: AsRef<Path>>(db_file: P, split: &str) -> Result<Self> {
105 let conn_pool = create_conn_pool(&db_file, false)?;
107
108 let row_serialized = Self::check_if_row_serialized(&conn_pool, split)?;
110
111 let select_statement = if row_serialized {
113 format!("select item from {split} where row_id = ?")
114 } else {
115 format!("select * from {split} where row_id = ?")
116 };
117
118 let (columns, len) = fetch_columns_and_len(&conn_pool, &select_statement, split)?;
120
121 Ok(SqliteDataset {
122 db_file: db_file.as_ref().to_path_buf(),
123 split: split.to_string(),
124 conn_pool,
125 columns,
126 len,
127 select_statement,
128 row_serialized,
129 phantom: PhantomData,
130 })
131 }
132
133 fn check_if_row_serialized(
137 conn_pool: &Pool<SqliteConnectionManager>,
138 split: &str,
139 ) -> Result<bool> {
140 struct Column {
142 name: String,
143 ty: String,
144 }
145
146 const COLUMN_NAME: usize = 1;
147 const COLUMN_TYPE: usize = 2;
148
149 let sql_statement = format!("PRAGMA table_info({split})");
150
151 let conn = conn_pool.get()?;
152
153 let mut stmt = conn.prepare(sql_statement.as_str())?;
154 let column_iter = stmt.query_map([], |row| {
155 Ok(Column {
156 name: row
157 .get::<usize, String>(COLUMN_NAME)
158 .unwrap()
159 .to_lowercase(),
160 ty: row
161 .get::<usize, String>(COLUMN_TYPE)
162 .unwrap()
163 .to_lowercase(),
164 })
165 })?;
166
167 let mut columns: Vec<Column> = vec![];
168
169 for column in column_iter {
170 columns.push(column?);
171 }
172
173 if columns.len() != 2 {
174 Ok(false)
175 } else {
176 Ok(columns[0].name == "row_id"
178 && columns[0].ty == "integer"
179 && columns[1].name == "item"
180 && columns[1].ty == "blob")
181 }
182 }
183
184 pub fn db_file(&self) -> PathBuf {
186 self.db_file.clone()
187 }
188
189 pub fn split(&self) -> &str {
191 self.split.as_str()
192 }
193}
194
195impl<I> Dataset<I> for SqliteDataset<I>
196where
197 I: Clone + Send + Sync + DeserializeOwned,
198{
199 fn get(&self, index: usize) -> Option<I> {
201 let row_id = index + 1;
203
204 let connection = self.conn_pool.get().unwrap();
206 let mut statement = connection.prepare(self.select_statement.as_str()).unwrap();
207
208 if self.row_serialized {
209 statement
211 .query_row([row_id], |row| {
212 Ok(
214 rmp_serde::from_slice::<I>(row.get_ref(0).unwrap().as_blob().unwrap())
215 .unwrap(),
216 )
217 })
218 .optional() .unwrap()
220 } else {
221 statement
223 .query_row([row_id], |row| {
224 Ok(from_row_with_columns::<I>(row, &self.columns).unwrap())
226 })
227 .optional() .unwrap()
229 }
230 }
231
232 fn len(&self) -> usize {
234 self.len
235 }
236}
237
238fn fetch_columns_and_len(
240 conn_pool: &Pool<SqliteConnectionManager>,
241 select_statement: &str,
242 split: &str,
243) -> Result<(Vec<String>, usize)> {
244 let connection = conn_pool.get()?;
246 let statement = connection.prepare(select_statement)?;
247 let columns = columns_from_statement(&statement);
248
249 let mut statement =
257 connection.prepare(format!("select coalesce(max(row_id), 0) from {split}").as_str())?;
258
259 let len = statement.query_row([], |row| {
260 let len: usize = row.get(0)?;
261 Ok(len)
262 })?;
263 Ok((columns, len))
264}
265
266fn create_conn_pool<P: AsRef<Path>>(
268 db_file: P,
269 write: bool,
270) -> Result<Pool<SqliteConnectionManager>> {
271 let sqlite_flags = if write {
272 OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE
273 } else {
274 OpenFlags::SQLITE_OPEN_READ_ONLY
275 };
276
277 let manager = SqliteConnectionManager::file(db_file).with_flags(sqlite_flags);
278 Pool::new(manager).map_err(SqliteDatasetError::ConnectionPool)
279}
280
281#[derive(Clone, Debug)]
284pub struct SqliteDatasetStorage {
285 name: Option<String>,
286 db_file: Option<PathBuf>,
287 base_dir: Option<PathBuf>,
288}
289
290impl SqliteDatasetStorage {
291 pub fn from_name(name: &str) -> Self {
297 SqliteDatasetStorage {
298 name: Some(name.to_string()),
299 db_file: None,
300 base_dir: None,
301 }
302 }
303
304 pub fn from_file<P: AsRef<Path>>(db_file: P) -> Self {
310 SqliteDatasetStorage {
311 name: None,
312 db_file: Some(db_file.as_ref().to_path_buf()),
313 base_dir: None,
314 }
315 }
316
317 pub fn with_base_dir<P: AsRef<Path>>(mut self, base_dir: P) -> Self {
323 self.base_dir = Some(base_dir.as_ref().to_path_buf());
324 self
325 }
326
327 pub fn exists(&self) -> bool {
333 self.db_file().exists()
334 }
335
336 pub fn db_file(&self) -> PathBuf {
342 match &self.db_file {
343 Some(db_file) => db_file.clone(),
344 None => {
345 let name = sanitize(self.name.as_ref().expect("Name is not set"));
346 Self::base_dir(self.base_dir.to_owned()).join(format!("{name}.db"))
347 }
348 }
349 }
350
351 pub fn base_dir(base_dir: Option<PathBuf>) -> PathBuf {
361 match base_dir {
362 Some(base_dir) => base_dir,
363 None => dirs::cache_dir()
364 .expect("Could not get cache directory")
365 .join("burn-dataset"),
366 }
367 }
368
369 pub fn writer<I>(&self, overwrite: bool) -> Result<SqliteDatasetWriter<I>>
379 where
380 I: Clone + Send + Sync + Serialize + DeserializeOwned,
381 {
382 SqliteDatasetWriter::new(self.db_file(), overwrite)
383 }
384
385 pub fn reader<I>(&self, split: &str) -> Result<SqliteDataset<I>>
395 where
396 I: Clone + Send + Sync + Serialize + DeserializeOwned,
397 {
398 if !self.exists() {
399 panic!("The database file does not exist");
400 }
401
402 SqliteDataset::from_db_file(self.db_file(), split)
403 }
404}
405
406#[derive(Debug)]
417pub struct SqliteDatasetWriter<I> {
418 db_file: PathBuf,
419 db_file_tmp: Option<Handle<Writable>>,
420 splits: Arc<RwLock<HashSet<String>>>,
421 overwrite: bool,
422 conn_pool: Option<Pool<SqliteConnectionManager>>,
423 is_completed: Arc<RwLock<bool>>,
424 phantom: PhantomData<I>,
425}
426
427impl<I> SqliteDatasetWriter<I>
428where
429 I: Clone + Send + Sync + Serialize + DeserializeOwned,
430{
431 pub fn new<P: AsRef<Path>>(db_file: P, overwrite: bool) -> Result<Self> {
442 let writer = Self {
443 db_file: db_file.as_ref().to_path_buf(),
444 db_file_tmp: None,
445 splits: Arc::new(RwLock::new(HashSet::new())),
446 overwrite,
447 conn_pool: None,
448 is_completed: Arc::new(RwLock::new(false)),
449 phantom: PhantomData,
450 };
451
452 writer.init()
453 }
454
455 fn init(mut self) -> Result<Self> {
461 if self.db_file.exists() {
463 if self.overwrite {
464 fs::remove_file(&self.db_file)?;
465 } else {
466 return Err(SqliteDatasetError::FileExists(self.db_file));
467 }
468 }
469
470 let db_file_dir = self
472 .db_file
473 .parent()
474 .ok_or("Unable to get parent directory")?;
475
476 if !db_file_dir.exists() {
477 fs::create_dir_all(db_file_dir)?;
478 }
479
480 let mut db_file_tmp = self.db_file.clone();
482 db_file_tmp.set_extension("db.tmp");
483 if db_file_tmp.exists() {
484 fs::remove_file(&db_file_tmp)?;
485 }
486
487 gix_tempfile::signal::setup(Default::default());
491 self.db_file_tmp = Some(gix_tempfile::writable_at(
492 &db_file_tmp,
493 ContainingDirectory::Exists,
494 AutoRemove::Tempfile,
495 )?);
496
497 let conn_pool = create_conn_pool(db_file_tmp, true)?;
498 self.conn_pool = Some(conn_pool);
499
500 Ok(self)
501 }
502
503 pub fn write(&self, split: &str, item: &I) -> Result<usize> {
516 let is_completed = self.is_completed.read().unwrap();
518
519 if *is_completed {
521 return Err(SqliteDatasetError::Other(
522 "Cannot save to a completed dataset writer",
523 ));
524 }
525
526 if !self.splits.read().unwrap().contains(split) {
528 self.create_table(split)?;
529 }
530
531 let conn_pool = self.conn_pool.as_ref().unwrap();
533 let conn = conn_pool.get()?;
534
535 let serialized_item = rmp_serde::to_vec(item)?;
537
538 pragma_update_with_error_handling(&conn, "synchronous", "OFF")?;
542 pragma_update_with_error_handling(&conn, "journal_mode", "OFF")?;
543
544 let insert_statement = format!("insert into {split} (item) values (?)");
546 conn.execute(insert_statement.as_str(), [serialized_item])?;
547
548 let index = (conn.last_insert_rowid() - 1) as usize;
550
551 Ok(index)
552 }
553
554 pub fn set_completed(&mut self) -> Result<()> {
556 let mut is_completed = self.is_completed.write().unwrap();
557
558 if let Some(pool) = self.conn_pool.take() {
562 std::mem::drop(pool);
563 }
564
565 let _file_result = self
567 .db_file_tmp
568 .take() .unwrap() .persist(&self.db_file)?
571 .ok_or("Unable to persist the database file")?;
572
573 *is_completed = true;
574 Ok(())
575 }
576
577 fn create_table(&self, split: &str) -> Result<()> {
591 if self.splits.read().unwrap().contains(split) {
593 return Ok(());
594 }
595
596 let conn_pool = self.conn_pool.as_ref().unwrap();
597 let connection = conn_pool.get()?;
598 let create_table_statement = format!(
599 "create table if not exists {split} (row_id integer primary key autoincrement not \
600 null, item blob not null)"
601 );
602
603 connection.execute(create_table_statement.as_str(), [])?;
604
605 self.splits.write().unwrap().insert(split.to_string());
607
608 Ok(())
609 }
610}
611
612fn pragma_update_with_error_handling(
618 conn: &PooledConnection<SqliteConnectionManager>,
619 setting: &str,
620 value: &str,
621) -> Result<()> {
622 let result = conn.pragma_update(None, setting, value);
623 if let Err(error) = result
624 && error != rusqlite::Error::ExecuteReturnedResults
625 {
626 return Err(SqliteDatasetError::Sql(error));
627 }
628
629 Ok(())
630}
631
632#[cfg(test)]
633mod tests {
634 use rayon::prelude::*;
635 use rstest::{fixture, rstest};
636 use serde::{Deserialize, Serialize};
637 use tempfile::{NamedTempFile, TempDir, tempdir};
638
639 use super::*;
640
641 type SqlDs = SqliteDataset<Sample>;
642
643 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
644 pub struct Sample {
645 column_str: String,
646 column_bytes: Vec<u8>,
647 column_int: i64,
648 column_bool: bool,
649 column_float: f64,
650 }
651
652 #[fixture]
653 fn train_dataset() -> SqlDs {
654 SqliteDataset::<Sample>::from_db_file("tests/data/sqlite-dataset.db", "train").unwrap()
655 }
656
657 #[rstest]
658 pub fn len(train_dataset: SqlDs) {
659 assert_eq!(train_dataset.len(), 2);
660 }
661
662 #[rstest]
663 pub fn get_some(train_dataset: SqlDs) {
664 let item = train_dataset.get(0).unwrap();
665 assert_eq!(item.column_str, "HI1");
666 assert_eq!(item.column_bytes, vec![55, 231, 159]);
667 assert_eq!(item.column_int, 1);
668 assert!(item.column_bool);
669 assert_eq!(item.column_float, 1.0);
670 }
671
672 #[rstest]
673 pub fn get_none(train_dataset: SqlDs) {
674 assert_eq!(train_dataset.get(10), None);
675 }
676
677 #[rstest]
678 pub fn multi_thread(train_dataset: SqlDs) {
679 let indices: Vec<usize> = vec![0, 1, 1, 3, 4, 5, 6, 0, 8, 1];
680 let results: Vec<Option<Sample>> =
681 indices.par_iter().map(|&i| train_dataset.get(i)).collect();
682
683 let mut match_count = 0;
684 for (_index, result) in indices.iter().zip(results.iter()) {
685 if let Some(_val) = result {
686 match_count += 1
687 }
688 }
689
690 assert_eq!(match_count, 5);
691 }
692
693 #[test]
694 fn sqlite_dataset_storage() {
695 let storage = SqliteDatasetStorage::from_file("non-existing.db");
697 assert!(!storage.exists());
698
699 let storage = SqliteDatasetStorage::from_name("non-existing.db");
701 assert!(!storage.exists());
702
703 let storage = SqliteDatasetStorage::from_file("tests/data/sqlite-dataset.db");
705 assert!(storage.exists());
706 let result = storage.reader::<Sample>("train");
707 assert!(result.is_ok());
708 let train = result.unwrap();
709 assert_eq!(train.len(), 2);
710
711 let temp_file = NamedTempFile::new().unwrap();
713 let storage = SqliteDatasetStorage::from_file(temp_file.path());
714 assert!(storage.exists());
715 let result = storage.writer::<Sample>(true);
716 assert!(result.is_ok());
717 }
718
719 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
720 pub struct Complex {
721 column_str: String,
722 column_bytes: Vec<u8>,
723 column_int: i64,
724 column_bool: bool,
725 column_float: f64,
726 column_complex: Vec<Vec<Vec<[u8; 3]>>>,
727 }
728
729 #[fixture]
731 fn tmp_dir() -> TempDir {
732 tempdir().unwrap()
735 }
736 type Writer = SqliteDatasetWriter<Complex>;
737
738 #[fixture]
741 fn writer_fixture(tmp_dir: TempDir) -> (Writer, TempDir) {
742 let temp_dir_str = tmp_dir.path();
743 let storage = SqliteDatasetStorage::from_name("preprocessed").with_base_dir(temp_dir_str);
744 let overwrite = true;
745 let result = storage.writer::<Complex>(overwrite);
746 assert!(result.is_ok());
747 let writer = result.unwrap();
748 (writer, tmp_dir)
749 }
750
751 #[test]
752 fn test_new() {
753 let test_path = NamedTempFile::new().unwrap();
755 let _writer = SqliteDatasetWriter::<Complex>::new(&test_path, true).unwrap();
756 assert!(!test_path.path().exists());
757
758 let test_path = NamedTempFile::new().unwrap();
760 let result = SqliteDatasetWriter::<Complex>::new(&test_path, false);
761 assert!(result.is_err());
762
763 let temp = NamedTempFile::new().unwrap();
765 let test_path = temp.path().to_path_buf();
766 assert!(temp.close().is_ok());
767 assert!(!test_path.exists());
768 let _writer = SqliteDatasetWriter::<Complex>::new(&test_path, true).unwrap();
769 assert!(!test_path.exists());
770 }
771
772 #[rstest]
773 pub fn sqlite_writer_write(writer_fixture: (Writer, TempDir)) {
774 let (writer, _tmp_dir) = writer_fixture;
776
777 assert!(writer.overwrite);
778 assert!(!writer.db_file.exists());
779
780 let new_item = Complex {
781 column_str: "HI1".to_string(),
782 column_bytes: vec![1_u8, 2, 3],
783 column_int: 0,
784 column_bool: true,
785 column_float: 1.0,
786 column_complex: vec![vec![vec![[1, 23_u8, 3]]]],
787 };
788
789 let index = writer.write("train", &new_item).unwrap();
790 assert_eq!(index, 0);
791
792 let mut writer = writer;
793
794 writer.set_completed().expect("Failed to set completed");
795
796 assert!(writer.db_file.exists());
797 assert!(writer.db_file_tmp.is_none());
798
799 let result = writer.write("train", &new_item);
800
801 assert!(result.is_err());
803
804 let dataset = SqliteDataset::<Complex>::from_db_file(writer.db_file, "train").unwrap();
805
806 let fetched_item = dataset.get(0).unwrap();
807 assert_eq!(fetched_item, new_item);
808 assert_eq!(dataset.len(), 1);
809 }
810
811 #[rstest]
812 pub fn sqlite_writer_write_multi_thread(writer_fixture: (Writer, TempDir)) {
813 let (writer, _tmp_dir) = writer_fixture;
815
816 let writer = Arc::new(writer);
817 let record_count = 20;
818
819 let splits = ["train", "test"];
820
821 (0..record_count).into_par_iter().for_each(|index: i64| {
822 let thread_id: std::thread::ThreadId = std::thread::current().id();
823 let sample = Complex {
824 column_str: format!("test_{thread_id:?}_{index}"),
825 column_bytes: vec![index as u8, 2, 3],
826 column_int: index,
827 column_bool: true,
828 column_float: 1.0,
829 column_complex: vec![vec![vec![[1, index as u8, 3]]]],
830 };
831
832 let split = splits[index as usize % 2];
834
835 let _index = writer.write(split, &sample).unwrap();
836 });
837
838 let mut writer = Arc::try_unwrap(writer).unwrap();
839
840 writer
841 .set_completed()
842 .expect("Should set completed successfully");
843
844 let train =
845 SqliteDataset::<Complex>::from_db_file(writer.db_file.clone(), "train").unwrap();
846 let test = SqliteDataset::<Complex>::from_db_file(writer.db_file, "test").unwrap();
847
848 assert_eq!(train.len(), record_count as usize / 2);
849 assert_eq!(test.len(), record_count as usize / 2);
850 }
851}