Skip to main content

burn_dataset/dataset/
sqlite.rs

1use std::{
2    collections::HashSet,
3    fs, io,
4    marker::PhantomData,
5    path::{Path, PathBuf},
6    sync::{Arc, RwLock},
7};
8
9use crate::Dataset;
10
11use gix_tempfile::{
12    AutoRemove, ContainingDirectory, Handle,
13    handle::{Writable, persist},
14};
15use r2d2::{Pool, PooledConnection};
16use r2d2_sqlite::{
17    SqliteConnectionManager,
18    rusqlite::{OpenFlags, OptionalExtension},
19};
20use sanitize_filename::sanitize;
21use serde::{Serialize, de::DeserializeOwned};
22use serde_rusqlite::{columns_from_statement, from_row_with_columns};
23
24/// Result type for the sqlite dataset.
25pub type Result<T> = core::result::Result<T, SqliteDatasetError>;
26
27/// Sqlite dataset error.
28#[derive(thiserror::Error, Debug)]
29pub enum SqliteDatasetError {
30    /// IO related error.
31    #[error("IO error: {0}")]
32    Io(#[from] io::Error),
33
34    /// Sql related error.
35    #[error("Sql error: {0}")]
36    Sql(#[from] serde_rusqlite::rusqlite::Error),
37
38    /// Serde related error.
39    #[error("Serde error: {0}")]
40    Serde(#[from] rmp_serde::encode::Error),
41
42    /// The database file already exists error.
43    #[error("Overwrite flag is set to false and the database file already exists: {0}")]
44    FileExists(PathBuf),
45
46    /// Error when creating the connection pool.
47    #[error("Failed to create connection pool: {0}")]
48    ConnectionPool(#[from] r2d2::Error),
49
50    /// Error when persisting the temporary database file.
51    #[error("Could not persist the temporary database file: {0}")]
52    PersistDbFile(#[from] persist::Error<Writable>),
53
54    /// Any other error.
55    #[error("{0}")]
56    Other(&'static str),
57}
58
59impl From<&'static str> for SqliteDatasetError {
60    fn from(s: &'static str) -> Self {
61        SqliteDatasetError::Other(s)
62    }
63}
64
65/// This struct represents a dataset where all items are stored in an SQLite database.
66/// Each instance of this struct corresponds to a specific table within the SQLite database,
67/// and allows for interaction with the data stored in the table in a structured and typed manner.
68///
69/// The SQLite database must contain a table with the same name as the `split` field. This table should
70/// have a primary key column named `row_id`, which is used to index the rows in the table. The `row_id`
71/// should start at 1, while the corresponding dataset `index` should start at 0, i.e., `row_id` = `index` + 1.
72///
73/// Table columns can be represented in two ways:
74///
75/// 1. The table can have a column for each field in the `I` struct. In this case, the column names in the table
76///    should match the field names of the `I` struct. The field names can be a subset of column names and
77///    can be in any order.
78///
79/// For the supported field types, refer to:
80/// - [Serialization field types](https://docs.rs/serde_rusqlite/latest/serde_rusqlite)
81/// - [SQLite data types](https://www.sqlite.org/datatype3.html)
82///
83/// 2. The fields in the `I` struct can be serialized into a single column `item` in the table. In this case, the table
84///    should have a single column named `item` of type `BLOB`. This is useful when the `I` struct contains complex fields
85///    that cannot be mapped to a SQLite type, such as nested structs, vectors, etc. The serialization is done using
86///    [MessagePack](https://msgpack.org/).
87///
88/// Note: The code automatically figures out which of the above two cases is applicable, and uses the appropriate
89/// method to read the data from the table.
90#[derive(Debug)]
91pub struct SqliteDataset<I> {
92    db_file: PathBuf,
93    split: String,
94    conn_pool: Pool<SqliteConnectionManager>,
95    columns: Vec<String>,
96    len: usize,
97    select_statement: String,
98    row_serialized: bool,
99    phantom: PhantomData<I>,
100}
101
102impl<I> SqliteDataset<I> {
103    /// Initializes a `SqliteDataset` from a SQLite database file and a split name.
104    pub fn from_db_file<P: AsRef<Path>>(db_file: P, split: &str) -> Result<Self> {
105        // Create a connection pool
106        let conn_pool = create_conn_pool(&db_file, false)?;
107
108        // Determine how the table is stored
109        let row_serialized = Self::check_if_row_serialized(&conn_pool, split)?;
110
111        // Create a select statement and save it
112        let select_statement = if row_serialized {
113            format!("select item from {split} where row_id = ?")
114        } else {
115            format!("select * from {split} where row_id = ?")
116        };
117
118        // Save the column names and the number of rows
119        let (columns, len) = fetch_columns_and_len(&conn_pool, &select_statement, split)?;
120
121        Ok(SqliteDataset {
122            db_file: db_file.as_ref().to_path_buf(),
123            split: split.to_string(),
124            conn_pool,
125            columns,
126            len,
127            select_statement,
128            row_serialized,
129            phantom: PhantomData,
130        })
131    }
132
133    /// Returns true if table has two columns: row_id (integer) and item (blob).
134    ///
135    /// This is used to determine if the table is row serialized or not.
136    fn check_if_row_serialized(
137        conn_pool: &Pool<SqliteConnectionManager>,
138        split: &str,
139    ) -> Result<bool> {
140        // This struct is used to store the column name and type
141        struct Column {
142            name: String,
143            ty: String,
144        }
145
146        const COLUMN_NAME: usize = 1;
147        const COLUMN_TYPE: usize = 2;
148
149        let sql_statement = format!("PRAGMA table_info({split})");
150
151        let conn = conn_pool.get()?;
152
153        let mut stmt = conn.prepare(sql_statement.as_str())?;
154        let column_iter = stmt.query_map([], |row| {
155            Ok(Column {
156                name: row
157                    .get::<usize, String>(COLUMN_NAME)
158                    .unwrap()
159                    .to_lowercase(),
160                ty: row
161                    .get::<usize, String>(COLUMN_TYPE)
162                    .unwrap()
163                    .to_lowercase(),
164            })
165        })?;
166
167        let mut columns: Vec<Column> = vec![];
168
169        for column in column_iter {
170            columns.push(column?);
171        }
172
173        if columns.len() != 2 {
174            Ok(false)
175        } else {
176            // Check if the column names and types match the expected values
177            Ok(columns[0].name == "row_id"
178                && columns[0].ty == "integer"
179                && columns[1].name == "item"
180                && columns[1].ty == "blob")
181        }
182    }
183
184    /// Get the database file name.
185    pub fn db_file(&self) -> PathBuf {
186        self.db_file.clone()
187    }
188
189    /// Get the split name.
190    pub fn split(&self) -> &str {
191        self.split.as_str()
192    }
193}
194
195impl<I> Dataset<I> for SqliteDataset<I>
196where
197    I: Clone + Send + Sync + DeserializeOwned,
198{
199    /// Get an item from the dataset.
200    fn get(&self, index: usize) -> Option<I> {
201        // Row ids start with 1 (one) and index starts with 0 (zero)
202        let row_id = index + 1;
203
204        // Get a connection from the pool
205        let connection = self.conn_pool.get().unwrap();
206        let mut statement = connection.prepare(self.select_statement.as_str()).unwrap();
207
208        if self.row_serialized {
209            // Fetch with a single column `item` and deserialize it with MessagePack
210            statement
211                .query_row([row_id], |row| {
212                    // Deserialize item (blob) with MessagePack (rmp-serde)
213                    Ok(
214                        rmp_serde::from_slice::<I>(row.get_ref(0).unwrap().as_blob().unwrap())
215                            .unwrap(),
216                    )
217                })
218                .optional() //Converts Error (not found) to None
219                .unwrap()
220        } else {
221            // Fetch a row with multiple columns and deserialize it serde_rusqlite
222            statement
223                .query_row([row_id], |row| {
224                    // Deserialize the row with serde_rusqlite
225                    Ok(from_row_with_columns::<I>(row, &self.columns).unwrap())
226                })
227                .optional() //Converts Error (not found) to None
228                .unwrap()
229        }
230    }
231
232    /// Return the number of rows in the dataset.
233    fn len(&self) -> usize {
234        self.len
235    }
236}
237
238/// Fetch the column names and the number of rows from the database.
239fn fetch_columns_and_len(
240    conn_pool: &Pool<SqliteConnectionManager>,
241    select_statement: &str,
242    split: &str,
243) -> Result<(Vec<String>, usize)> {
244    // Save the column names
245    let connection = conn_pool.get()?;
246    let statement = connection.prepare(select_statement)?;
247    let columns = columns_from_statement(&statement);
248
249    // Count the number of rows and save it as len
250    //
251    // NOTE: Using coalesce(max(row_id), 0) instead of count(*) because count(*) is super slow for large tables.
252    // The coalesce(max(row_id), 0) returns 0 if the table is empty, otherwise it returns the max row_id,
253    // which corresponds to the number of rows in the table.
254    // The main assumption, which always holds true, is that the row_id is always increasing and there are no gaps.
255    // This is true for all the datasets that we are using, otherwise row_id will not correspond to the index.
256    let mut statement =
257        connection.prepare(format!("select coalesce(max(row_id), 0) from {split}").as_str())?;
258
259    let len = statement.query_row([], |row| {
260        let len: usize = row.get(0)?;
261        Ok(len)
262    })?;
263    Ok((columns, len))
264}
265
266/// Helper function to create a connection pool
267fn create_conn_pool<P: AsRef<Path>>(
268    db_file: P,
269    write: bool,
270) -> Result<Pool<SqliteConnectionManager>> {
271    let sqlite_flags = if write {
272        OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE
273    } else {
274        OpenFlags::SQLITE_OPEN_READ_ONLY
275    };
276
277    let manager = SqliteConnectionManager::file(db_file).with_flags(sqlite_flags);
278    Pool::new(manager).map_err(SqliteDatasetError::ConnectionPool)
279}
280
281/// The `SqliteDatasetStorage` struct represents a SQLite database for storing datasets.
282/// It consists of an optional name, a database file path, and a base directory for storage.
283#[derive(Clone, Debug)]
284pub struct SqliteDatasetStorage {
285    name: Option<String>,
286    db_file: Option<PathBuf>,
287    base_dir: Option<PathBuf>,
288}
289
290impl SqliteDatasetStorage {
291    /// Creates a new instance of `SqliteDatasetStorage` using a dataset name.
292    ///
293    /// # Arguments
294    ///
295    /// * `name` - A string slice that holds the name of the dataset.
296    pub fn from_name(name: &str) -> Self {
297        SqliteDatasetStorage {
298            name: Some(name.to_string()),
299            db_file: None,
300            base_dir: None,
301        }
302    }
303
304    /// Creates a new instance of `SqliteDatasetStorage` using a database file path.
305    ///
306    /// # Arguments
307    ///
308    /// * `db_file` - A reference to the Path that represents the database file path.
309    pub fn from_file<P: AsRef<Path>>(db_file: P) -> Self {
310        SqliteDatasetStorage {
311            name: None,
312            db_file: Some(db_file.as_ref().to_path_buf()),
313            base_dir: None,
314        }
315    }
316
317    /// Sets the base directory for storing the dataset.
318    ///
319    /// # Arguments
320    ///
321    /// * `base_dir` - A string slice that represents the base directory.
322    pub fn with_base_dir<P: AsRef<Path>>(mut self, base_dir: P) -> Self {
323        self.base_dir = Some(base_dir.as_ref().to_path_buf());
324        self
325    }
326
327    /// Checks if the database file exists in the given path.
328    ///
329    /// # Returns
330    ///
331    /// * A boolean value indicating whether the file exists or not.
332    pub fn exists(&self) -> bool {
333        self.db_file().exists()
334    }
335
336    /// Fetches the database file path.
337    ///
338    /// # Returns
339    ///
340    /// * A `PathBuf` instance representing the file path.
341    pub fn db_file(&self) -> PathBuf {
342        match &self.db_file {
343            Some(db_file) => db_file.clone(),
344            None => {
345                let name = sanitize(self.name.as_ref().expect("Name is not set"));
346                Self::base_dir(self.base_dir.to_owned()).join(format!("{name}.db"))
347            }
348        }
349    }
350
351    /// Determines the base directory for storing the dataset.
352    ///
353    /// # Arguments
354    ///
355    /// * `base_dir` - An `Option` that may contain a `PathBuf` instance representing the base directory.
356    ///
357    /// # Returns
358    ///
359    /// * A `PathBuf` instance representing the base directory.
360    pub fn base_dir(base_dir: Option<PathBuf>) -> PathBuf {
361        match base_dir {
362            Some(base_dir) => base_dir,
363            None => dirs::cache_dir()
364                .expect("Could not get cache directory")
365                .join("burn-dataset"),
366        }
367    }
368
369    /// Provides a writer instance for the SQLite dataset.
370    ///
371    /// # Arguments
372    ///
373    /// * `overwrite` - A boolean indicating if the existing database file should be overwritten.
374    ///
375    /// # Returns
376    ///
377    /// * A `Result` which is `Ok` if the writer could be created, `Err` otherwise.
378    pub fn writer<I>(&self, overwrite: bool) -> Result<SqliteDatasetWriter<I>>
379    where
380        I: Clone + Send + Sync + Serialize + DeserializeOwned,
381    {
382        SqliteDatasetWriter::new(self.db_file(), overwrite)
383    }
384
385    /// Provides a reader instance for the SQLite dataset.
386    ///
387    /// # Arguments
388    ///
389    /// * `split` - A string slice that defines the data split for reading (e.g., "train", "test").
390    ///
391    /// # Returns
392    ///
393    /// * A `Result` which is `Ok` if the reader could be created, `Err` otherwise.
394    pub fn reader<I>(&self, split: &str) -> Result<SqliteDataset<I>>
395    where
396        I: Clone + Send + Sync + Serialize + DeserializeOwned,
397    {
398        if !self.exists() {
399            panic!("The database file does not exist");
400        }
401
402        SqliteDataset::from_db_file(self.db_file(), split)
403    }
404}
405
406/// This `SqliteDatasetWriter` struct is a SQLite database writer dedicated to storing datasets.
407/// It retains the current writer's state and its database connection.
408///
409/// Being thread-safe, this writer can be concurrently used across multiple threads.
410///
411/// Typical applications include:
412///
413/// - Generation of a new dataset
414/// - Storage of preprocessed data or metadata
415/// - Enlargement of a dataset's item count post preprocessing
416#[derive(Debug)]
417pub struct SqliteDatasetWriter<I> {
418    db_file: PathBuf,
419    db_file_tmp: Option<Handle<Writable>>,
420    splits: Arc<RwLock<HashSet<String>>>,
421    overwrite: bool,
422    conn_pool: Option<Pool<SqliteConnectionManager>>,
423    is_completed: Arc<RwLock<bool>>,
424    phantom: PhantomData<I>,
425}
426
427impl<I> SqliteDatasetWriter<I>
428where
429    I: Clone + Send + Sync + Serialize + DeserializeOwned,
430{
431    /// Creates a new instance of `SqliteDatasetWriter`.
432    ///
433    /// # Arguments
434    ///
435    /// * `db_file` - A reference to the Path that represents the database file path.
436    /// * `overwrite` - A boolean indicating if the existing database file should be overwritten.
437    ///
438    /// # Returns
439    ///
440    /// * A `Result` which is `Ok` if the writer could be created, `Err` otherwise.
441    pub fn new<P: AsRef<Path>>(db_file: P, overwrite: bool) -> Result<Self> {
442        let writer = Self {
443            db_file: db_file.as_ref().to_path_buf(),
444            db_file_tmp: None,
445            splits: Arc::new(RwLock::new(HashSet::new())),
446            overwrite,
447            conn_pool: None,
448            is_completed: Arc::new(RwLock::new(false)),
449            phantom: PhantomData,
450        };
451
452        writer.init()
453    }
454
455    /// Initializes the dataset writer by creating the database file, tables, and connection pool.
456    ///
457    /// # Returns
458    ///
459    /// * A `Result` which is `Ok` if the writer could be initialized, `Err` otherwise.
460    fn init(mut self) -> Result<Self> {
461        // Remove the db file if it already exists
462        if self.db_file.exists() {
463            if self.overwrite {
464                fs::remove_file(&self.db_file)?;
465            } else {
466                return Err(SqliteDatasetError::FileExists(self.db_file));
467            }
468        }
469
470        // Create the database file directory if it does not exist
471        let db_file_dir = self
472            .db_file
473            .parent()
474            .ok_or("Unable to get parent directory")?;
475
476        if !db_file_dir.exists() {
477            fs::create_dir_all(db_file_dir)?;
478        }
479
480        // Create a temp database file name as {base_dir}/{name}.db.tmp
481        let mut db_file_tmp = self.db_file.clone();
482        db_file_tmp.set_extension("db.tmp");
483        if db_file_tmp.exists() {
484            fs::remove_file(&db_file_tmp)?;
485        }
486
487        // Create the temp database file and wrap it with a gix_tempfile::Handle
488        // This will ensure that the temp file is deleted when the writer is dropped
489        // or when process exits with SIGINT or SIGTERM (tempfile crate does not do this)
490        gix_tempfile::signal::setup(Default::default());
491        self.db_file_tmp = Some(gix_tempfile::writable_at(
492            &db_file_tmp,
493            ContainingDirectory::Exists,
494            AutoRemove::Tempfile,
495        )?);
496
497        let conn_pool = create_conn_pool(db_file_tmp, true)?;
498        self.conn_pool = Some(conn_pool);
499
500        Ok(self)
501    }
502
503    /// Serializes and writes an item to the database. The item is written to the table for the
504    /// specified split. If the table does not exist, it is created. If the table exists, the item
505    /// is appended to the table. The serialization is done using the [MessagePack](https://msgpack.org/)
506    ///
507    /// # Arguments
508    ///
509    /// * `split` - A string slice that defines the data split for writing (e.g., "train", "test").
510    /// * `item` - A reference to the item to be written to the database.
511    ///
512    /// # Returns
513    ///
514    /// * A `Result` containing the index of the inserted row if successful, an error otherwise.
515    pub fn write(&self, split: &str, item: &I) -> Result<usize> {
516        // Acquire the read lock (wont't block other reads)
517        let is_completed = self.is_completed.read().unwrap();
518
519        // If the writer is completed, return an error
520        if *is_completed {
521            return Err(SqliteDatasetError::Other(
522                "Cannot save to a completed dataset writer",
523            ));
524        }
525
526        // create the table for the split if it does not exist
527        if !self.splits.read().unwrap().contains(split) {
528            self.create_table(split)?;
529        }
530
531        // Get a connection from the pool
532        let conn_pool = self.conn_pool.as_ref().unwrap();
533        let conn = conn_pool.get()?;
534
535        // Serialize the item using MessagePack
536        let serialized_item = rmp_serde::to_vec(item)?;
537
538        // Turn off the synchronous and journal mode for speed up
539        // We are sacrificing durability for speed but it's okay because
540        // we always recreate the dataset if it is not completed.
541        pragma_update_with_error_handling(&conn, "synchronous", "OFF")?;
542        pragma_update_with_error_handling(&conn, "journal_mode", "OFF")?;
543
544        // Insert the serialized item into the database
545        let insert_statement = format!("insert into {split} (item) values (?)");
546        conn.execute(insert_statement.as_str(), [serialized_item])?;
547
548        // Get the primary key of the last inserted row and convert to index (row_id-1)
549        let index = (conn.last_insert_rowid() - 1) as usize;
550
551        Ok(index)
552    }
553
554    /// Marks the dataset as completed and persists the temporary database file.
555    pub fn set_completed(&mut self) -> Result<()> {
556        let mut is_completed = self.is_completed.write().unwrap();
557
558        // Force close the connection pool
559        // This is required on Windows platform where the connection pool prevents
560        // from persisting the db by renaming the temp file.
561        if let Some(pool) = self.conn_pool.take() {
562            std::mem::drop(pool);
563        }
564
565        // Rename the database file from tmp to db
566        let _file_result = self
567            .db_file_tmp
568            .take() // take ownership of the temporary file and set to None
569            .unwrap() // unwrap the temporary file
570            .persist(&self.db_file)?
571            .ok_or("Unable to persist the database file")?;
572
573        *is_completed = true;
574        Ok(())
575    }
576
577    /// Creates table for the data split.
578    ///
579    /// Note: call is idempotent and thread-safe.
580    ///
581    /// # Arguments
582    ///
583    /// * `split` - A string slice that defines the data split for the table (e.g., "train", "test").
584    ///
585    /// # Returns
586    ///
587    /// * A `Result` which is `Ok` if the table could be created, `Err` otherwise.
588    ///
589    /// TODO (@antimora): add support creating a table with columns corresponding to the item fields
590    fn create_table(&self, split: &str) -> Result<()> {
591        // Check if the split already exists
592        if self.splits.read().unwrap().contains(split) {
593            return Ok(());
594        }
595
596        let conn_pool = self.conn_pool.as_ref().unwrap();
597        let connection = conn_pool.get()?;
598        let create_table_statement = format!(
599            "create table if not exists  {split} (row_id integer primary key autoincrement not \
600             null, item blob not null)"
601        );
602
603        connection.execute(create_table_statement.as_str(), [])?;
604
605        // Add the split to the splits
606        self.splits.write().unwrap().insert(split.to_string());
607
608        Ok(())
609    }
610}
611
612/// Runs a pragma update and ignores the `ExecuteReturnedResults` error.
613///
614/// Sometimes ExecuteReturnedResults is returned when running a pragma update. This is not an error
615/// and can be ignored. This function runs the pragma update and ignores the error if it is
616/// `ExecuteReturnedResults`.
617fn pragma_update_with_error_handling(
618    conn: &PooledConnection<SqliteConnectionManager>,
619    setting: &str,
620    value: &str,
621) -> Result<()> {
622    let result = conn.pragma_update(None, setting, value);
623    if let Err(error) = result
624        && error != rusqlite::Error::ExecuteReturnedResults
625    {
626        return Err(SqliteDatasetError::Sql(error));
627    }
628
629    Ok(())
630}
631
632#[cfg(test)]
633mod tests {
634    use rayon::prelude::*;
635    use rstest::{fixture, rstest};
636    use serde::{Deserialize, Serialize};
637    use tempfile::{NamedTempFile, TempDir, tempdir};
638
639    use super::*;
640
641    type SqlDs = SqliteDataset<Sample>;
642
643    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
644    pub struct Sample {
645        column_str: String,
646        column_bytes: Vec<u8>,
647        column_int: i64,
648        column_bool: bool,
649        column_float: f64,
650    }
651
652    #[fixture]
653    fn train_dataset() -> SqlDs {
654        SqliteDataset::<Sample>::from_db_file("tests/data/sqlite-dataset.db", "train").unwrap()
655    }
656
657    #[rstest]
658    pub fn len(train_dataset: SqlDs) {
659        assert_eq!(train_dataset.len(), 2);
660    }
661
662    #[rstest]
663    pub fn get_some(train_dataset: SqlDs) {
664        let item = train_dataset.get(0).unwrap();
665        assert_eq!(item.column_str, "HI1");
666        assert_eq!(item.column_bytes, vec![55, 231, 159]);
667        assert_eq!(item.column_int, 1);
668        assert!(item.column_bool);
669        assert_eq!(item.column_float, 1.0);
670    }
671
672    #[rstest]
673    pub fn get_none(train_dataset: SqlDs) {
674        assert_eq!(train_dataset.get(10), None);
675    }
676
677    #[rstest]
678    pub fn multi_thread(train_dataset: SqlDs) {
679        let indices: Vec<usize> = vec![0, 1, 1, 3, 4, 5, 6, 0, 8, 1];
680        let results: Vec<Option<Sample>> =
681            indices.par_iter().map(|&i| train_dataset.get(i)).collect();
682
683        let mut match_count = 0;
684        for (_index, result) in indices.iter().zip(results.iter()) {
685            if let Some(_val) = result {
686                match_count += 1
687            }
688        }
689
690        assert_eq!(match_count, 5);
691    }
692
693    #[test]
694    fn sqlite_dataset_storage() {
695        // Test with non-existing file
696        let storage = SqliteDatasetStorage::from_file("non-existing.db");
697        assert!(!storage.exists());
698
699        // Test with non-existing name
700        let storage = SqliteDatasetStorage::from_name("non-existing.db");
701        assert!(!storage.exists());
702
703        // Test with existing file
704        let storage = SqliteDatasetStorage::from_file("tests/data/sqlite-dataset.db");
705        assert!(storage.exists());
706        let result = storage.reader::<Sample>("train");
707        assert!(result.is_ok());
708        let train = result.unwrap();
709        assert_eq!(train.len(), 2);
710
711        // Test get writer
712        let temp_file = NamedTempFile::new().unwrap();
713        let storage = SqliteDatasetStorage::from_file(temp_file.path());
714        assert!(storage.exists());
715        let result = storage.writer::<Sample>(true);
716        assert!(result.is_ok());
717    }
718
719    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
720    pub struct Complex {
721        column_str: String,
722        column_bytes: Vec<u8>,
723        column_int: i64,
724        column_bool: bool,
725        column_float: f64,
726        column_complex: Vec<Vec<Vec<[u8; 3]>>>,
727    }
728
729    /// Create a temporary directory.
730    #[fixture]
731    fn tmp_dir() -> TempDir {
732        // Create a TempDir. This object will be automatically
733        // deleted when it goes out of scope.
734        tempdir().unwrap()
735    }
736    type Writer = SqliteDatasetWriter<Complex>;
737
738    /// Create a SqliteDatasetWriter with a temporary directory.
739    /// Make sure to return the temporary directory so that it is not deleted.
740    #[fixture]
741    fn writer_fixture(tmp_dir: TempDir) -> (Writer, TempDir) {
742        let temp_dir_str = tmp_dir.path();
743        let storage = SqliteDatasetStorage::from_name("preprocessed").with_base_dir(temp_dir_str);
744        let overwrite = true;
745        let result = storage.writer::<Complex>(overwrite);
746        assert!(result.is_ok());
747        let writer = result.unwrap();
748        (writer, tmp_dir)
749    }
750
751    #[test]
752    fn test_new() {
753        // Test that the constructor works with overwrite = true
754        let test_path = NamedTempFile::new().unwrap();
755        let _writer = SqliteDatasetWriter::<Complex>::new(&test_path, true).unwrap();
756        assert!(!test_path.path().exists());
757
758        // Test that the constructor works with overwrite = false
759        let test_path = NamedTempFile::new().unwrap();
760        let result = SqliteDatasetWriter::<Complex>::new(&test_path, false);
761        assert!(result.is_err());
762
763        // Test that the constructor works with no existing file
764        let temp = NamedTempFile::new().unwrap();
765        let test_path = temp.path().to_path_buf();
766        assert!(temp.close().is_ok());
767        assert!(!test_path.exists());
768        let _writer = SqliteDatasetWriter::<Complex>::new(&test_path, true).unwrap();
769        assert!(!test_path.exists());
770    }
771
772    #[rstest]
773    pub fn sqlite_writer_write(writer_fixture: (Writer, TempDir)) {
774        // Get the dataset_saver from the fixture and tmp_dir (will be deleted after scope)
775        let (writer, _tmp_dir) = writer_fixture;
776
777        assert!(writer.overwrite);
778        assert!(!writer.db_file.exists());
779
780        let new_item = Complex {
781            column_str: "HI1".to_string(),
782            column_bytes: vec![1_u8, 2, 3],
783            column_int: 0,
784            column_bool: true,
785            column_float: 1.0,
786            column_complex: vec![vec![vec![[1, 23_u8, 3]]]],
787        };
788
789        let index = writer.write("train", &new_item).unwrap();
790        assert_eq!(index, 0);
791
792        let mut writer = writer;
793
794        writer.set_completed().expect("Failed to set completed");
795
796        assert!(writer.db_file.exists());
797        assert!(writer.db_file_tmp.is_none());
798
799        let result = writer.write("train", &new_item);
800
801        // Should fail because the writer is completed
802        assert!(result.is_err());
803
804        let dataset = SqliteDataset::<Complex>::from_db_file(writer.db_file, "train").unwrap();
805
806        let fetched_item = dataset.get(0).unwrap();
807        assert_eq!(fetched_item, new_item);
808        assert_eq!(dataset.len(), 1);
809    }
810
811    #[rstest]
812    pub fn sqlite_writer_write_multi_thread(writer_fixture: (Writer, TempDir)) {
813        // Get the dataset_saver from the fixture and tmp_dir (will be deleted after scope)
814        let (writer, _tmp_dir) = writer_fixture;
815
816        let writer = Arc::new(writer);
817        let record_count = 20;
818
819        let splits = ["train", "test"];
820
821        (0..record_count).into_par_iter().for_each(|index: i64| {
822            let thread_id: std::thread::ThreadId = std::thread::current().id();
823            let sample = Complex {
824                column_str: format!("test_{thread_id:?}_{index}"),
825                column_bytes: vec![index as u8, 2, 3],
826                column_int: index,
827                column_bool: true,
828                column_float: 1.0,
829                column_complex: vec![vec![vec![[1, index as u8, 3]]]],
830            };
831
832            // half for train and half for test
833            let split = splits[index as usize % 2];
834
835            let _index = writer.write(split, &sample).unwrap();
836        });
837
838        let mut writer = Arc::try_unwrap(writer).unwrap();
839
840        writer
841            .set_completed()
842            .expect("Should set completed successfully");
843
844        let train =
845            SqliteDataset::<Complex>::from_db_file(writer.db_file.clone(), "train").unwrap();
846        let test = SqliteDataset::<Complex>::from_db_file(writer.db_file, "test").unwrap();
847
848        assert_eq!(train.len(), record_count as usize / 2);
849        assert_eq!(test.len(), record_count as usize / 2);
850    }
851}