datacake_sqlite/
db.rs

1use std::path::Path;
2
3use flume::{self, Receiver, Sender};
4use futures::channel::oneshot;
5use rusqlite::{Connection, OptionalExtension, Params, Row};
6
7type Task = Box<dyn FnOnce(&mut Connection) + Send + 'static>;
8
9const CAPACITY: usize = 10;
10
11#[derive(Debug, Clone)]
12/// A asynchronous wrapper around a SQLite database.
13///
14/// These operations will be ran in a background thread preventing
15/// any IO operations from blocking the async context.
16pub struct StorageHandle {
17    tx: Sender<Task>,
18}
19
20impl StorageHandle {
21    /// Connects to the SQLite database.
22    ///
23    /// This spawns 1 background threads with actions being executed within that thread.
24    ///
25    /// This approach reduces the affect of writes blocking reads and vice-versa.
26    pub async fn open(path: impl AsRef<Path>) -> rusqlite::Result<Self> {
27        let tx = setup_database(path).await?;
28        Ok(Self { tx })
29    }
30
31    /// Connects to a new in-memory SQLite database.
32    pub async fn open_in_memory() -> rusqlite::Result<Self> {
33        Self::open(":memory:").await
34    }
35
36    /// Execute a SQL statement with some provided parameters.
37    pub async fn execute<P>(
38        &self,
39        sql: impl AsRef<str>,
40        params: P,
41    ) -> rusqlite::Result<usize>
42    where
43        P: Params + Clone + Send + 'static,
44    {
45        let sql = sql.as_ref().to_string();
46        self.submit_task(move |conn| {
47            let mut prepared = conn.prepare_cached(&sql)?;
48            prepared.execute(params)
49        })
50        .await
51    }
52
53    /// Execute a SQL statement several times with some provided parameters.
54    ///
55    /// The statement is executed within the same transaction.
56    pub async fn execute_many<P>(
57        &self,
58        sql: impl AsRef<str>,
59        param_set: Vec<P>,
60    ) -> rusqlite::Result<usize>
61    where
62        P: Params + Clone + Send + 'static,
63    {
64        let sql = sql.as_ref().to_string();
65        self.submit_task(move |conn| {
66            let tx = conn.transaction()?;
67
68            let mut total = 0;
69            {
70                let mut prepared = tx.prepare_cached(&sql)?;
71
72                for params in param_set {
73                    total += prepared.execute(params)?;
74                }
75            }
76
77            tx.commit()?;
78            Ok(total)
79        })
80        .await
81    }
82
83    /// Fetch a single row from a given SQL statement with some provided parameters.
84    pub async fn fetch_one<P, T>(
85        &self,
86        sql: impl AsRef<str>,
87        params: P,
88    ) -> rusqlite::Result<Option<T>>
89    where
90        P: Params + Send + 'static,
91        T: FromRow + Send + 'static,
92    {
93        let sql = sql.as_ref().to_string();
94
95        self.submit_task(move |conn| {
96            let mut prepared = conn.prepare_cached(&sql)?;
97            prepared.query_row(params, T::from_row).optional()
98        })
99        .await
100    }
101
102    /// Fetch a many rows by executing the same statement on several sets of parameters.
103    pub async fn fetch_many<P, T>(
104        &self,
105        sql: impl AsRef<str>,
106        param_sets: Vec<P>,
107    ) -> rusqlite::Result<Vec<T>>
108    where
109        P: Params + Send + 'static,
110        T: FromRow + Send + 'static,
111    {
112        let sql = sql.as_ref().to_string();
113
114        self.submit_task(move |conn| {
115            let mut prepared = conn.prepare_cached(&sql)?;
116            let mut rows = Vec::with_capacity(param_sets.len());
117
118            for params in param_sets {
119                if let Some(row) = prepared.query_row(params, T::from_row).optional()? {
120                    rows.push(row);
121                }
122            }
123
124            Ok(rows)
125        })
126        .await
127    }
128
129    /// Fetch a all rows from a given SQL statement with some provided parameters.
130    pub async fn fetch_all<P, T>(
131        &self,
132        sql: impl AsRef<str>,
133        params: P,
134    ) -> rusqlite::Result<Vec<T>>
135    where
136        P: Params + Send + 'static,
137        T: FromRow + Send + 'static,
138    {
139        let sql = sql.as_ref().to_string();
140
141        self.submit_task(move |conn| {
142            let mut prepared = conn.prepare_cached(&sql)?;
143            let mut iter = prepared.query(params)?;
144
145            let mut rows = Vec::with_capacity(4);
146            while let Some(row) = iter.next()? {
147                rows.push(T::from_row(row)?);
148            }
149
150            Ok(rows)
151        })
152        .await
153    }
154
155    /// Submits a writer task to execute.
156    ///
157    /// This executes the callback on the memory view connection which should be
158    /// significantly faster to modify or read.
159    async fn submit_task<CB, T>(&self, inner: CB) -> rusqlite::Result<T>
160    where
161        T: Send + 'static,
162        CB: FnOnce(&mut Connection) -> rusqlite::Result<T> + Send + 'static,
163    {
164        let (tx, rx) = oneshot::channel();
165
166        let cb = move |conn: &mut Connection| {
167            let res = inner(conn);
168            let _ = tx.send(res);
169        };
170
171        self.tx
172            .send_async(Box::new(cb))
173            .await
174            .expect("send message");
175
176        rx.await.unwrap()
177    }
178}
179
180/// A helper trait for converting between a Row reference and the given type.
181///
182/// This is required due to the nature of rows being tied to the database connection
183/// which cannot be shared outside of the thread the actor runs in.
184pub trait FromRow: Sized {
185    fn from_row(row: &Row) -> rusqlite::Result<Self>;
186}
187
188async fn setup_database(path: impl AsRef<Path>) -> rusqlite::Result<Sender<Task>> {
189    let path = path.as_ref().to_path_buf();
190    let (tx, rx) = flume::bounded(CAPACITY);
191
192    tokio::task::spawn_blocking(move || setup_disk_handle(&path, rx))
193        .await
194        .expect("spawn background runner")?;
195
196    Ok(tx)
197}
198
199fn setup_disk_handle(path: &Path, tasks: Receiver<Task>) -> rusqlite::Result<()> {
200    let disk = Connection::open(path)?;
201
202    disk.query_row("pragma journal_mode = WAL;", (), |_r| Ok(()))?;
203    disk.execute("pragma synchronous = normal;", ())?;
204    disk.execute("pragma temp_store = memory;", ())?;
205
206    std::thread::spawn(move || run_tasks(disk, tasks));
207
208    Ok(())
209}
210
211/// Runs all tasks received with a mutable reference to the given connection.
212fn run_tasks(mut conn: Connection, tasks: Receiver<Task>) {
213    while let Ok(task) = tasks.recv() {
214        (task)(&mut conn);
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use std::env::temp_dir;
221
222    use super::*;
223
224    #[tokio::test]
225    async fn test_memory_storage_handle() {
226        let handle = StorageHandle::open_in_memory().await.expect("open DB");
227
228        run_storage_handle_suite(handle).await;
229    }
230
231    #[tokio::test]
232    async fn test_disk_storage_handle() {
233        let path = temp_dir().join(uuid::Uuid::new_v4().to_string());
234        let handle = StorageHandle::open(path).await.expect("open DB");
235
236        run_storage_handle_suite(handle).await;
237    }
238
239    #[derive(Debug, Eq, PartialEq)]
240    struct Person {
241        id: i32,
242        name: String,
243        data: String,
244    }
245
246    impl FromRow for Person {
247        fn from_row(row: &Row) -> rusqlite::Result<Self> {
248            Ok(Self {
249                id: row.get(0)?,
250                name: row.get(1)?,
251                data: row.get(2)?,
252            })
253        }
254    }
255
256    async fn run_storage_handle_suite(handle: StorageHandle) {
257        handle
258            .execute(
259                "CREATE TABLE person (
260                    id    INTEGER PRIMARY KEY,
261                    name  TEXT NOT NULL,
262                    data  BLOB
263                )",
264                (), // empty list of parameters.
265            )
266            .await
267            .expect("create table");
268
269        let res = handle
270            .fetch_one::<_, Person>("SELECT id, name, data FROM person;", ())
271            .await
272            .expect("execute statement");
273        assert!(res.is_none(), "Expected no rows to be returned.");
274
275        handle
276            .execute(
277                "INSERT INTO person (id, name, data) VALUES (1, 'cf8', 'tada');",
278                (),
279            )
280            .await
281            .expect("Insert row");
282
283        let res = handle
284            .fetch_one::<_, Person>("SELECT id, name, data FROM person;", ())
285            .await
286            .expect("execute statement");
287        assert_eq!(
288            res,
289            Some(Person {
290                id: 1,
291                name: "cf8".to_string(),
292                data: "tada".to_string()
293            }),
294        );
295
296        handle
297            .execute(
298                "INSERT INTO person (id, name, data) VALUES (2, 'cf6', 'tada2');",
299                (),
300            )
301            .await
302            .expect("Insert row");
303
304        let res = handle
305            .fetch_all::<_, Person>(
306                "SELECT id, name, data FROM person ORDER BY id ASC;",
307                (),
308            )
309            .await
310            .expect("execute statement");
311        assert_eq!(
312            res,
313            vec![
314                Person {
315                    id: 1,
316                    name: "cf8".to_string(),
317                    data: "tada".to_string()
318                },
319                Person {
320                    id: 2,
321                    name: "cf6".to_string(),
322                    data: "tada2".to_string()
323                },
324            ],
325        );
326    }
327}