burn_dataset/dataset/
sqlite.rs

1use std::{
2    collections::HashSet,
3    fs, io,
4    marker::PhantomData,
5    path::{Path, PathBuf},
6    sync::{Arc, RwLock},
7};
8
9use crate::Dataset;
10
11use gix_tempfile::{
12    handle::{persist, Writable},
13    AutoRemove, ContainingDirectory, Handle,
14};
15use r2d2::{Pool, PooledConnection};
16use r2d2_sqlite::{
17    rusqlite::{OpenFlags, OptionalExtension},
18    SqliteConnectionManager,
19};
20use sanitize_filename::sanitize;
21use serde::{de::DeserializeOwned, Serialize};
22use serde_rusqlite::{columns_from_statement, from_row_with_columns};
23
24/// Result type for the sqlite dataset.
25pub type Result<T> = core::result::Result<T, SqliteDatasetError>;
26
27/// Sqlite dataset error.
28#[derive(thiserror::Error, Debug)]
29pub enum SqliteDatasetError {
30    /// IO related error.
31    #[error("IO error: {0}")]
32    Io(#[from] io::Error),
33
34    /// Sql related error.
35    #[error("Sql error: {0}")]
36    Sql(#[from] serde_rusqlite::rusqlite::Error),
37
38    /// Serde related error.
39    #[error("Serde error: {0}")]
40    Serde(#[from] rmp_serde::encode::Error),
41
42    /// The database file already exists error.
43    #[error("Overwrite flag is set to false and the database file already exists: {0}")]
44    FileExists(PathBuf),
45
46    /// Error when creating the connection pool.
47    #[error("Failed to create connection pool: {0}")]
48    ConnectionPool(#[from] r2d2::Error),
49
50    /// Error when persisting the temporary database file.
51    #[error("Could not persist the temporary database file: {0}")]
52    PersistDbFile(#[from] persist::Error<Writable>),
53
54    /// Any other error.
55    #[error("{0}")]
56    Other(&'static str),
57}
58
59impl From<&'static str> for SqliteDatasetError {
60    fn from(s: &'static str) -> Self {
61        SqliteDatasetError::Other(s)
62    }
63}
64
65/// This struct represents a dataset where all items are stored in an SQLite database.
66/// Each instance of this struct corresponds to a specific table within the SQLite database,
67/// and allows for interaction with the data stored in the table in a structured and typed manner.
68///
69/// The SQLite database must contain a table with the same name as the `split` field. This table should
70/// have a primary key column named `row_id`, which is used to index the rows in the table. The `row_id`
71/// should start at 1, while the corresponding dataset `index` should start at 0, i.e., `row_id` = `index` + 1.
72///
73/// Table columns can be represented in two ways:
74///
75/// 1. The table can have a column for each field in the `I` struct. In this case, the column names in the table
76///    should match the field names of the `I` struct. The field names can be a subset of column names and
77///    can be in any order.
78///
79/// For the supported field types, refer to:
80/// - [Serialization field types](https://docs.rs/serde_rusqlite/latest/serde_rusqlite)
81/// - [SQLite data types](https://www.sqlite.org/datatype3.html)
82///
83/// 2. The fields in the `I` struct can be serialized into a single column `item` in the table. In this case, the table
84///    should have a single column named `item` of type `BLOB`. This is useful when the `I` struct contains complex fields
85///    that cannot be mapped to a SQLite type, such as nested structs, vectors, etc. The serialization is done using
86///    [MessagePack](https://msgpack.org/).
87///
88/// Note: The code automatically figures out which of the above two cases is applicable, and uses the appropriate
89/// method to read the data from the table.
90#[derive(Debug)]
91pub struct SqliteDataset<I> {
92    db_file: PathBuf,
93    split: String,
94    conn_pool: Pool<SqliteConnectionManager>,
95    columns: Vec<String>,
96    len: usize,
97    select_statement: String,
98    row_serialized: bool,
99    phantom: PhantomData<I>,
100}
101
102impl<I> SqliteDataset<I> {
103    /// Initializes a `SqliteDataset` from a SQLite database file and a split name.
104    pub fn from_db_file<P: AsRef<Path>>(db_file: P, split: &str) -> Result<Self> {
105        // Create a connection pool
106        let conn_pool = create_conn_pool(&db_file, false)?;
107
108        // Determine how the table is stored
109        let row_serialized = Self::check_if_row_serialized(&conn_pool, split)?;
110
111        // Create a select statement and save it
112        let select_statement = if row_serialized {
113            format!("select item from {split} where row_id = ?")
114        } else {
115            format!("select * from {split} where row_id = ?")
116        };
117
118        // Save the column names and the number of rows
119        let (columns, len) = fetch_columns_and_len(&conn_pool, &select_statement, split)?;
120
121        Ok(SqliteDataset {
122            db_file: db_file.as_ref().to_path_buf(),
123            split: split.to_string(),
124            conn_pool,
125            columns,
126            len,
127            select_statement,
128            row_serialized,
129            phantom: PhantomData,
130        })
131    }
132
133    /// Returns true if table has two columns: row_id (integer) and item (blob).
134    ///
135    /// This is used to determine if the table is row serialized or not.
136    fn check_if_row_serialized(
137        conn_pool: &Pool<SqliteConnectionManager>,
138        split: &str,
139    ) -> Result<bool> {
140        // This struct is used to store the column name and type
141        struct Column {
142            name: String,
143            ty: String,
144        }
145
146        const COLUMN_NAME: usize = 1;
147        const COLUMN_TYPE: usize = 2;
148
149        let sql_statement = format!("PRAGMA table_info({split})");
150
151        let conn = conn_pool.get()?;
152
153        let mut stmt = conn.prepare(sql_statement.as_str())?;
154        let column_iter = stmt.query_map([], |row| {
155            Ok(Column {
156                name: row
157                    .get::<usize, String>(COLUMN_NAME)
158                    .unwrap()
159                    .to_lowercase(),
160                ty: row
161                    .get::<usize, String>(COLUMN_TYPE)
162                    .unwrap()
163                    .to_lowercase(),
164            })
165        })?;
166
167        let mut columns: Vec<Column> = vec![];
168
169        for column in column_iter {
170            columns.push(column?);
171        }
172
173        if columns.len() != 2 {
174            Ok(false)
175        } else {
176            // Check if the column names and types match the expected values
177            Ok(columns[0].name == "row_id"
178                && columns[0].ty == "integer"
179                && columns[1].name == "item"
180                && columns[1].ty == "blob")
181        }
182    }
183
184    /// Get the database file name.
185    pub fn db_file(&self) -> PathBuf {
186        self.db_file.clone()
187    }
188
189    /// Get the split name.
190    pub fn split(&self) -> &str {
191        self.split.as_str()
192    }
193}
194
195impl<I> Dataset<I> for SqliteDataset<I>
196where
197    I: Clone + Send + Sync + DeserializeOwned,
198{
199    /// Get an item from the dataset.
200    fn get(&self, index: usize) -> Option<I> {
201        // Row ids start with 1 (one) and index starts with 0 (zero)
202        let row_id = index + 1;
203
204        // Get a connection from the pool
205        let connection = self.conn_pool.get().unwrap();
206        let mut statement = connection.prepare(self.select_statement.as_str()).unwrap();
207
208        if self.row_serialized {
209            // Fetch with a single column `item` and deserialize it with MessagePack
210            statement
211                .query_row([row_id], |row| {
212                    // Deserialize item (blob) with MessagePack (rmp-serde)
213                    Ok(
214                        rmp_serde::from_slice::<I>(row.get_ref(0).unwrap().as_blob().unwrap())
215                            .unwrap(),
216                    )
217                })
218                .optional() //Converts Error (not found) to None
219                .unwrap()
220        } else {
221            // Fetch a row with multiple columns and deserialize it serde_rusqlite
222            statement
223                .query_row([row_id], |row| {
224                    // Deserialize the row with serde_rusqlite
225                    Ok(from_row_with_columns::<I>(row, &self.columns).unwrap())
226                })
227                .optional() //Converts Error (not found) to None
228                .unwrap()
229        }
230    }
231
232    /// Return the number of rows in the dataset.
233    fn len(&self) -> usize {
234        self.len
235    }
236}
237
238/// Fetch the column names and the number of rows from the database.
239fn fetch_columns_and_len(
240    conn_pool: &Pool<SqliteConnectionManager>,
241    select_statement: &str,
242    split: &str,
243) -> Result<(Vec<String>, usize)> {
244    // Save the column names
245    let connection = conn_pool.get()?;
246    let statement = connection.prepare(select_statement)?;
247    let columns = columns_from_statement(&statement);
248
249    // Count the number of rows and save it as len
250    //
251    // NOTE: Using coalesce(max(row_id), 0) instead of count(*) because count(*) is super slow for large tables.
252    // The coalesce(max(row_id), 0) returns 0 if the table is empty, otherwise it returns the max row_id,
253    // which corresponds to the number of rows in the table.
254    // The main assumption, which always holds true, is that the row_id is always increasing and there are no gaps.
255    // This is true for all the datasets that we are using, otherwise row_id will not correspond to the index.
256    let mut statement =
257        connection.prepare(format!("select coalesce(max(row_id), 0) from {split}").as_str())?;
258
259    let len = statement.query_row([], |row| {
260        let len: usize = row.get(0)?;
261        Ok(len)
262    })?;
263    Ok((columns, len))
264}
265
266/// Helper function to create a connection pool
267fn create_conn_pool<P: AsRef<Path>>(
268    db_file: P,
269    write: bool,
270) -> Result<Pool<SqliteConnectionManager>> {
271    let sqlite_flags = if write {
272        OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE
273    } else {
274        OpenFlags::SQLITE_OPEN_READ_ONLY
275    };
276
277    let manager = SqliteConnectionManager::file(db_file).with_flags(sqlite_flags);
278    Pool::new(manager).map_err(SqliteDatasetError::ConnectionPool)
279}
280
281/// The `SqliteDatasetStorage` struct represents a SQLite database for storing datasets.
282/// It consists of an optional name, a database file path, and a base directory for storage.
283#[derive(Clone, Debug)]
284pub struct SqliteDatasetStorage {
285    name: Option<String>,
286    db_file: Option<PathBuf>,
287    base_dir: Option<PathBuf>,
288}
289
290impl SqliteDatasetStorage {
291    /// Creates a new instance of `SqliteDatasetStorage` using a dataset name.
292    ///
293    /// # Arguments
294    ///
295    /// * `name` - A string slice that holds the name of the dataset.
296    pub fn from_name(name: &str) -> Self {
297        SqliteDatasetStorage {
298            name: Some(name.to_string()),
299            db_file: None,
300            base_dir: None,
301        }
302    }
303
304    /// Creates a new instance of `SqliteDatasetStorage` using a database file path.
305    ///
306    /// # Arguments
307    ///
308    /// * `db_file` - A reference to the Path that represents the database file path.
309    pub fn from_file<P: AsRef<Path>>(db_file: P) -> Self {
310        SqliteDatasetStorage {
311            name: None,
312            db_file: Some(db_file.as_ref().to_path_buf()),
313            base_dir: None,
314        }
315    }
316
317    /// Sets the base directory for storing the dataset.
318    ///
319    /// # Arguments
320    ///
321    /// * `base_dir` - A string slice that represents the base directory.
322    pub fn with_base_dir<P: AsRef<Path>>(mut self, base_dir: P) -> Self {
323        self.base_dir = Some(base_dir.as_ref().to_path_buf());
324        self
325    }
326
327    /// Checks if the database file exists in the given path.
328    ///
329    /// # Returns
330    ///
331    /// * A boolean value indicating whether the file exists or not.
332    pub fn exists(&self) -> bool {
333        self.db_file().exists()
334    }
335
336    /// Fetches the database file path.
337    ///
338    /// # Returns
339    ///
340    /// * A `PathBuf` instance representing the file path.
341    pub fn db_file(&self) -> PathBuf {
342        let db_file = match &self.db_file {
343            Some(db_file) => db_file.clone(),
344            None => {
345                let name = sanitize(self.name.as_ref().expect("Name is not set"));
346                Self::base_dir(self.base_dir.to_owned()).join(format!("{name}.db"))
347            }
348        };
349        db_file
350    }
351
352    /// Determines the base directory for storing the dataset.
353    ///
354    /// # Arguments
355    ///
356    /// * `base_dir` - An `Option` that may contain a `PathBuf` instance representing the base directory.
357    ///
358    /// # Returns
359    ///
360    /// * A `PathBuf` instance representing the base directory.
361    pub fn base_dir(base_dir: Option<PathBuf>) -> PathBuf {
362        match base_dir {
363            Some(base_dir) => base_dir,
364            None => {
365                let home_dir = dirs::home_dir().expect("Could not get home directory");
366
367                home_dir.join(".cache").join("burn-dataset")
368            }
369        }
370    }
371
372    /// Provides a writer instance for the SQLite dataset.
373    ///
374    /// # Arguments
375    ///
376    /// * `overwrite` - A boolean indicating if the existing database file should be overwritten.
377    ///
378    /// # Returns
379    ///
380    /// * A `Result` which is `Ok` if the writer could be created, `Err` otherwise.
381    pub fn writer<I>(&self, overwrite: bool) -> Result<SqliteDatasetWriter<I>>
382    where
383        I: Clone + Send + Sync + Serialize + DeserializeOwned,
384    {
385        SqliteDatasetWriter::new(self.db_file(), overwrite)
386    }
387
388    /// Provides a reader instance for the SQLite dataset.
389    ///
390    /// # Arguments
391    ///
392    /// * `split` - A string slice that defines the data split for reading (e.g., "train", "test").
393    ///
394    /// # Returns
395    ///
396    /// * A `Result` which is `Ok` if the reader could be created, `Err` otherwise.
397    pub fn reader<I>(&self, split: &str) -> Result<SqliteDataset<I>>
398    where
399        I: Clone + Send + Sync + Serialize + DeserializeOwned,
400    {
401        if !self.exists() {
402            panic!("The database file does not exist");
403        }
404
405        SqliteDataset::from_db_file(self.db_file(), split)
406    }
407}
408
409/// This `SqliteDatasetWriter` struct is a SQLite database writer dedicated to storing datasets.
410/// It retains the current writer's state and its database connection.
411///
412/// Being thread-safe, this writer can be concurrently used across multiple threads.
413///
414/// Typical applications include:
415///
416/// - Generation of a new dataset
417/// - Storage of preprocessed data or metadata
418/// - Enlargement of a dataset's item count post preprocessing
419#[derive(Debug)]
420pub struct SqliteDatasetWriter<I> {
421    db_file: PathBuf,
422    db_file_tmp: Option<Handle<Writable>>,
423    splits: Arc<RwLock<HashSet<String>>>,
424    overwrite: bool,
425    conn_pool: Option<Pool<SqliteConnectionManager>>,
426    is_completed: Arc<RwLock<bool>>,
427    phantom: PhantomData<I>,
428}
429
430impl<I> SqliteDatasetWriter<I>
431where
432    I: Clone + Send + Sync + Serialize + DeserializeOwned,
433{
434    /// Creates a new instance of `SqliteDatasetWriter`.
435    ///
436    /// # Arguments
437    ///
438    /// * `db_file` - A reference to the Path that represents the database file path.
439    /// * `overwrite` - A boolean indicating if the existing database file should be overwritten.
440    ///
441    /// # Returns
442    ///
443    /// * A `Result` which is `Ok` if the writer could be created, `Err` otherwise.
444    pub fn new<P: AsRef<Path>>(db_file: P, overwrite: bool) -> Result<Self> {
445        let writer = Self {
446            db_file: db_file.as_ref().to_path_buf(),
447            db_file_tmp: None,
448            splits: Arc::new(RwLock::new(HashSet::new())),
449            overwrite,
450            conn_pool: None,
451            is_completed: Arc::new(RwLock::new(false)),
452            phantom: PhantomData,
453        };
454
455        writer.init()
456    }
457
458    /// Initializes the dataset writer by creating the database file, tables, and connection pool.
459    ///
460    /// # Returns
461    ///
462    /// * A `Result` which is `Ok` if the writer could be initialized, `Err` otherwise.
463    fn init(mut self) -> Result<Self> {
464        // Remove the db file if it already exists
465        if self.db_file.exists() {
466            if self.overwrite {
467                fs::remove_file(&self.db_file)?;
468            } else {
469                return Err(SqliteDatasetError::FileExists(self.db_file));
470            }
471        }
472
473        // Create the database file directory if it does not exist
474        let db_file_dir = self
475            .db_file
476            .parent()
477            .ok_or("Unable to get parent directory")?;
478
479        if !db_file_dir.exists() {
480            fs::create_dir_all(db_file_dir)?;
481        }
482
483        // Create a temp database file name as {base_dir}/{name}.db.tmp
484        let mut db_file_tmp = self.db_file.clone();
485        db_file_tmp.set_extension("db.tmp");
486        if db_file_tmp.exists() {
487            fs::remove_file(&db_file_tmp)?;
488        }
489
490        // Create the temp database file and wrap it with a gix_tempfile::Handle
491        // This will ensure that the temp file is deleted when the writer is dropped
492        // or when process exits with SIGINT or SIGTERM (tempfile crate does not do this)
493        gix_tempfile::signal::setup(Default::default());
494        self.db_file_tmp = Some(gix_tempfile::writable_at(
495            &db_file_tmp,
496            ContainingDirectory::Exists,
497            AutoRemove::Tempfile,
498        )?);
499
500        let conn_pool = create_conn_pool(db_file_tmp, true)?;
501        self.conn_pool = Some(conn_pool);
502
503        Ok(self)
504    }
505
506    /// Serializes and writes an item to the database. The item is written to the table for the
507    /// specified split. If the table does not exist, it is created. If the table exists, the item
508    /// is appended to the table. The serialization is done using the [MessagePack](https://msgpack.org/)
509    ///
510    /// # Arguments
511    ///
512    /// * `split` - A string slice that defines the data split for writing (e.g., "train", "test").
513    /// * `item` - A reference to the item to be written to the database.
514    ///
515    /// # Returns
516    ///
517    /// * A `Result` containing the index of the inserted row if successful, an error otherwise.
518    pub fn write(&self, split: &str, item: &I) -> Result<usize> {
519        // Acquire the read lock (wont't block other reads)
520        let is_completed = self.is_completed.read().unwrap();
521
522        // If the writer is completed, return an error
523        if *is_completed {
524            return Err(SqliteDatasetError::Other(
525                "Cannot save to a completed dataset writer",
526            ));
527        }
528
529        // create the table for the split if it does not exist
530        if !self.splits.read().unwrap().contains(split) {
531            self.create_table(split)?;
532        }
533
534        // Get a connection from the pool
535        let conn_pool = self.conn_pool.as_ref().unwrap();
536        let conn = conn_pool.get()?;
537
538        // Serialize the item using MessagePack
539        let serialized_item = rmp_serde::to_vec(item)?;
540
541        // Turn off the synchronous and journal mode for speed up
542        // We are sacrificing durability for speed but it's okay because
543        // we always recreate the dataset if it is not completed.
544        pragma_update_with_error_handling(&conn, "synchronous", "OFF")?;
545        pragma_update_with_error_handling(&conn, "journal_mode", "OFF")?;
546
547        // Insert the serialized item into the database
548        let insert_statement = format!("insert into {split} (item) values (?)", split = split);
549        conn.execute(insert_statement.as_str(), [serialized_item])?;
550
551        // Get the primary key of the last inserted row and convert to index (row_id-1)
552        let index = (conn.last_insert_rowid() - 1) as usize;
553
554        Ok(index)
555    }
556
557    /// Marks the dataset as completed and persists the temporary database file.
558    pub fn set_completed(&mut self) -> Result<()> {
559        let mut is_completed = self.is_completed.write().unwrap();
560
561        // Force close the connection pool
562        // This is required on Windows platform where the connection pool prevents
563        // from persisting the db by renaming the temp file.
564        if let Some(pool) = self.conn_pool.take() {
565            std::mem::drop(pool);
566        }
567
568        // Rename the database file from tmp to db
569        let _file_result = self
570            .db_file_tmp
571            .take() // take ownership of the temporary file and set to None
572            .unwrap() // unwrap the temporary file
573            .persist(&self.db_file)?
574            .ok_or("Unable to persist the database file")?;
575
576        *is_completed = true;
577        Ok(())
578    }
579
580    /// Creates table for the data split.
581    ///
582    /// Note: call is idempotent and thread-safe.
583    ///
584    /// # Arguments
585    ///
586    /// * `split` - A string slice that defines the data split for the table (e.g., "train", "test").
587    ///
588    /// # Returns
589    ///
590    /// * A `Result` which is `Ok` if the table could be created, `Err` otherwise.
591    ///
592    /// TODO (@antimora): add support creating a table with columns corresponding to the item fields
593    fn create_table(&self, split: &str) -> Result<()> {
594        // Check if the split already exists
595        if self.splits.read().unwrap().contains(split) {
596            return Ok(());
597        }
598
599        let conn_pool = self.conn_pool.as_ref().unwrap();
600        let connection = conn_pool.get()?;
601        let create_table_statement = format!(
602            "create table if not exists  {split} (row_id integer primary key autoincrement not \
603             null, item blob not null)"
604        );
605
606        connection.execute(create_table_statement.as_str(), [])?;
607
608        // Add the split to the splits
609        self.splits.write().unwrap().insert(split.to_string());
610
611        Ok(())
612    }
613}
614
615/// Runs a pragma update and ignores the `ExecuteReturnedResults` error.
616///
617/// Sometimes ExecuteReturnedResults is returned when running a pragma update. This is not an error
618/// and can be ignored. This function runs the pragma update and ignores the error if it is
619/// `ExecuteReturnedResults`.
620fn pragma_update_with_error_handling(
621    conn: &PooledConnection<SqliteConnectionManager>,
622    setting: &str,
623    value: &str,
624) -> Result<()> {
625    let result = conn.pragma_update(None, setting, value);
626    if let Err(error) = result {
627        if error != rusqlite::Error::ExecuteReturnedResults {
628            return Err(SqliteDatasetError::Sql(error));
629        }
630    }
631    Ok(())
632}
633
634#[cfg(test)]
635mod tests {
636    use rayon::prelude::*;
637    use rstest::{fixture, rstest};
638    use serde::{Deserialize, Serialize};
639    use tempfile::{tempdir, NamedTempFile, TempDir};
640
641    use super::*;
642
643    type SqlDs = SqliteDataset<Sample>;
644
645    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
646    pub struct Sample {
647        column_str: String,
648        column_bytes: Vec<u8>,
649        column_int: i64,
650        column_bool: bool,
651        column_float: f64,
652    }
653
654    #[fixture]
655    fn train_dataset() -> SqlDs {
656        SqliteDataset::<Sample>::from_db_file("tests/data/sqlite-dataset.db", "train").unwrap()
657    }
658
659    #[rstest]
660    pub fn len(train_dataset: SqlDs) {
661        assert_eq!(train_dataset.len(), 2);
662    }
663
664    #[rstest]
665    pub fn get_some(train_dataset: SqlDs) {
666        let item = train_dataset.get(0).unwrap();
667        assert_eq!(item.column_str, "HI1");
668        assert_eq!(item.column_bytes, vec![55, 231, 159]);
669        assert_eq!(item.column_int, 1);
670        assert!(item.column_bool);
671        assert_eq!(item.column_float, 1.0);
672    }
673
674    #[rstest]
675    pub fn get_none(train_dataset: SqlDs) {
676        assert_eq!(train_dataset.get(10), None);
677    }
678
679    #[rstest]
680    pub fn multi_thread(train_dataset: SqlDs) {
681        let indices: Vec<usize> = vec![0, 1, 1, 3, 4, 5, 6, 0, 8, 1];
682        let results: Vec<Option<Sample>> =
683            indices.par_iter().map(|&i| train_dataset.get(i)).collect();
684
685        let mut match_count = 0;
686        for (_index, result) in indices.iter().zip(results.iter()) {
687            if let Some(_val) = result {
688                match_count += 1
689            }
690        }
691
692        assert_eq!(match_count, 5);
693    }
694
695    #[test]
696    fn sqlite_dataset_storage() {
697        // Test with non-existing file
698        let storage = SqliteDatasetStorage::from_file("non-existing.db");
699        assert!(!storage.exists());
700
701        // Test with non-existing name
702        let storage = SqliteDatasetStorage::from_name("non-existing.db");
703        assert!(!storage.exists());
704
705        // Test with existing file
706        let storage = SqliteDatasetStorage::from_file("tests/data/sqlite-dataset.db");
707        assert!(storage.exists());
708        let result = storage.reader::<Sample>("train");
709        assert!(result.is_ok());
710        let train = result.unwrap();
711        assert_eq!(train.len(), 2);
712
713        // Test get writer
714        let temp_file = NamedTempFile::new().unwrap();
715        let storage = SqliteDatasetStorage::from_file(temp_file.path());
716        assert!(storage.exists());
717        let result = storage.writer::<Sample>(true);
718        assert!(result.is_ok());
719    }
720
721    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
722    pub struct Complex {
723        column_str: String,
724        column_bytes: Vec<u8>,
725        column_int: i64,
726        column_bool: bool,
727        column_float: f64,
728        column_complex: Vec<Vec<Vec<[u8; 3]>>>,
729    }
730
731    /// Create a temporary directory.
732    #[fixture]
733    fn tmp_dir() -> TempDir {
734        // Create a TempDir. This object will be automatically
735        // deleted when it goes out of scope.
736        tempdir().unwrap()
737    }
738    type Writer = SqliteDatasetWriter<Complex>;
739
740    /// Create a SqliteDatasetWriter with a temporary directory.
741    /// Make sure to return the temporary directory so that it is not deleted.
742    #[fixture]
743    fn writer_fixture(tmp_dir: TempDir) -> (Writer, TempDir) {
744        let temp_dir_str = tmp_dir.path();
745        let storage = SqliteDatasetStorage::from_name("preprocessed").with_base_dir(temp_dir_str);
746        let overwrite = true;
747        let result = storage.writer::<Complex>(overwrite);
748        assert!(result.is_ok());
749        let writer = result.unwrap();
750        (writer, tmp_dir)
751    }
752
753    #[test]
754    fn test_new() {
755        // Test that the constructor works with overwrite = true
756        let test_path = NamedTempFile::new().unwrap();
757        let _writer = SqliteDatasetWriter::<Complex>::new(&test_path, true).unwrap();
758        assert!(!test_path.path().exists());
759
760        // Test that the constructor works with overwrite = false
761        let test_path = NamedTempFile::new().unwrap();
762        let result = SqliteDatasetWriter::<Complex>::new(&test_path, false);
763        assert!(result.is_err());
764
765        // Test that the constructor works with no existing file
766        let temp = NamedTempFile::new().unwrap();
767        let test_path = temp.path().to_path_buf();
768        assert!(temp.close().is_ok());
769        assert!(!test_path.exists());
770        let _writer = SqliteDatasetWriter::<Complex>::new(&test_path, true).unwrap();
771        assert!(!test_path.exists());
772    }
773
774    #[rstest]
775    pub fn sqlite_writer_write(writer_fixture: (Writer, TempDir)) {
776        // Get the dataset_saver from the fixture and tmp_dir (will be deleted after scope)
777        let (writer, _tmp_dir) = writer_fixture;
778
779        assert!(writer.overwrite);
780        assert!(!writer.db_file.exists());
781
782        let new_item = Complex {
783            column_str: "HI1".to_string(),
784            column_bytes: vec![1_u8, 2, 3],
785            column_int: 0,
786            column_bool: true,
787            column_float: 1.0,
788            column_complex: vec![vec![vec![[1, 23_u8, 3]]]],
789        };
790
791        let index = writer.write("train", &new_item).unwrap();
792        assert_eq!(index, 0);
793
794        let mut writer = writer;
795
796        writer.set_completed().expect("Failed to set completed");
797
798        assert!(writer.db_file.exists());
799        assert!(writer.db_file_tmp.is_none());
800
801        let result = writer.write("train", &new_item);
802
803        // Should fail because the writer is completed
804        assert!(result.is_err());
805
806        let dataset = SqliteDataset::<Complex>::from_db_file(writer.db_file, "train").unwrap();
807
808        let fetched_item = dataset.get(0).unwrap();
809        assert_eq!(fetched_item, new_item);
810        assert_eq!(dataset.len(), 1);
811    }
812
813    #[rstest]
814    pub fn sqlite_writer_write_multi_thread(writer_fixture: (Writer, TempDir)) {
815        // Get the dataset_saver from the fixture and tmp_dir (will be deleted after scope)
816        let (writer, _tmp_dir) = writer_fixture;
817
818        let writer = Arc::new(writer);
819        let record_count = 20;
820
821        let splits = ["train", "test"];
822
823        (0..record_count).into_par_iter().for_each(|index: i64| {
824            let thread_id: std::thread::ThreadId = std::thread::current().id();
825            let sample = Complex {
826                column_str: format!("test_{:?}_{}", thread_id, index),
827                column_bytes: vec![index as u8, 2, 3],
828                column_int: index,
829                column_bool: true,
830                column_float: 1.0,
831                column_complex: vec![vec![vec![[1, index as u8, 3]]]],
832            };
833
834            // half for train and half for test
835            let split = splits[index as usize % 2];
836
837            let _index = writer.write(split, &sample).unwrap();
838        });
839
840        let mut writer = Arc::try_unwrap(writer).unwrap();
841
842        writer
843            .set_completed()
844            .expect("Should set completed successfully");
845
846        let train =
847            SqliteDataset::<Complex>::from_db_file(writer.db_file.clone(), "train").unwrap();
848        let test = SqliteDataset::<Complex>::from_db_file(writer.db_file, "test").unwrap();
849
850        assert_eq!(train.len(), record_count as usize / 2);
851        assert_eq!(test.len(), record_count as usize / 2);
852    }
853}