1use std::path::Path;
2
3use flume::{self, Receiver, Sender};
4use futures::channel::oneshot;
5use rusqlite::{Connection, OptionalExtension, Params, Row};
6
7type Task = Box<dyn FnOnce(&mut Connection) + Send + 'static>;
8
9const CAPACITY: usize = 10;
10
11#[derive(Debug, Clone)]
12pub struct StorageHandle {
17 tx: Sender<Task>,
18}
19
20impl StorageHandle {
21 pub async fn open(path: impl AsRef<Path>) -> rusqlite::Result<Self> {
27 let tx = setup_database(path).await?;
28 Ok(Self { tx })
29 }
30
31 pub async fn open_in_memory() -> rusqlite::Result<Self> {
33 Self::open(":memory:").await
34 }
35
36 pub async fn execute<P>(
38 &self,
39 sql: impl AsRef<str>,
40 params: P,
41 ) -> rusqlite::Result<usize>
42 where
43 P: Params + Clone + Send + 'static,
44 {
45 let sql = sql.as_ref().to_string();
46 self.submit_task(move |conn| {
47 let mut prepared = conn.prepare_cached(&sql)?;
48 prepared.execute(params)
49 })
50 .await
51 }
52
53 pub async fn execute_many<P>(
57 &self,
58 sql: impl AsRef<str>,
59 param_set: Vec<P>,
60 ) -> rusqlite::Result<usize>
61 where
62 P: Params + Clone + Send + 'static,
63 {
64 let sql = sql.as_ref().to_string();
65 self.submit_task(move |conn| {
66 let tx = conn.transaction()?;
67
68 let mut total = 0;
69 {
70 let mut prepared = tx.prepare_cached(&sql)?;
71
72 for params in param_set {
73 total += prepared.execute(params)?;
74 }
75 }
76
77 tx.commit()?;
78 Ok(total)
79 })
80 .await
81 }
82
83 pub async fn fetch_one<P, T>(
85 &self,
86 sql: impl AsRef<str>,
87 params: P,
88 ) -> rusqlite::Result<Option<T>>
89 where
90 P: Params + Send + 'static,
91 T: FromRow + Send + 'static,
92 {
93 let sql = sql.as_ref().to_string();
94
95 self.submit_task(move |conn| {
96 let mut prepared = conn.prepare_cached(&sql)?;
97 prepared.query_row(params, T::from_row).optional()
98 })
99 .await
100 }
101
102 pub async fn fetch_many<P, T>(
104 &self,
105 sql: impl AsRef<str>,
106 param_sets: Vec<P>,
107 ) -> rusqlite::Result<Vec<T>>
108 where
109 P: Params + Send + 'static,
110 T: FromRow + Send + 'static,
111 {
112 let sql = sql.as_ref().to_string();
113
114 self.submit_task(move |conn| {
115 let mut prepared = conn.prepare_cached(&sql)?;
116 let mut rows = Vec::with_capacity(param_sets.len());
117
118 for params in param_sets {
119 if let Some(row) = prepared.query_row(params, T::from_row).optional()? {
120 rows.push(row);
121 }
122 }
123
124 Ok(rows)
125 })
126 .await
127 }
128
129 pub async fn fetch_all<P, T>(
131 &self,
132 sql: impl AsRef<str>,
133 params: P,
134 ) -> rusqlite::Result<Vec<T>>
135 where
136 P: Params + Send + 'static,
137 T: FromRow + Send + 'static,
138 {
139 let sql = sql.as_ref().to_string();
140
141 self.submit_task(move |conn| {
142 let mut prepared = conn.prepare_cached(&sql)?;
143 let mut iter = prepared.query(params)?;
144
145 let mut rows = Vec::with_capacity(4);
146 while let Some(row) = iter.next()? {
147 rows.push(T::from_row(row)?);
148 }
149
150 Ok(rows)
151 })
152 .await
153 }
154
155 async fn submit_task<CB, T>(&self, inner: CB) -> rusqlite::Result<T>
160 where
161 T: Send + 'static,
162 CB: FnOnce(&mut Connection) -> rusqlite::Result<T> + Send + 'static,
163 {
164 let (tx, rx) = oneshot::channel();
165
166 let cb = move |conn: &mut Connection| {
167 let res = inner(conn);
168 let _ = tx.send(res);
169 };
170
171 self.tx
172 .send_async(Box::new(cb))
173 .await
174 .expect("send message");
175
176 rx.await.unwrap()
177 }
178}
179
180pub trait FromRow: Sized {
185 fn from_row(row: &Row) -> rusqlite::Result<Self>;
186}
187
188async fn setup_database(path: impl AsRef<Path>) -> rusqlite::Result<Sender<Task>> {
189 let path = path.as_ref().to_path_buf();
190 let (tx, rx) = flume::bounded(CAPACITY);
191
192 tokio::task::spawn_blocking(move || setup_disk_handle(&path, rx))
193 .await
194 .expect("spawn background runner")?;
195
196 Ok(tx)
197}
198
199fn setup_disk_handle(path: &Path, tasks: Receiver<Task>) -> rusqlite::Result<()> {
200 let disk = Connection::open(path)?;
201
202 disk.query_row("pragma journal_mode = WAL;", (), |_r| Ok(()))?;
203 disk.execute("pragma synchronous = normal;", ())?;
204 disk.execute("pragma temp_store = memory;", ())?;
205
206 std::thread::spawn(move || run_tasks(disk, tasks));
207
208 Ok(())
209}
210
211fn run_tasks(mut conn: Connection, tasks: Receiver<Task>) {
213 while let Ok(task) = tasks.recv() {
214 (task)(&mut conn);
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use std::env::temp_dir;
221
222 use super::*;
223
224 #[tokio::test]
225 async fn test_memory_storage_handle() {
226 let handle = StorageHandle::open_in_memory().await.expect("open DB");
227
228 run_storage_handle_suite(handle).await;
229 }
230
231 #[tokio::test]
232 async fn test_disk_storage_handle() {
233 let path = temp_dir().join(uuid::Uuid::new_v4().to_string());
234 let handle = StorageHandle::open(path).await.expect("open DB");
235
236 run_storage_handle_suite(handle).await;
237 }
238
239 #[derive(Debug, Eq, PartialEq)]
240 struct Person {
241 id: i32,
242 name: String,
243 data: String,
244 }
245
246 impl FromRow for Person {
247 fn from_row(row: &Row) -> rusqlite::Result<Self> {
248 Ok(Self {
249 id: row.get(0)?,
250 name: row.get(1)?,
251 data: row.get(2)?,
252 })
253 }
254 }
255
256 async fn run_storage_handle_suite(handle: StorageHandle) {
257 handle
258 .execute(
259 "CREATE TABLE person (
260 id INTEGER PRIMARY KEY,
261 name TEXT NOT NULL,
262 data BLOB
263 )",
264 (), )
266 .await
267 .expect("create table");
268
269 let res = handle
270 .fetch_one::<_, Person>("SELECT id, name, data FROM person;", ())
271 .await
272 .expect("execute statement");
273 assert!(res.is_none(), "Expected no rows to be returned.");
274
275 handle
276 .execute(
277 "INSERT INTO person (id, name, data) VALUES (1, 'cf8', 'tada');",
278 (),
279 )
280 .await
281 .expect("Insert row");
282
283 let res = handle
284 .fetch_one::<_, Person>("SELECT id, name, data FROM person;", ())
285 .await
286 .expect("execute statement");
287 assert_eq!(
288 res,
289 Some(Person {
290 id: 1,
291 name: "cf8".to_string(),
292 data: "tada".to_string()
293 }),
294 );
295
296 handle
297 .execute(
298 "INSERT INTO person (id, name, data) VALUES (2, 'cf6', 'tada2');",
299 (),
300 )
301 .await
302 .expect("Insert row");
303
304 let res = handle
305 .fetch_all::<_, Person>(
306 "SELECT id, name, data FROM person ORDER BY id ASC;",
307 (),
308 )
309 .await
310 .expect("execute statement");
311 assert_eq!(
312 res,
313 vec![
314 Person {
315 id: 1,
316 name: "cf8".to_string(),
317 data: "tada".to_string()
318 },
319 Person {
320 id: 2,
321 name: "cf6".to_string(),
322 data: "tada2".to_string()
323 },
324 ],
325 );
326 }
327}