burn_dataset/dataset/
sqlite.rs

1use std::{
2    collections::HashSet,
3    fs, io,
4    marker::PhantomData,
5    path::{Path, PathBuf},
6    sync::{Arc, RwLock},
7};
8
9use crate::Dataset;
10
11use gix_tempfile::{
12    AutoRemove, ContainingDirectory, Handle,
13    handle::{Writable, persist},
14};
15use r2d2::{Pool, PooledConnection};
16use r2d2_sqlite::{
17    SqliteConnectionManager,
18    rusqlite::{OpenFlags, OptionalExtension},
19};
20use sanitize_filename::sanitize;
21use serde::{Serialize, de::DeserializeOwned};
22use serde_rusqlite::{columns_from_statement, from_row_with_columns};
23
24/// Result type for the sqlite dataset.
25pub type Result<T> = core::result::Result<T, SqliteDatasetError>;
26
27/// Sqlite dataset error.
28#[derive(thiserror::Error, Debug)]
29pub enum SqliteDatasetError {
30    /// IO related error.
31    #[error("IO error: {0}")]
32    Io(#[from] io::Error),
33
34    /// Sql related error.
35    #[error("Sql error: {0}")]
36    Sql(#[from] serde_rusqlite::rusqlite::Error),
37
38    /// Serde related error.
39    #[error("Serde error: {0}")]
40    Serde(#[from] rmp_serde::encode::Error),
41
42    /// The database file already exists error.
43    #[error("Overwrite flag is set to false and the database file already exists: {0}")]
44    FileExists(PathBuf),
45
46    /// Error when creating the connection pool.
47    #[error("Failed to create connection pool: {0}")]
48    ConnectionPool(#[from] r2d2::Error),
49
50    /// Error when persisting the temporary database file.
51    #[error("Could not persist the temporary database file: {0}")]
52    PersistDbFile(#[from] persist::Error<Writable>),
53
54    /// Any other error.
55    #[error("{0}")]
56    Other(&'static str),
57}
58
59impl From<&'static str> for SqliteDatasetError {
60    fn from(s: &'static str) -> Self {
61        SqliteDatasetError::Other(s)
62    }
63}
64
65/// This struct represents a dataset where all items are stored in an SQLite database.
66/// Each instance of this struct corresponds to a specific table within the SQLite database,
67/// and allows for interaction with the data stored in the table in a structured and typed manner.
68///
69/// The SQLite database must contain a table with the same name as the `split` field. This table should
70/// have a primary key column named `row_id`, which is used to index the rows in the table. The `row_id`
71/// should start at 1, while the corresponding dataset `index` should start at 0, i.e., `row_id` = `index` + 1.
72///
73/// Table columns can be represented in two ways:
74///
75/// 1. The table can have a column for each field in the `I` struct. In this case, the column names in the table
76///    should match the field names of the `I` struct. The field names can be a subset of column names and
77///    can be in any order.
78///
79/// For the supported field types, refer to:
80/// - [Serialization field types](https://docs.rs/serde_rusqlite/latest/serde_rusqlite)
81/// - [SQLite data types](https://www.sqlite.org/datatype3.html)
82///
83/// 2. The fields in the `I` struct can be serialized into a single column `item` in the table. In this case, the table
84///    should have a single column named `item` of type `BLOB`. This is useful when the `I` struct contains complex fields
85///    that cannot be mapped to a SQLite type, such as nested structs, vectors, etc. The serialization is done using
86///    [MessagePack](https://msgpack.org/).
87///
88/// Note: The code automatically figures out which of the above two cases is applicable, and uses the appropriate
89/// method to read the data from the table.
90#[derive(Debug)]
91pub struct SqliteDataset<I> {
92    db_file: PathBuf,
93    split: String,
94    conn_pool: Pool<SqliteConnectionManager>,
95    columns: Vec<String>,
96    len: usize,
97    select_statement: String,
98    row_serialized: bool,
99    phantom: PhantomData<I>,
100}
101
102impl<I> SqliteDataset<I> {
103    /// Initializes a `SqliteDataset` from a SQLite database file and a split name.
104    pub fn from_db_file<P: AsRef<Path>>(db_file: P, split: &str) -> Result<Self> {
105        // Create a connection pool
106        let conn_pool = create_conn_pool(&db_file, false)?;
107
108        // Determine how the table is stored
109        let row_serialized = Self::check_if_row_serialized(&conn_pool, split)?;
110
111        // Create a select statement and save it
112        let select_statement = if row_serialized {
113            format!("select item from {split} where row_id = ?")
114        } else {
115            format!("select * from {split} where row_id = ?")
116        };
117
118        // Save the column names and the number of rows
119        let (columns, len) = fetch_columns_and_len(&conn_pool, &select_statement, split)?;
120
121        Ok(SqliteDataset {
122            db_file: db_file.as_ref().to_path_buf(),
123            split: split.to_string(),
124            conn_pool,
125            columns,
126            len,
127            select_statement,
128            row_serialized,
129            phantom: PhantomData,
130        })
131    }
132
133    /// Returns true if table has two columns: row_id (integer) and item (blob).
134    ///
135    /// This is used to determine if the table is row serialized or not.
136    fn check_if_row_serialized(
137        conn_pool: &Pool<SqliteConnectionManager>,
138        split: &str,
139    ) -> Result<bool> {
140        // This struct is used to store the column name and type
141        struct Column {
142            name: String,
143            ty: String,
144        }
145
146        const COLUMN_NAME: usize = 1;
147        const COLUMN_TYPE: usize = 2;
148
149        let sql_statement = format!("PRAGMA table_info({split})");
150
151        let conn = conn_pool.get()?;
152
153        let mut stmt = conn.prepare(sql_statement.as_str())?;
154        let column_iter = stmt.query_map([], |row| {
155            Ok(Column {
156                name: row
157                    .get::<usize, String>(COLUMN_NAME)
158                    .unwrap()
159                    .to_lowercase(),
160                ty: row
161                    .get::<usize, String>(COLUMN_TYPE)
162                    .unwrap()
163                    .to_lowercase(),
164            })
165        })?;
166
167        let mut columns: Vec<Column> = vec![];
168
169        for column in column_iter {
170            columns.push(column?);
171        }
172
173        if columns.len() != 2 {
174            Ok(false)
175        } else {
176            // Check if the column names and types match the expected values
177            Ok(columns[0].name == "row_id"
178                && columns[0].ty == "integer"
179                && columns[1].name == "item"
180                && columns[1].ty == "blob")
181        }
182    }
183
184    /// Get the database file name.
185    pub fn db_file(&self) -> PathBuf {
186        self.db_file.clone()
187    }
188
189    /// Get the split name.
190    pub fn split(&self) -> &str {
191        self.split.as_str()
192    }
193}
194
195impl<I> Dataset<I> for SqliteDataset<I>
196where
197    I: Clone + Send + Sync + DeserializeOwned,
198{
199    /// Get an item from the dataset.
200    fn get(&self, index: usize) -> Option<I> {
201        // Row ids start with 1 (one) and index starts with 0 (zero)
202        let row_id = index + 1;
203
204        // Get a connection from the pool
205        let connection = self.conn_pool.get().unwrap();
206        let mut statement = connection.prepare(self.select_statement.as_str()).unwrap();
207
208        if self.row_serialized {
209            // Fetch with a single column `item` and deserialize it with MessagePack
210            statement
211                .query_row([row_id], |row| {
212                    // Deserialize item (blob) with MessagePack (rmp-serde)
213                    Ok(
214                        rmp_serde::from_slice::<I>(row.get_ref(0).unwrap().as_blob().unwrap())
215                            .unwrap(),
216                    )
217                })
218                .optional() //Converts Error (not found) to None
219                .unwrap()
220        } else {
221            // Fetch a row with multiple columns and deserialize it serde_rusqlite
222            statement
223                .query_row([row_id], |row| {
224                    // Deserialize the row with serde_rusqlite
225                    Ok(from_row_with_columns::<I>(row, &self.columns).unwrap())
226                })
227                .optional() //Converts Error (not found) to None
228                .unwrap()
229        }
230    }
231
232    /// Return the number of rows in the dataset.
233    fn len(&self) -> usize {
234        self.len
235    }
236}
237
238/// Fetch the column names and the number of rows from the database.
239fn fetch_columns_and_len(
240    conn_pool: &Pool<SqliteConnectionManager>,
241    select_statement: &str,
242    split: &str,
243) -> Result<(Vec<String>, usize)> {
244    // Save the column names
245    let connection = conn_pool.get()?;
246    let statement = connection.prepare(select_statement)?;
247    let columns = columns_from_statement(&statement);
248
249    // Count the number of rows and save it as len
250    //
251    // NOTE: Using coalesce(max(row_id), 0) instead of count(*) because count(*) is super slow for large tables.
252    // The coalesce(max(row_id), 0) returns 0 if the table is empty, otherwise it returns the max row_id,
253    // which corresponds to the number of rows in the table.
254    // The main assumption, which always holds true, is that the row_id is always increasing and there are no gaps.
255    // This is true for all the datasets that we are using, otherwise row_id will not correspond to the index.
256    let mut statement =
257        connection.prepare(format!("select coalesce(max(row_id), 0) from {split}").as_str())?;
258
259    let len = statement.query_row([], |row| {
260        let len: usize = row.get(0)?;
261        Ok(len)
262    })?;
263    Ok((columns, len))
264}
265
266/// Helper function to create a connection pool
267fn create_conn_pool<P: AsRef<Path>>(
268    db_file: P,
269    write: bool,
270) -> Result<Pool<SqliteConnectionManager>> {
271    let sqlite_flags = if write {
272        OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE
273    } else {
274        OpenFlags::SQLITE_OPEN_READ_ONLY
275    };
276
277    let manager = SqliteConnectionManager::file(db_file).with_flags(sqlite_flags);
278    Pool::new(manager).map_err(SqliteDatasetError::ConnectionPool)
279}
280
281/// The `SqliteDatasetStorage` struct represents a SQLite database for storing datasets.
282/// It consists of an optional name, a database file path, and a base directory for storage.
283#[derive(Clone, Debug)]
284pub struct SqliteDatasetStorage {
285    name: Option<String>,
286    db_file: Option<PathBuf>,
287    base_dir: Option<PathBuf>,
288}
289
290impl SqliteDatasetStorage {
291    /// Creates a new instance of `SqliteDatasetStorage` using a dataset name.
292    ///
293    /// # Arguments
294    ///
295    /// * `name` - A string slice that holds the name of the dataset.
296    pub fn from_name(name: &str) -> Self {
297        SqliteDatasetStorage {
298            name: Some(name.to_string()),
299            db_file: None,
300            base_dir: None,
301        }
302    }
303
304    /// Creates a new instance of `SqliteDatasetStorage` using a database file path.
305    ///
306    /// # Arguments
307    ///
308    /// * `db_file` - A reference to the Path that represents the database file path.
309    pub fn from_file<P: AsRef<Path>>(db_file: P) -> Self {
310        SqliteDatasetStorage {
311            name: None,
312            db_file: Some(db_file.as_ref().to_path_buf()),
313            base_dir: None,
314        }
315    }
316
317    /// Sets the base directory for storing the dataset.
318    ///
319    /// # Arguments
320    ///
321    /// * `base_dir` - A string slice that represents the base directory.
322    pub fn with_base_dir<P: AsRef<Path>>(mut self, base_dir: P) -> Self {
323        self.base_dir = Some(base_dir.as_ref().to_path_buf());
324        self
325    }
326
327    /// Checks if the database file exists in the given path.
328    ///
329    /// # Returns
330    ///
331    /// * A boolean value indicating whether the file exists or not.
332    pub fn exists(&self) -> bool {
333        self.db_file().exists()
334    }
335
336    /// Fetches the database file path.
337    ///
338    /// # Returns
339    ///
340    /// * A `PathBuf` instance representing the file path.
341    pub fn db_file(&self) -> PathBuf {
342        match &self.db_file {
343            Some(db_file) => db_file.clone(),
344            None => {
345                let name = sanitize(self.name.as_ref().expect("Name is not set"));
346                Self::base_dir(self.base_dir.to_owned()).join(format!("{name}.db"))
347            }
348        }
349    }
350
351    /// Determines the base directory for storing the dataset.
352    ///
353    /// # Arguments
354    ///
355    /// * `base_dir` - An `Option` that may contain a `PathBuf` instance representing the base directory.
356    ///
357    /// # Returns
358    ///
359    /// * A `PathBuf` instance representing the base directory.
360    pub fn base_dir(base_dir: Option<PathBuf>) -> PathBuf {
361        match base_dir {
362            Some(base_dir) => base_dir,
363            None => {
364                let home_dir = dirs::home_dir().expect("Could not get home directory");
365
366                home_dir.join(".cache").join("burn-dataset")
367            }
368        }
369    }
370
371    /// Provides a writer instance for the SQLite dataset.
372    ///
373    /// # Arguments
374    ///
375    /// * `overwrite` - A boolean indicating if the existing database file should be overwritten.
376    ///
377    /// # Returns
378    ///
379    /// * A `Result` which is `Ok` if the writer could be created, `Err` otherwise.
380    pub fn writer<I>(&self, overwrite: bool) -> Result<SqliteDatasetWriter<I>>
381    where
382        I: Clone + Send + Sync + Serialize + DeserializeOwned,
383    {
384        SqliteDatasetWriter::new(self.db_file(), overwrite)
385    }
386
387    /// Provides a reader instance for the SQLite dataset.
388    ///
389    /// # Arguments
390    ///
391    /// * `split` - A string slice that defines the data split for reading (e.g., "train", "test").
392    ///
393    /// # Returns
394    ///
395    /// * A `Result` which is `Ok` if the reader could be created, `Err` otherwise.
396    pub fn reader<I>(&self, split: &str) -> Result<SqliteDataset<I>>
397    where
398        I: Clone + Send + Sync + Serialize + DeserializeOwned,
399    {
400        if !self.exists() {
401            panic!("The database file does not exist");
402        }
403
404        SqliteDataset::from_db_file(self.db_file(), split)
405    }
406}
407
408/// This `SqliteDatasetWriter` struct is a SQLite database writer dedicated to storing datasets.
409/// It retains the current writer's state and its database connection.
410///
411/// Being thread-safe, this writer can be concurrently used across multiple threads.
412///
413/// Typical applications include:
414///
415/// - Generation of a new dataset
416/// - Storage of preprocessed data or metadata
417/// - Enlargement of a dataset's item count post preprocessing
418#[derive(Debug)]
419pub struct SqliteDatasetWriter<I> {
420    db_file: PathBuf,
421    db_file_tmp: Option<Handle<Writable>>,
422    splits: Arc<RwLock<HashSet<String>>>,
423    overwrite: bool,
424    conn_pool: Option<Pool<SqliteConnectionManager>>,
425    is_completed: Arc<RwLock<bool>>,
426    phantom: PhantomData<I>,
427}
428
429impl<I> SqliteDatasetWriter<I>
430where
431    I: Clone + Send + Sync + Serialize + DeserializeOwned,
432{
433    /// Creates a new instance of `SqliteDatasetWriter`.
434    ///
435    /// # Arguments
436    ///
437    /// * `db_file` - A reference to the Path that represents the database file path.
438    /// * `overwrite` - A boolean indicating if the existing database file should be overwritten.
439    ///
440    /// # Returns
441    ///
442    /// * A `Result` which is `Ok` if the writer could be created, `Err` otherwise.
443    pub fn new<P: AsRef<Path>>(db_file: P, overwrite: bool) -> Result<Self> {
444        let writer = Self {
445            db_file: db_file.as_ref().to_path_buf(),
446            db_file_tmp: None,
447            splits: Arc::new(RwLock::new(HashSet::new())),
448            overwrite,
449            conn_pool: None,
450            is_completed: Arc::new(RwLock::new(false)),
451            phantom: PhantomData,
452        };
453
454        writer.init()
455    }
456
457    /// Initializes the dataset writer by creating the database file, tables, and connection pool.
458    ///
459    /// # Returns
460    ///
461    /// * A `Result` which is `Ok` if the writer could be initialized, `Err` otherwise.
462    fn init(mut self) -> Result<Self> {
463        // Remove the db file if it already exists
464        if self.db_file.exists() {
465            if self.overwrite {
466                fs::remove_file(&self.db_file)?;
467            } else {
468                return Err(SqliteDatasetError::FileExists(self.db_file));
469            }
470        }
471
472        // Create the database file directory if it does not exist
473        let db_file_dir = self
474            .db_file
475            .parent()
476            .ok_or("Unable to get parent directory")?;
477
478        if !db_file_dir.exists() {
479            fs::create_dir_all(db_file_dir)?;
480        }
481
482        // Create a temp database file name as {base_dir}/{name}.db.tmp
483        let mut db_file_tmp = self.db_file.clone();
484        db_file_tmp.set_extension("db.tmp");
485        if db_file_tmp.exists() {
486            fs::remove_file(&db_file_tmp)?;
487        }
488
489        // Create the temp database file and wrap it with a gix_tempfile::Handle
490        // This will ensure that the temp file is deleted when the writer is dropped
491        // or when process exits with SIGINT or SIGTERM (tempfile crate does not do this)
492        gix_tempfile::signal::setup(Default::default());
493        self.db_file_tmp = Some(gix_tempfile::writable_at(
494            &db_file_tmp,
495            ContainingDirectory::Exists,
496            AutoRemove::Tempfile,
497        )?);
498
499        let conn_pool = create_conn_pool(db_file_tmp, true)?;
500        self.conn_pool = Some(conn_pool);
501
502        Ok(self)
503    }
504
505    /// Serializes and writes an item to the database. The item is written to the table for the
506    /// specified split. If the table does not exist, it is created. If the table exists, the item
507    /// is appended to the table. The serialization is done using the [MessagePack](https://msgpack.org/)
508    ///
509    /// # Arguments
510    ///
511    /// * `split` - A string slice that defines the data split for writing (e.g., "train", "test").
512    /// * `item` - A reference to the item to be written to the database.
513    ///
514    /// # Returns
515    ///
516    /// * A `Result` containing the index of the inserted row if successful, an error otherwise.
517    pub fn write(&self, split: &str, item: &I) -> Result<usize> {
518        // Acquire the read lock (wont't block other reads)
519        let is_completed = self.is_completed.read().unwrap();
520
521        // If the writer is completed, return an error
522        if *is_completed {
523            return Err(SqliteDatasetError::Other(
524                "Cannot save to a completed dataset writer",
525            ));
526        }
527
528        // create the table for the split if it does not exist
529        if !self.splits.read().unwrap().contains(split) {
530            self.create_table(split)?;
531        }
532
533        // Get a connection from the pool
534        let conn_pool = self.conn_pool.as_ref().unwrap();
535        let conn = conn_pool.get()?;
536
537        // Serialize the item using MessagePack
538        let serialized_item = rmp_serde::to_vec(item)?;
539
540        // Turn off the synchronous and journal mode for speed up
541        // We are sacrificing durability for speed but it's okay because
542        // we always recreate the dataset if it is not completed.
543        pragma_update_with_error_handling(&conn, "synchronous", "OFF")?;
544        pragma_update_with_error_handling(&conn, "journal_mode", "OFF")?;
545
546        // Insert the serialized item into the database
547        let insert_statement = format!("insert into {split} (item) values (?)");
548        conn.execute(insert_statement.as_str(), [serialized_item])?;
549
550        // Get the primary key of the last inserted row and convert to index (row_id-1)
551        let index = (conn.last_insert_rowid() - 1) as usize;
552
553        Ok(index)
554    }
555
556    /// Marks the dataset as completed and persists the temporary database file.
557    pub fn set_completed(&mut self) -> Result<()> {
558        let mut is_completed = self.is_completed.write().unwrap();
559
560        // Force close the connection pool
561        // This is required on Windows platform where the connection pool prevents
562        // from persisting the db by renaming the temp file.
563        if let Some(pool) = self.conn_pool.take() {
564            std::mem::drop(pool);
565        }
566
567        // Rename the database file from tmp to db
568        let _file_result = self
569            .db_file_tmp
570            .take() // take ownership of the temporary file and set to None
571            .unwrap() // unwrap the temporary file
572            .persist(&self.db_file)?
573            .ok_or("Unable to persist the database file")?;
574
575        *is_completed = true;
576        Ok(())
577    }
578
579    /// Creates table for the data split.
580    ///
581    /// Note: call is idempotent and thread-safe.
582    ///
583    /// # Arguments
584    ///
585    /// * `split` - A string slice that defines the data split for the table (e.g., "train", "test").
586    ///
587    /// # Returns
588    ///
589    /// * A `Result` which is `Ok` if the table could be created, `Err` otherwise.
590    ///
591    /// TODO (@antimora): add support creating a table with columns corresponding to the item fields
592    fn create_table(&self, split: &str) -> Result<()> {
593        // Check if the split already exists
594        if self.splits.read().unwrap().contains(split) {
595            return Ok(());
596        }
597
598        let conn_pool = self.conn_pool.as_ref().unwrap();
599        let connection = conn_pool.get()?;
600        let create_table_statement = format!(
601            "create table if not exists  {split} (row_id integer primary key autoincrement not \
602             null, item blob not null)"
603        );
604
605        connection.execute(create_table_statement.as_str(), [])?;
606
607        // Add the split to the splits
608        self.splits.write().unwrap().insert(split.to_string());
609
610        Ok(())
611    }
612}
613
614/// Runs a pragma update and ignores the `ExecuteReturnedResults` error.
615///
616/// Sometimes ExecuteReturnedResults is returned when running a pragma update. This is not an error
617/// and can be ignored. This function runs the pragma update and ignores the error if it is
618/// `ExecuteReturnedResults`.
619fn pragma_update_with_error_handling(
620    conn: &PooledConnection<SqliteConnectionManager>,
621    setting: &str,
622    value: &str,
623) -> Result<()> {
624    let result = conn.pragma_update(None, setting, value);
625    if let Err(error) = result {
626        if error != rusqlite::Error::ExecuteReturnedResults {
627            return Err(SqliteDatasetError::Sql(error));
628        }
629    }
630    Ok(())
631}
632
633#[cfg(test)]
634mod tests {
635    use rayon::prelude::*;
636    use rstest::{fixture, rstest};
637    use serde::{Deserialize, Serialize};
638    use tempfile::{NamedTempFile, TempDir, tempdir};
639
640    use super::*;
641
642    type SqlDs = SqliteDataset<Sample>;
643
644    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
645    pub struct Sample {
646        column_str: String,
647        column_bytes: Vec<u8>,
648        column_int: i64,
649        column_bool: bool,
650        column_float: f64,
651    }
652
653    #[fixture]
654    fn train_dataset() -> SqlDs {
655        SqliteDataset::<Sample>::from_db_file("tests/data/sqlite-dataset.db", "train").unwrap()
656    }
657
658    #[rstest]
659    pub fn len(train_dataset: SqlDs) {
660        assert_eq!(train_dataset.len(), 2);
661    }
662
663    #[rstest]
664    pub fn get_some(train_dataset: SqlDs) {
665        let item = train_dataset.get(0).unwrap();
666        assert_eq!(item.column_str, "HI1");
667        assert_eq!(item.column_bytes, vec![55, 231, 159]);
668        assert_eq!(item.column_int, 1);
669        assert!(item.column_bool);
670        assert_eq!(item.column_float, 1.0);
671    }
672
673    #[rstest]
674    pub fn get_none(train_dataset: SqlDs) {
675        assert_eq!(train_dataset.get(10), None);
676    }
677
678    #[rstest]
679    pub fn multi_thread(train_dataset: SqlDs) {
680        let indices: Vec<usize> = vec![0, 1, 1, 3, 4, 5, 6, 0, 8, 1];
681        let results: Vec<Option<Sample>> =
682            indices.par_iter().map(|&i| train_dataset.get(i)).collect();
683
684        let mut match_count = 0;
685        for (_index, result) in indices.iter().zip(results.iter()) {
686            if let Some(_val) = result {
687                match_count += 1
688            }
689        }
690
691        assert_eq!(match_count, 5);
692    }
693
694    #[test]
695    fn sqlite_dataset_storage() {
696        // Test with non-existing file
697        let storage = SqliteDatasetStorage::from_file("non-existing.db");
698        assert!(!storage.exists());
699
700        // Test with non-existing name
701        let storage = SqliteDatasetStorage::from_name("non-existing.db");
702        assert!(!storage.exists());
703
704        // Test with existing file
705        let storage = SqliteDatasetStorage::from_file("tests/data/sqlite-dataset.db");
706        assert!(storage.exists());
707        let result = storage.reader::<Sample>("train");
708        assert!(result.is_ok());
709        let train = result.unwrap();
710        assert_eq!(train.len(), 2);
711
712        // Test get writer
713        let temp_file = NamedTempFile::new().unwrap();
714        let storage = SqliteDatasetStorage::from_file(temp_file.path());
715        assert!(storage.exists());
716        let result = storage.writer::<Sample>(true);
717        assert!(result.is_ok());
718    }
719
720    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
721    pub struct Complex {
722        column_str: String,
723        column_bytes: Vec<u8>,
724        column_int: i64,
725        column_bool: bool,
726        column_float: f64,
727        column_complex: Vec<Vec<Vec<[u8; 3]>>>,
728    }
729
730    /// Create a temporary directory.
731    #[fixture]
732    fn tmp_dir() -> TempDir {
733        // Create a TempDir. This object will be automatically
734        // deleted when it goes out of scope.
735        tempdir().unwrap()
736    }
737    type Writer = SqliteDatasetWriter<Complex>;
738
739    /// Create a SqliteDatasetWriter with a temporary directory.
740    /// Make sure to return the temporary directory so that it is not deleted.
741    #[fixture]
742    fn writer_fixture(tmp_dir: TempDir) -> (Writer, TempDir) {
743        let temp_dir_str = tmp_dir.path();
744        let storage = SqliteDatasetStorage::from_name("preprocessed").with_base_dir(temp_dir_str);
745        let overwrite = true;
746        let result = storage.writer::<Complex>(overwrite);
747        assert!(result.is_ok());
748        let writer = result.unwrap();
749        (writer, tmp_dir)
750    }
751
752    #[test]
753    fn test_new() {
754        // Test that the constructor works with overwrite = true
755        let test_path = NamedTempFile::new().unwrap();
756        let _writer = SqliteDatasetWriter::<Complex>::new(&test_path, true).unwrap();
757        assert!(!test_path.path().exists());
758
759        // Test that the constructor works with overwrite = false
760        let test_path = NamedTempFile::new().unwrap();
761        let result = SqliteDatasetWriter::<Complex>::new(&test_path, false);
762        assert!(result.is_err());
763
764        // Test that the constructor works with no existing file
765        let temp = NamedTempFile::new().unwrap();
766        let test_path = temp.path().to_path_buf();
767        assert!(temp.close().is_ok());
768        assert!(!test_path.exists());
769        let _writer = SqliteDatasetWriter::<Complex>::new(&test_path, true).unwrap();
770        assert!(!test_path.exists());
771    }
772
773    #[rstest]
774    pub fn sqlite_writer_write(writer_fixture: (Writer, TempDir)) {
775        // Get the dataset_saver from the fixture and tmp_dir (will be deleted after scope)
776        let (writer, _tmp_dir) = writer_fixture;
777
778        assert!(writer.overwrite);
779        assert!(!writer.db_file.exists());
780
781        let new_item = Complex {
782            column_str: "HI1".to_string(),
783            column_bytes: vec![1_u8, 2, 3],
784            column_int: 0,
785            column_bool: true,
786            column_float: 1.0,
787            column_complex: vec![vec![vec![[1, 23_u8, 3]]]],
788        };
789
790        let index = writer.write("train", &new_item).unwrap();
791        assert_eq!(index, 0);
792
793        let mut writer = writer;
794
795        writer.set_completed().expect("Failed to set completed");
796
797        assert!(writer.db_file.exists());
798        assert!(writer.db_file_tmp.is_none());
799
800        let result = writer.write("train", &new_item);
801
802        // Should fail because the writer is completed
803        assert!(result.is_err());
804
805        let dataset = SqliteDataset::<Complex>::from_db_file(writer.db_file, "train").unwrap();
806
807        let fetched_item = dataset.get(0).unwrap();
808        assert_eq!(fetched_item, new_item);
809        assert_eq!(dataset.len(), 1);
810    }
811
812    #[rstest]
813    pub fn sqlite_writer_write_multi_thread(writer_fixture: (Writer, TempDir)) {
814        // Get the dataset_saver from the fixture and tmp_dir (will be deleted after scope)
815        let (writer, _tmp_dir) = writer_fixture;
816
817        let writer = Arc::new(writer);
818        let record_count = 20;
819
820        let splits = ["train", "test"];
821
822        (0..record_count).into_par_iter().for_each(|index: i64| {
823            let thread_id: std::thread::ThreadId = std::thread::current().id();
824            let sample = Complex {
825                column_str: format!("test_{thread_id:?}_{index}"),
826                column_bytes: vec![index as u8, 2, 3],
827                column_int: index,
828                column_bool: true,
829                column_float: 1.0,
830                column_complex: vec![vec![vec![[1, index as u8, 3]]]],
831            };
832
833            // half for train and half for test
834            let split = splits[index as usize % 2];
835
836            let _index = writer.write(split, &sample).unwrap();
837        });
838
839        let mut writer = Arc::try_unwrap(writer).unwrap();
840
841        writer
842            .set_completed()
843            .expect("Should set completed successfully");
844
845        let train =
846            SqliteDataset::<Complex>::from_db_file(writer.db_file.clone(), "train").unwrap();
847        let test = SqliteDataset::<Complex>::from_db_file(writer.db_file, "test").unwrap();
848
849        assert_eq!(train.len(), record_count as usize / 2);
850        assert_eq!(test.len(), record_count as usize / 2);
851    }
852}