Skip to main content

async_sqlite/
client.rs

1use std::{
2    path::{Path, PathBuf},
3    thread,
4};
5
6use crate::Error;
7
8use crossbeam_channel::{bounded, unbounded, Receiver, Sender, TrySendError};
9use futures_channel::oneshot;
10use rusqlite::{Connection, OpenFlags};
11
12/// A `ClientBuilder` can be used to create a [`Client`] with custom
13/// configuration.
14///
15/// For more information on creating a sqlite connection, see the
16/// [rusqlite docs](rusqlite::Connection::open()).
17///
18/// # Examples
19///
20/// ```rust
21/// # use async_sqlite::ClientBuilder;
22/// # async fn run() -> Result<(), async_sqlite::Error> {
23/// let client = ClientBuilder::new().path("path/to/db.sqlite3").open().await?;
24///
25/// // ...
26///
27/// client.close().await?;
28/// # Ok(())
29/// # }
30/// ```
31#[derive(Clone, Debug, Default)]
32pub struct ClientBuilder {
33    pub(crate) path: Option<PathBuf>,
34    pub(crate) flags: OpenFlags,
35    pub(crate) journal_mode: Option<JournalMode>,
36    pub(crate) vfs: Option<String>,
37    pub(crate) queue_capacity: Option<usize>,
38}
39
40impl ClientBuilder {
41    /// Returns a new [`ClientBuilder`] with the default settings.
42    pub fn new() -> Self {
43        Self::default()
44    }
45
46    /// Specify the path of the sqlite3 database to open.
47    ///
48    /// By default, an in-memory database is used.
49    pub fn path<P: AsRef<Path>>(mut self, path: P) -> Self {
50        self.path = Some(path.as_ref().into());
51        self
52    }
53
54    /// Specify the [`OpenFlags`] to use when opening a new connection.
55    ///
56    /// By default, [`OpenFlags::default()`] is used.
57    pub fn flags(mut self, flags: OpenFlags) -> Self {
58        self.flags = flags;
59        self
60    }
61
62    /// Specify the [`JournalMode`] to set when opening a new connection.
63    ///
64    /// By default, no `journal_mode` is explicity set.
65    pub fn journal_mode(mut self, journal_mode: JournalMode) -> Self {
66        self.journal_mode = Some(journal_mode);
67        self
68    }
69
70    /// Specify the name of the [vfs](https://www.sqlite.org/vfs.html) to use.
71    pub fn vfs(mut self, vfs: &str) -> Self {
72        self.vfs = Some(vfs.to_owned());
73        self
74    }
75
76    /// Limit the number of commands that may wait in the worker queue.
77    ///
78    /// By default, the queue is unbounded. If a capacity is configured, calls
79    /// return [`Error::QueueFull`] when that many commands are already waiting
80    /// for the worker thread. A capacity of `0` allows a command to be accepted
81    /// only when the worker is ready to receive it immediately.
82    pub fn queue_capacity(mut self, queue_capacity: usize) -> Self {
83        self.queue_capacity = Some(queue_capacity);
84        self
85    }
86
87    /// Returns a new [`Client`] that uses the `ClientBuilder` configuration.
88    ///
89    /// # Examples
90    ///
91    /// ```rust
92    /// # use async_sqlite::ClientBuilder;
93    /// # async fn run() -> Result<(), async_sqlite::Error> {
94    /// let client = ClientBuilder::new().open().await?;
95    /// # Ok(())
96    /// # }
97    /// ```
98    pub async fn open(self) -> Result<Client, Error> {
99        Client::open_async(self).await
100    }
101
102    /// Returns a new [`Client`] that uses the `ClientBuilder` configuration,
103    /// blocking the current thread.
104    ///
105    /// # Examples
106    ///
107    /// ```rust
108    /// # use async_sqlite::ClientBuilder;
109    /// # fn run() -> Result<(), async_sqlite::Error> {
110    /// let client = ClientBuilder::new().open_blocking()?;
111    /// # Ok(())
112    /// # }
113    /// ```
114    pub fn open_blocking(self) -> Result<Client, Error> {
115        Client::open_blocking(self)
116    }
117}
118
119enum Command {
120    Func(Box<dyn QueuedFunc>),
121    Shutdown(Box<dyn QueuedShutdown>),
122}
123
124trait QueuedFunc: Send {
125    fn is_canceled(&self) -> bool;
126    fn execute(self: Box<Self>, conn: &mut Connection);
127}
128
129struct AsyncFunc<F, T, E> {
130    tx: oneshot::Sender<Result<T, E>>,
131    func: F,
132}
133
134impl<F, T, E> QueuedFunc for AsyncFunc<F, T, E>
135where
136    F: FnOnce(&mut Connection) -> Result<T, E> + Send + 'static,
137    T: Send + 'static,
138    E: Send + 'static,
139{
140    fn is_canceled(&self) -> bool {
141        self.tx.is_canceled()
142    }
143
144    fn execute(self: Box<Self>, conn: &mut Connection) {
145        let Self { tx, func } = *self;
146        _ = tx.send(func(conn));
147    }
148}
149
150struct BlockingFunc<F, T, E> {
151    tx: Sender<Result<T, E>>,
152    func: F,
153}
154
155impl<F, T, E> QueuedFunc for BlockingFunc<F, T, E>
156where
157    F: FnOnce(&mut Connection) -> Result<T, E> + Send + 'static,
158    T: Send + 'static,
159    E: Send + 'static,
160{
161    fn is_canceled(&self) -> bool {
162        false
163    }
164
165    fn execute(self: Box<Self>, conn: &mut Connection) {
166        let Self { tx, func } = *self;
167        _ = tx.send(func(conn));
168    }
169}
170
171trait QueuedShutdown: Send {
172    fn is_canceled(&self) -> bool;
173    fn respond(self: Box<Self>, res: Result<(), Error>);
174}
175
176struct AsyncShutdown {
177    tx: oneshot::Sender<Result<(), Error>>,
178}
179
180impl QueuedShutdown for AsyncShutdown {
181    fn is_canceled(&self) -> bool {
182        self.tx.is_canceled()
183    }
184
185    fn respond(self: Box<Self>, res: Result<(), Error>) {
186        _ = self.tx.send(res);
187    }
188}
189
190struct BlockingShutdown {
191    tx: Sender<Result<(), Error>>,
192}
193
194impl QueuedShutdown for BlockingShutdown {
195    fn is_canceled(&self) -> bool {
196        false
197    }
198
199    fn respond(self: Box<Self>, res: Result<(), Error>) {
200        _ = self.tx.send(res);
201    }
202}
203
204fn run_catching<F, T>(conn: &mut Connection, func: F) -> Result<T, Error>
205where
206    F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error>,
207{
208    match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| func(conn))) {
209        Ok(res) => res.map_err(Error::from),
210        Err(p) => {
211            rollback_if_needed(conn);
212            Err(Error::Panic {
213                message: panic_message(&*p),
214            })
215        }
216    }
217}
218
219fn run_catching_and_then<F, T, E>(conn: &mut Connection, func: F) -> Result<T, E>
220where
221    F: FnOnce(&mut Connection) -> Result<T, E>,
222    E: From<Error>,
223{
224    match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| func(conn))) {
225        Ok(res) => res,
226        Err(p) => {
227            rollback_if_needed(conn);
228            Err(E::from(Error::Panic {
229                message: panic_message(&*p),
230            }))
231        }
232    }
233}
234
235fn rollback_if_needed(conn: &mut Connection) {
236    if !conn.is_autocommit() {
237        let _ = conn.execute_batch("ROLLBACK");
238    }
239}
240
241fn panic_message(p: &(dyn std::any::Any + Send)) -> String {
242    if let Some(s) = p.downcast_ref::<&'static str>() {
243        (*s).to_owned()
244    } else if let Some(s) = p.downcast_ref::<String>() {
245        s.clone()
246    } else {
247        "panic".to_owned()
248    }
249}
250
251/// Client represents a single sqlite connection that can be used from async
252/// contexts.
253#[derive(Clone)]
254pub struct Client {
255    conn_tx: Sender<Command>,
256}
257
258impl Client {
259    async fn open_async(builder: ClientBuilder) -> Result<Self, Error> {
260        let (open_tx, open_rx) = oneshot::channel();
261        Self::open(builder, |res| {
262            _ = open_tx.send(res);
263        });
264        open_rx.await?
265    }
266
267    fn open_blocking(builder: ClientBuilder) -> Result<Self, Error> {
268        let (conn_tx, conn_rx) = bounded(1);
269        Self::open(builder, move |res| {
270            _ = conn_tx.send(res);
271        });
272        conn_rx.recv()?
273    }
274
275    fn open<F>(builder: ClientBuilder, func: F)
276    where
277        F: FnOnce(Result<Self, Error>) + Send + 'static,
278    {
279        thread::spawn(move || {
280            let (conn_tx, conn_rx) = match builder.queue_capacity {
281                Some(queue_capacity) => bounded(queue_capacity),
282                None => unbounded(),
283            };
284
285            let mut conn = match Client::create_conn(builder) {
286                Ok(conn) => conn,
287                Err(err) => {
288                    func(Err(err));
289                    return;
290                }
291            };
292
293            let client = Self { conn_tx };
294            func(Ok(client));
295
296            while let Ok(cmd) = conn_rx.recv() {
297                match cmd {
298                    Command::Func(func) => {
299                        if !func.is_canceled() {
300                            func.execute(&mut conn);
301                        }
302                    }
303                    Command::Shutdown(func) => {
304                        if !func.is_canceled() {
305                            match conn.close() {
306                                Ok(()) => {
307                                    func.respond(Ok(()));
308                                    return;
309                                }
310                                Err((c, e)) => {
311                                    conn = c;
312                                    func.respond(Err(e.into()));
313                                }
314                            }
315                        }
316                    }
317                }
318            }
319        });
320    }
321
322    fn create_conn(mut builder: ClientBuilder) -> Result<Connection, Error> {
323        let path = builder.path.take().unwrap_or_else(|| ":memory:".into());
324        let conn = if let Some(vfs) = builder.vfs.take() {
325            Connection::open_with_flags_and_vfs(path, builder.flags, vfs.as_str())?
326        } else {
327            Connection::open_with_flags(path, builder.flags)?
328        };
329
330        if let Some(journal_mode) = builder.journal_mode.take() {
331            let val = journal_mode.as_str();
332            let out: String =
333                conn.pragma_update_and_check(None, "journal_mode", val, |row| row.get(0))?;
334            if !out.eq_ignore_ascii_case(val) {
335                return Err(Error::PragmaUpdate {
336                    name: "journal_mode",
337                    exp: val,
338                    got: out,
339                });
340            }
341        }
342
343        Ok(conn)
344    }
345
346    fn enqueue_async<F, T, E>(
347        &self,
348        func: F,
349    ) -> Result<oneshot::Receiver<Result<T, E>>, TrySendError<Command>>
350    where
351        F: FnOnce(&mut Connection) -> Result<T, E> + Send + 'static,
352        T: Send + 'static,
353        E: Send + 'static,
354    {
355        let (tx, rx) = oneshot::channel();
356        self.conn_tx
357            .try_send(Command::Func(Box::new(AsyncFunc { tx, func })))?;
358        Ok(rx)
359    }
360
361    fn enqueue_blocking<F, T, E>(
362        &self,
363        func: F,
364    ) -> Result<Receiver<Result<T, E>>, TrySendError<Command>>
365    where
366        F: FnOnce(&mut Connection) -> Result<T, E> + Send + 'static,
367        T: Send + 'static,
368        E: Send + 'static,
369    {
370        let (tx, rx) = bounded(1);
371        self.conn_tx
372            .try_send(Command::Func(Box::new(BlockingFunc { tx, func })))?;
373        Ok(rx)
374    }
375
376    /// Invokes the provided function with a [`rusqlite::Connection`].
377    pub async fn conn<F, T>(&self, func: F) -> Result<T, Error>
378    where
379        F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
380        T: Send + 'static,
381    {
382        let rx = self
383            .enqueue_async(move |conn| run_catching(conn, |conn| func(conn)))
384            .map_err(Error::from)?;
385        rx.await?
386    }
387
388    /// Invokes the provided function with a mutable [`rusqlite::Connection`].
389    pub async fn conn_mut<F, T>(&self, func: F) -> Result<T, Error>
390    where
391        F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
392        T: Send + 'static,
393    {
394        let rx = self
395            .enqueue_async(move |conn| run_catching(conn, func))
396            .map_err(Error::from)?;
397        rx.await?
398    }
399
400    /// Invokes the provided function with a [`rusqlite::Connection`].
401    ///
402    /// Maps the result error type to a custom error; designed to be
403    /// used in conjunction with [`query_and_then`](https://docs.rs/rusqlite/latest/rusqlite/struct.CachedStatement.html#method.query_and_then).
404    pub async fn conn_and_then<F, T, E>(&self, func: F) -> Result<T, E>
405    where
406        F: FnOnce(&Connection) -> Result<T, E> + Send + 'static,
407        T: Send + 'static,
408        E: From<rusqlite::Error> + From<Error> + Send + 'static,
409    {
410        let rx = self
411            .enqueue_async(move |conn| run_catching_and_then(conn, |conn| func(conn)))
412            .map_err(Error::from)?;
413        rx.await.map_err(Error::from)?
414    }
415
416    /// Invokes the provided function with a mutable [`rusqlite::Connection`].
417    ///
418    /// Maps the result error type to a custom error; designed to be
419    /// used in conjunction with [`query_and_then`](https://docs.rs/rusqlite/latest/rusqlite/struct.CachedStatement.html#method.query_and_then).
420    pub async fn conn_mut_and_then<F, T, E>(&self, func: F) -> Result<T, E>
421    where
422        F: FnOnce(&mut Connection) -> Result<T, E> + Send + 'static,
423        T: Send + 'static,
424        E: From<rusqlite::Error> + From<Error> + Send + 'static,
425    {
426        let rx = self
427            .enqueue_async(move |conn| run_catching_and_then(conn, func))
428            .map_err(Error::from)?;
429        rx.await.map_err(Error::from)?
430    }
431
432    /// Closes the underlying sqlite connection.
433    ///
434    /// After this method returns, all calls to `self::conn()` or
435    /// `self::conn_mut()` will return an [`Error::Closed`] error.
436    pub async fn close(&self) -> Result<(), Error> {
437        let (tx, rx) = oneshot::channel();
438        match self
439            .conn_tx
440            .try_send(Command::Shutdown(Box::new(AsyncShutdown { tx })))
441        {
442            Ok(()) => {}
443            Err(TrySendError::Disconnected(_)) => {
444                // If the worker thread has already shut down, return Ok here.
445                return Ok(());
446            }
447            Err(err) => return Err(err.into()),
448        }
449        // If receiving fails, the connection is already closed.
450        rx.await.unwrap_or(Ok(()))
451    }
452
453    /// Invokes the provided function with a [`rusqlite::Connection`], blocking
454    /// the current thread until completion.
455    pub fn conn_blocking<F, T>(&self, func: F) -> Result<T, Error>
456    where
457        F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
458        T: Send + 'static,
459    {
460        let rx = self
461            .enqueue_blocking(move |conn| run_catching(conn, |conn| func(conn)))
462            .map_err(Error::from)?;
463        rx.recv()?
464    }
465
466    /// Invokes the provided function with a mutable [`rusqlite::Connection`],
467    /// blocking the current thread until completion.
468    pub fn conn_mut_blocking<F, T>(&self, func: F) -> Result<T, Error>
469    where
470        F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
471        T: Send + 'static,
472    {
473        let rx = self
474            .enqueue_blocking(move |conn| run_catching(conn, func))
475            .map_err(Error::from)?;
476        rx.recv()?
477    }
478
479    /// Invokes the provided function with a [`rusqlite::Connection`],
480    /// blocking the current thread until completion.
481    ///
482    /// Maps the result error type to a custom error; designed to be
483    /// used in conjunction with [`query_and_then`](https://docs.rs/rusqlite/latest/rusqlite/struct.CachedStatement.html#method.query_and_then).
484    pub fn conn_and_then_blocking<F, T, E>(&self, func: F) -> Result<T, E>
485    where
486        F: FnOnce(&Connection) -> Result<T, E> + Send + 'static,
487        T: Send + 'static,
488        E: From<rusqlite::Error> + From<Error> + Send + 'static,
489    {
490        let rx = self
491            .enqueue_blocking(move |conn| run_catching_and_then(conn, |conn| func(conn)))
492            .map_err(Error::from)?;
493        rx.recv().map_err(Error::from)?
494    }
495
496    /// Invokes the provided function with a mutable [`rusqlite::Connection`],
497    /// blocking the current thread until completion.
498    ///
499    /// Maps the result error type to a custom error; designed to be
500    /// used in conjunction with [`query_and_then`](https://docs.rs/rusqlite/latest/rusqlite/struct.CachedStatement.html#method.query_and_then).
501    pub fn conn_mut_and_then_blocking<F, T, E>(&self, func: F) -> Result<T, E>
502    where
503        F: FnOnce(&mut Connection) -> Result<T, E> + Send + 'static,
504        T: Send + 'static,
505        E: From<rusqlite::Error> + From<Error> + Send + 'static,
506    {
507        let rx = self
508            .enqueue_blocking(move |conn| run_catching_and_then(conn, func))
509            .map_err(Error::from)?;
510        rx.recv().map_err(Error::from)?
511    }
512
513    /// Closes the underlying sqlite connection, blocking the current thread
514    /// until complete.
515    ///
516    /// After this method returns, all calls to `self::conn_blocking()` or
517    /// `self::conn_mut_blocking()` will return an [`Error::Closed`] error.
518    pub fn close_blocking(&self) -> Result<(), Error> {
519        let (tx, rx) = bounded(1);
520        match self
521            .conn_tx
522            .try_send(Command::Shutdown(Box::new(BlockingShutdown { tx })))
523        {
524            Ok(()) => {}
525            Err(TrySendError::Disconnected(_)) => return Ok(()),
526            Err(err) => return Err(err.into()),
527        }
528        // If receiving fails, the connection is already closed.
529        rx.recv().unwrap_or(Ok(()))
530    }
531}
532
533/// The possible sqlite journal modes.
534///
535/// For more information, please see the [sqlite docs](https://www.sqlite.org/pragma.html#pragma_journal_mode).
536#[derive(Clone, Copy, Debug)]
537pub enum JournalMode {
538    Delete,
539    Truncate,
540    Persist,
541    Memory,
542    Wal,
543    Off,
544}
545
546impl JournalMode {
547    /// Returns the appropriate string representation of the journal mode.
548    pub fn as_str(&self) -> &'static str {
549        match self {
550            Self::Delete => "DELETE",
551            Self::Truncate => "TRUNCATE",
552            Self::Persist => "PERSIST",
553            Self::Memory => "MEMORY",
554            Self::Wal => "WAL",
555            Self::Off => "OFF",
556        }
557    }
558}