async_sqlite/
client.rs

1use std::{
2    path::{Path, PathBuf},
3    thread,
4};
5
6use crate::Error;
7
8use crossbeam_channel::{bounded, unbounded, Sender};
9use futures_channel::oneshot;
10use rusqlite::{Connection, OpenFlags};
11
12/// A `ClientBuilder` can be used to create a [`Client`] with custom
13/// configuration.
14///
15/// For more information on creating a sqlite connection, see the
16/// [rusqlite docs](rusqlite::Connection::open()).
17///
18/// # Examples
19///
20/// ```rust
21/// # use async_sqlite::ClientBuilder;
22/// # async fn run() -> Result<(), async_sqlite::Error> {
23/// let client = ClientBuilder::new().path("path/to/db.sqlite3").open().await?;
24///
25/// // ...
26///
27/// client.close().await?;
28/// # Ok(())
29/// # }
30/// ```
31#[derive(Clone, Debug, Default)]
32pub struct ClientBuilder {
33    pub(crate) path: Option<PathBuf>,
34    pub(crate) flags: OpenFlags,
35    pub(crate) journal_mode: Option<JournalMode>,
36    pub(crate) vfs: Option<String>,
37}
38
39impl ClientBuilder {
40    /// Returns a new [`ClientBuilder`] with the default settings.
41    pub fn new() -> Self {
42        Self::default()
43    }
44
45    /// Specify the path of the sqlite3 database to open.
46    ///
47    /// By default, an in-memory database is used.
48    pub fn path<P: AsRef<Path>>(mut self, path: P) -> Self {
49        self.path = Some(path.as_ref().into());
50        self
51    }
52
53    /// Specify the [`OpenFlags`] to use when opening a new connection.
54    ///
55    /// By default, [`OpenFlags::default()`] is used.
56    pub fn flags(mut self, flags: OpenFlags) -> Self {
57        self.flags = flags;
58        self
59    }
60
61    /// Specify the [`JournalMode`] to set when opening a new connection.
62    ///
63    /// By default, no `journal_mode` is explicity set.
64    pub fn journal_mode(mut self, journal_mode: JournalMode) -> Self {
65        self.journal_mode = Some(journal_mode);
66        self
67    }
68
69    /// Specify the name of the [vfs](https://www.sqlite.org/vfs.html) to use.
70    pub fn vfs(mut self, vfs: &str) -> Self {
71        self.vfs = Some(vfs.to_owned());
72        self
73    }
74
75    /// Returns a new [`Client`] that uses the `ClientBuilder` configuration.
76    ///
77    /// # Examples
78    ///
79    /// ```rust
80    /// # use async_sqlite::ClientBuilder;
81    /// # async fn run() -> Result<(), async_sqlite::Error> {
82    /// let client = ClientBuilder::new().open().await?;
83    /// # Ok(())
84    /// # }
85    /// ```
86    pub async fn open(self) -> Result<Client, Error> {
87        Client::open_async(self).await
88    }
89
90    /// Returns a new [`Client`] that uses the `ClientBuilder` configuration,
91    /// blocking the current thread.
92    ///
93    /// # Examples
94    ///
95    /// ```rust
96    /// # use async_sqlite::ClientBuilder;
97    /// # fn run() -> Result<(), async_sqlite::Error> {
98    /// let client = ClientBuilder::new().open_blocking()?;
99    /// # Ok(())
100    /// # }
101    /// ```
102    pub fn open_blocking(self) -> Result<Client, Error> {
103        Client::open_blocking(self)
104    }
105}
106
107enum Command {
108    Func(Box<dyn FnOnce(&mut Connection) + Send>),
109    Shutdown(Box<dyn FnOnce(Result<(), Error>) + Send>),
110}
111
112/// Client represents a single sqlite connection that can be used from async
113/// contexts.
114#[derive(Clone)]
115pub struct Client {
116    conn_tx: Sender<Command>,
117}
118
119impl Client {
120    async fn open_async(builder: ClientBuilder) -> Result<Self, Error> {
121        let (open_tx, open_rx) = oneshot::channel();
122        Self::open(builder, |res| {
123            _ = open_tx.send(res);
124        });
125        open_rx.await?
126    }
127
128    fn open_blocking(builder: ClientBuilder) -> Result<Self, Error> {
129        let (conn_tx, conn_rx) = bounded(1);
130        Self::open(builder, move |res| {
131            _ = conn_tx.send(res);
132        });
133        conn_rx.recv()?
134    }
135
136    fn open<F>(builder: ClientBuilder, func: F)
137    where
138        F: FnOnce(Result<Self, Error>) + Send + 'static,
139    {
140        thread::spawn(move || {
141            let (conn_tx, conn_rx) = unbounded();
142
143            let mut conn = match Client::create_conn(builder) {
144                Ok(conn) => conn,
145                Err(err) => {
146                    func(Err(err));
147                    return;
148                }
149            };
150
151            let client = Self { conn_tx };
152            func(Ok(client));
153
154            while let Ok(cmd) = conn_rx.recv() {
155                match cmd {
156                    Command::Func(func) => func(&mut conn),
157                    Command::Shutdown(func) => match conn.close() {
158                        Ok(()) => {
159                            func(Ok(()));
160                            return;
161                        }
162                        Err((c, e)) => {
163                            conn = c;
164                            func(Err(e.into()));
165                        }
166                    },
167                }
168            }
169        });
170    }
171
172    fn create_conn(mut builder: ClientBuilder) -> Result<Connection, Error> {
173        let path = builder.path.take().unwrap_or_else(|| ":memory:".into());
174        let conn = if let Some(vfs) = builder.vfs.take() {
175            Connection::open_with_flags_and_vfs(path, builder.flags, &vfs)?
176        } else {
177            Connection::open_with_flags(path, builder.flags)?
178        };
179
180        if let Some(journal_mode) = builder.journal_mode.take() {
181            let val = journal_mode.as_str();
182            let out: String =
183                conn.pragma_update_and_check(None, "journal_mode", val, |row| row.get(0))?;
184            if !out.eq_ignore_ascii_case(val) {
185                return Err(Error::PragmaUpdate {
186                    name: "journal_mode",
187                    exp: val,
188                    got: out,
189                });
190            }
191        }
192
193        Ok(conn)
194    }
195
196    /// Invokes the provided function with a [`rusqlite::Connection`].
197    pub async fn conn<F, T>(&self, func: F) -> Result<T, Error>
198    where
199        F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
200        T: Send + 'static,
201    {
202        let (tx, rx) = oneshot::channel();
203        self.conn_tx.send(Command::Func(Box::new(move |conn| {
204            _ = tx.send(func(conn));
205        })))?;
206        Ok(rx.await??)
207    }
208
209    /// Invokes the provided function with a mutable [`rusqlite::Connection`].
210    pub async fn conn_mut<F, T>(&self, func: F) -> Result<T, Error>
211    where
212        F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
213        T: Send + 'static,
214    {
215        let (tx, rx) = oneshot::channel();
216        self.conn_tx.send(Command::Func(Box::new(move |conn| {
217            _ = tx.send(func(conn));
218        })))?;
219        Ok(rx.await??)
220    }
221
222    /// Invokes the provided function with a [`rusqlite::Connection`].
223    ///
224    /// Maps the result error type to a custom error; designed to be
225    /// used in conjunction with [`query_and_then`](https://docs.rs/rusqlite/latest/rusqlite/struct.CachedStatement.html#method.query_and_then).
226    pub async fn conn_and_then<F, T, E>(&self, func: F) -> Result<T, E>
227    where
228        F: FnOnce(&Connection) -> Result<T, E> + Send + 'static,
229        T: Send + 'static,
230        E: From<rusqlite::Error> + From<Error> + Send + 'static,
231    {
232        let (tx, rx) = oneshot::channel();
233        self.conn_tx
234            .send(Command::Func(Box::new(move |conn| {
235                _ = tx.send(func(conn));
236            })))
237            .map_err(Error::from)?;
238        rx.await.map_err(Error::from)?
239    }
240
241    /// Invokes the provided function with a mutable [`rusqlite::Connection`].
242    ///
243    /// Maps the result error type to a custom error; designed to be
244    /// used in conjunction with [`query_and_then`](https://docs.rs/rusqlite/latest/rusqlite/struct.CachedStatement.html#method.query_and_then).
245    pub async fn conn_mut_and_then<F, T, E>(&self, func: F) -> Result<T, E>
246    where
247        F: FnOnce(&mut Connection) -> Result<T, E> + Send + 'static,
248        T: Send + 'static,
249        E: From<rusqlite::Error> + From<Error> + Send + 'static,
250    {
251        let (tx, rx) = oneshot::channel();
252        self.conn_tx
253            .send(Command::Func(Box::new(move |conn| {
254                _ = tx.send(func(conn));
255            })))
256            .map_err(Error::from)?;
257        rx.await.map_err(Error::from)?
258    }
259
260    /// Closes the underlying sqlite connection.
261    ///
262    /// After this method returns, all calls to `self::conn()` or
263    /// `self::conn_mut()` will return an [`Error::Closed`] error.
264    pub async fn close(&self) -> Result<(), Error> {
265        let (tx, rx) = oneshot::channel();
266        let func = Box::new(|res| _ = tx.send(res));
267        if self.conn_tx.send(Command::Shutdown(func)).is_err() {
268            // If the worker thread has already shut down, return Ok here.
269            return Ok(());
270        }
271        // If receiving fails, the connection is already closed.
272        rx.await.unwrap_or(Ok(()))
273    }
274
275    /// Invokes the provided function with a [`rusqlite::Connection`], blocking
276    /// the current thread until completion.
277    pub fn conn_blocking<F, T>(&self, func: F) -> Result<T, Error>
278    where
279        F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
280        T: Send + 'static,
281    {
282        let (tx, rx) = bounded(1);
283        self.conn_tx.send(Command::Func(Box::new(move |conn| {
284            _ = tx.send(func(conn));
285        })))?;
286        Ok(rx.recv()??)
287    }
288
289    /// Invokes the provided function with a mutable [`rusqlite::Connection`],
290    /// blocking the current thread until completion.
291    pub fn conn_mut_blocking<F, T>(&self, func: F) -> Result<T, Error>
292    where
293        F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
294        T: Send + 'static,
295    {
296        let (tx, rx) = bounded(1);
297        self.conn_tx.send(Command::Func(Box::new(move |conn| {
298            _ = tx.send(func(conn));
299        })))?;
300        Ok(rx.recv()??)
301    }
302
303    /// Closes the underlying sqlite connection, blocking the current thread
304    /// until complete.
305    ///
306    /// After this method returns, all calls to `self::conn_blocking()` or
307    /// `self::conn_mut_blocking()` will return an [`Error::Closed`] error.
308    pub fn close_blocking(&self) -> Result<(), Error> {
309        let (tx, rx) = bounded(1);
310        let func = Box::new(move |res| _ = tx.send(res));
311        if self.conn_tx.send(Command::Shutdown(func)).is_err() {
312            return Ok(());
313        }
314        // If receiving fails, the connection is already closed.
315        rx.recv().unwrap_or(Ok(()))
316    }
317}
318
319/// The possible sqlite journal modes.
320///
321/// For more information, please see the [sqlite docs](https://www.sqlite.org/pragma.html#pragma_journal_mode).
322#[derive(Clone, Copy, Debug)]
323pub enum JournalMode {
324    Delete,
325    Truncate,
326    Persist,
327    Memory,
328    Wal,
329    Off,
330}
331
332impl JournalMode {
333    /// Returns the appropriate string representation of the journal mode.
334    pub fn as_str(&self) -> &'static str {
335        match self {
336            Self::Delete => "DELETE",
337            Self::Truncate => "TRUNCATE",
338            Self::Persist => "PERSIST",
339            Self::Memory => "MEMORY",
340            Self::Wal => "WAL",
341            Self::Off => "OFF",
342        }
343    }
344}