Skip to main content

async_sqlite/
client.rs

1use std::{
2    path::{Path, PathBuf},
3    thread,
4};
5
6use crate::Error;
7
8use crossbeam_channel::{bounded, unbounded, Sender};
9use futures_channel::oneshot;
10use rusqlite::{Connection, OpenFlags};
11
12/// A `ClientBuilder` can be used to create a [`Client`] with custom
13/// configuration.
14///
15/// For more information on creating a sqlite connection, see the
16/// [rusqlite docs](rusqlite::Connection::open()).
17///
18/// # Examples
19///
20/// ```rust
21/// # use async_sqlite::ClientBuilder;
22/// # async fn run() -> Result<(), async_sqlite::Error> {
23/// let client = ClientBuilder::new().path("path/to/db.sqlite3").open().await?;
24///
25/// // ...
26///
27/// client.close().await?;
28/// # Ok(())
29/// # }
30/// ```
31#[derive(Clone, Debug, Default)]
32pub struct ClientBuilder {
33    pub(crate) path: Option<PathBuf>,
34    pub(crate) flags: OpenFlags,
35    pub(crate) journal_mode: Option<JournalMode>,
36    pub(crate) vfs: Option<String>,
37}
38
39impl ClientBuilder {
40    /// Returns a new [`ClientBuilder`] with the default settings.
41    pub fn new() -> Self {
42        Self::default()
43    }
44
45    /// Specify the path of the sqlite3 database to open.
46    ///
47    /// By default, an in-memory database is used.
48    pub fn path<P: AsRef<Path>>(mut self, path: P) -> Self {
49        self.path = Some(path.as_ref().into());
50        self
51    }
52
53    /// Specify the [`OpenFlags`] to use when opening a new connection.
54    ///
55    /// By default, [`OpenFlags::default()`] is used.
56    pub fn flags(mut self, flags: OpenFlags) -> Self {
57        self.flags = flags;
58        self
59    }
60
61    /// Specify the [`JournalMode`] to set when opening a new connection.
62    ///
63    /// By default, no `journal_mode` is explicity set.
64    pub fn journal_mode(mut self, journal_mode: JournalMode) -> Self {
65        self.journal_mode = Some(journal_mode);
66        self
67    }
68
69    /// Specify the name of the [vfs](https://www.sqlite.org/vfs.html) to use.
70    pub fn vfs(mut self, vfs: &str) -> Self {
71        self.vfs = Some(vfs.to_owned());
72        self
73    }
74
75    /// Returns a new [`Client`] that uses the `ClientBuilder` configuration.
76    ///
77    /// # Examples
78    ///
79    /// ```rust
80    /// # use async_sqlite::ClientBuilder;
81    /// # async fn run() -> Result<(), async_sqlite::Error> {
82    /// let client = ClientBuilder::new().open().await?;
83    /// # Ok(())
84    /// # }
85    /// ```
86    pub async fn open(self) -> Result<Client, Error> {
87        Client::open_async(self).await
88    }
89
90    /// Returns a new [`Client`] that uses the `ClientBuilder` configuration,
91    /// blocking the current thread.
92    ///
93    /// # Examples
94    ///
95    /// ```rust
96    /// # use async_sqlite::ClientBuilder;
97    /// # fn run() -> Result<(), async_sqlite::Error> {
98    /// let client = ClientBuilder::new().open_blocking()?;
99    /// # Ok(())
100    /// # }
101    /// ```
102    pub fn open_blocking(self) -> Result<Client, Error> {
103        Client::open_blocking(self)
104    }
105}
106
107enum Command {
108    Func(Box<dyn FnOnce(&mut Connection) + Send>),
109    Shutdown(Box<dyn FnOnce(Result<(), Error>) + Send>),
110}
111
112/// Client represents a single sqlite connection that can be used from async
113/// contexts.
114#[derive(Clone)]
115pub struct Client {
116    conn_tx: Sender<Command>,
117}
118
119impl Client {
120    async fn open_async(builder: ClientBuilder) -> Result<Self, Error> {
121        let (open_tx, open_rx) = oneshot::channel();
122        Self::open(builder, |res| {
123            _ = open_tx.send(res);
124        });
125        open_rx.await?
126    }
127
128    fn open_blocking(builder: ClientBuilder) -> Result<Self, Error> {
129        let (conn_tx, conn_rx) = bounded(1);
130        Self::open(builder, move |res| {
131            _ = conn_tx.send(res);
132        });
133        conn_rx.recv()?
134    }
135
136    fn open<F>(builder: ClientBuilder, func: F)
137    where
138        F: FnOnce(Result<Self, Error>) + Send + 'static,
139    {
140        thread::spawn(move || {
141            let (conn_tx, conn_rx) = unbounded();
142
143            let mut conn = match Client::create_conn(builder) {
144                Ok(conn) => conn,
145                Err(err) => {
146                    func(Err(err));
147                    return;
148                }
149            };
150
151            let client = Self { conn_tx };
152            func(Ok(client));
153
154            while let Ok(cmd) = conn_rx.recv() {
155                match cmd {
156                    Command::Func(func) => {
157                        let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
158                            func(&mut conn);
159                        }));
160                    }
161                    Command::Shutdown(func) => match conn.close() {
162                        Ok(()) => {
163                            func(Ok(()));
164                            return;
165                        }
166                        Err((c, e)) => {
167                            conn = c;
168                            func(Err(e.into()));
169                        }
170                    },
171                }
172            }
173        });
174    }
175
176    fn create_conn(mut builder: ClientBuilder) -> Result<Connection, Error> {
177        let path = builder.path.take().unwrap_or_else(|| ":memory:".into());
178        let conn = if let Some(vfs) = builder.vfs.take() {
179            Connection::open_with_flags_and_vfs(path, builder.flags, vfs.as_str())?
180        } else {
181            Connection::open_with_flags(path, builder.flags)?
182        };
183
184        if let Some(journal_mode) = builder.journal_mode.take() {
185            let val = journal_mode.as_str();
186            let out: String =
187                conn.pragma_update_and_check(None, "journal_mode", val, |row| row.get(0))?;
188            if !out.eq_ignore_ascii_case(val) {
189                return Err(Error::PragmaUpdate {
190                    name: "journal_mode",
191                    exp: val,
192                    got: out,
193                });
194            }
195        }
196
197        Ok(conn)
198    }
199
200    /// Invokes the provided function with a [`rusqlite::Connection`].
201    pub async fn conn<F, T>(&self, func: F) -> Result<T, Error>
202    where
203        F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
204        T: Send + 'static,
205    {
206        let (tx, rx) = oneshot::channel();
207        self.conn_tx.send(Command::Func(Box::new(move |conn| {
208            _ = tx.send(func(conn));
209        })))?;
210        Ok(rx.await??)
211    }
212
213    /// Invokes the provided function with a mutable [`rusqlite::Connection`].
214    pub async fn conn_mut<F, T>(&self, func: F) -> Result<T, Error>
215    where
216        F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
217        T: Send + 'static,
218    {
219        let (tx, rx) = oneshot::channel();
220        self.conn_tx.send(Command::Func(Box::new(move |conn| {
221            _ = tx.send(func(conn));
222        })))?;
223        Ok(rx.await??)
224    }
225
226    /// Invokes the provided function with a [`rusqlite::Connection`].
227    ///
228    /// Maps the result error type to a custom error; designed to be
229    /// used in conjunction with [`query_and_then`](https://docs.rs/rusqlite/latest/rusqlite/struct.CachedStatement.html#method.query_and_then).
230    pub async fn conn_and_then<F, T, E>(&self, func: F) -> Result<T, E>
231    where
232        F: FnOnce(&Connection) -> Result<T, E> + Send + 'static,
233        T: Send + 'static,
234        E: From<rusqlite::Error> + From<Error> + Send + 'static,
235    {
236        let (tx, rx) = oneshot::channel();
237        self.conn_tx
238            .send(Command::Func(Box::new(move |conn| {
239                _ = tx.send(func(conn));
240            })))
241            .map_err(Error::from)?;
242        rx.await.map_err(Error::from)?
243    }
244
245    /// Invokes the provided function with a mutable [`rusqlite::Connection`].
246    ///
247    /// Maps the result error type to a custom error; designed to be
248    /// used in conjunction with [`query_and_then`](https://docs.rs/rusqlite/latest/rusqlite/struct.CachedStatement.html#method.query_and_then).
249    pub async fn conn_mut_and_then<F, T, E>(&self, func: F) -> Result<T, E>
250    where
251        F: FnOnce(&mut Connection) -> Result<T, E> + Send + 'static,
252        T: Send + 'static,
253        E: From<rusqlite::Error> + From<Error> + Send + 'static,
254    {
255        let (tx, rx) = oneshot::channel();
256        self.conn_tx
257            .send(Command::Func(Box::new(move |conn| {
258                _ = tx.send(func(conn));
259            })))
260            .map_err(Error::from)?;
261        rx.await.map_err(Error::from)?
262    }
263
264    /// Closes the underlying sqlite connection.
265    ///
266    /// After this method returns, all calls to `self::conn()` or
267    /// `self::conn_mut()` will return an [`Error::Closed`] error.
268    pub async fn close(&self) -> Result<(), Error> {
269        let (tx, rx) = oneshot::channel();
270        let func = Box::new(|res| _ = tx.send(res));
271        if self.conn_tx.send(Command::Shutdown(func)).is_err() {
272            // If the worker thread has already shut down, return Ok here.
273            return Ok(());
274        }
275        // If receiving fails, the connection is already closed.
276        rx.await.unwrap_or(Ok(()))
277    }
278
279    /// Invokes the provided function with a [`rusqlite::Connection`], blocking
280    /// the current thread until completion.
281    pub fn conn_blocking<F, T>(&self, func: F) -> Result<T, Error>
282    where
283        F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
284        T: Send + 'static,
285    {
286        let (tx, rx) = bounded(1);
287        self.conn_tx.send(Command::Func(Box::new(move |conn| {
288            _ = tx.send(func(conn));
289        })))?;
290        Ok(rx.recv()??)
291    }
292
293    /// Invokes the provided function with a mutable [`rusqlite::Connection`],
294    /// blocking the current thread until completion.
295    pub fn conn_mut_blocking<F, T>(&self, func: F) -> Result<T, Error>
296    where
297        F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
298        T: Send + 'static,
299    {
300        let (tx, rx) = bounded(1);
301        self.conn_tx.send(Command::Func(Box::new(move |conn| {
302            _ = tx.send(func(conn));
303        })))?;
304        Ok(rx.recv()??)
305    }
306
307    /// Closes the underlying sqlite connection, blocking the current thread
308    /// until complete.
309    ///
310    /// After this method returns, all calls to `self::conn_blocking()` or
311    /// `self::conn_mut_blocking()` will return an [`Error::Closed`] error.
312    pub fn close_blocking(&self) -> Result<(), Error> {
313        let (tx, rx) = bounded(1);
314        let func = Box::new(move |res| _ = tx.send(res));
315        if self.conn_tx.send(Command::Shutdown(func)).is_err() {
316            return Ok(());
317        }
318        // If receiving fails, the connection is already closed.
319        rx.recv().unwrap_or(Ok(()))
320    }
321}
322
323/// The possible sqlite journal modes.
324///
325/// For more information, please see the [sqlite docs](https://www.sqlite.org/pragma.html#pragma_journal_mode).
326#[derive(Clone, Copy, Debug)]
327pub enum JournalMode {
328    Delete,
329    Truncate,
330    Persist,
331    Memory,
332    Wal,
333    Off,
334}
335
336impl JournalMode {
337    /// Returns the appropriate string representation of the journal mode.
338    pub fn as_str(&self) -> &'static str {
339        match self {
340            Self::Delete => "DELETE",
341            Self::Truncate => "TRUNCATE",
342            Self::Persist => "PERSIST",
343            Self::Memory => "MEMORY",
344            Self::Wal => "WAL",
345            Self::Off => "OFF",
346        }
347    }
348}