async_sqlite/
pool.rs

1use std::{
2    num::NonZeroUsize,
3    path::{Path, PathBuf},
4    sync::{
5        atomic::{AtomicU64, Ordering::Relaxed},
6        Arc,
7    },
8    thread::available_parallelism,
9};
10
11use crate::{Client, ClientBuilder, Error, JournalMode};
12
13use futures_util::future::join_all;
14use rusqlite::{Connection, OpenFlags};
15
16/// A `PoolBuilder` can be used to create a [`Pool`] with custom
17/// configuration.
18///
19/// See [`Client`] for more information.
20///
21/// # Examples
22///
23/// ```rust
24/// # use async_sqlite::PoolBuilder;
25/// # async fn run() -> Result<(), async_sqlite::Error> {
26/// let pool = PoolBuilder::new().path("path/to/db.sqlite3").open().await?;
27///
28/// // ...
29///
30/// pool.close().await?;
31/// # Ok(())
32/// # }
33/// ```
34#[derive(Clone, Debug, Default)]
35pub struct PoolBuilder {
36    path: Option<PathBuf>,
37    flags: OpenFlags,
38    journal_mode: Option<JournalMode>,
39    vfs: Option<String>,
40    num_conns: Option<usize>,
41}
42
43impl PoolBuilder {
44    /// Returns a new [`PoolBuilder`] with the default settings.
45    pub fn new() -> Self {
46        Self::default()
47    }
48
49    /// Specify the path of the sqlite3 database to open.
50    ///
51    /// By default, an in-memory database is used.
52    pub fn path<P: AsRef<Path>>(mut self, path: P) -> Self {
53        self.path = Some(path.as_ref().into());
54        self
55    }
56
57    /// Specify the [`OpenFlags`] to use when opening a new connection.
58    ///
59    /// By default, [`OpenFlags::default()`] is used.
60    pub fn flags(mut self, flags: OpenFlags) -> Self {
61        self.flags = flags;
62        self
63    }
64
65    /// Specify the [`JournalMode`] to set when opening a new connection.
66    ///
67    /// By default, no `journal_mode` is explicity set.
68    pub fn journal_mode(mut self, journal_mode: JournalMode) -> Self {
69        self.journal_mode = Some(journal_mode);
70        self
71    }
72
73    /// Specify the name of the [vfs](https://www.sqlite.org/vfs.html) to use.
74    pub fn vfs(mut self, vfs: &str) -> Self {
75        self.vfs = Some(vfs.to_owned());
76        self
77    }
78
79    /// Specify the number of sqlite connections to open as part of the pool.
80    ///
81    /// Defaults to the number of logical CPUs of the current system. Values
82    /// less than `1` are clamped to `1`.
83    ///
84    /// ```
85    /// use async_sqlite::PoolBuilder;
86    ///
87    /// let builder = PoolBuilder::new().num_conns(2);
88    /// ```
89    pub fn num_conns(mut self, num_conns: usize) -> Self {
90        self.num_conns = Some(num_conns.max(1));
91        self
92    }
93
94    /// Returns a new [`Pool`] that uses the `PoolBuilder` configuration.
95    ///
96    /// # Examples
97    ///
98    /// ```rust
99    /// # use async_sqlite::PoolBuilder;
100    /// # async fn run() -> Result<(), async_sqlite::Error> {
101    /// let pool = PoolBuilder::new().open().await?;
102    /// # Ok(())
103    /// # }
104    /// ```
105    pub async fn open(self) -> Result<Pool, Error> {
106        let num_conns = self.get_num_conns();
107        let opens = (0..num_conns).map(|_| {
108            ClientBuilder {
109                path: self.path.clone(),
110                flags: self.flags,
111                journal_mode: self.journal_mode,
112                vfs: self.vfs.clone(),
113            }
114            .open()
115        });
116        let clients = join_all(opens)
117            .await
118            .into_iter()
119            .collect::<Result<Vec<Client>, Error>>()?;
120        Ok(Pool {
121            state: Arc::new(State {
122                clients,
123                counter: AtomicU64::new(0),
124            }),
125        })
126    }
127
128    /// Returns a new [`Pool`] that uses the `PoolBuilder` configuration,
129    /// blocking the current thread.
130    ///
131    /// # Examples
132    ///
133    /// ```rust
134    /// # use async_sqlite::PoolBuilder;
135    /// # fn run() -> Result<(), async_sqlite::Error> {
136    /// let pool = PoolBuilder::new().open_blocking()?;
137    /// # Ok(())
138    /// # }
139    /// ```
140    pub fn open_blocking(self) -> Result<Pool, Error> {
141        let num_conns = self.get_num_conns();
142        let clients = (0..num_conns)
143            .map(|_| {
144                ClientBuilder {
145                    path: self.path.clone(),
146                    flags: self.flags,
147                    journal_mode: self.journal_mode,
148                    vfs: self.vfs.clone(),
149                }
150                .open_blocking()
151            })
152            .collect::<Result<Vec<Client>, Error>>()?;
153        Ok(Pool {
154            state: Arc::new(State {
155                clients,
156                counter: AtomicU64::new(0),
157            }),
158        })
159    }
160
161    fn get_num_conns(&self) -> usize {
162        self.num_conns.unwrap_or_else(|| {
163            available_parallelism()
164                .unwrap_or_else(|_| NonZeroUsize::new(1).unwrap())
165                .into()
166        })
167    }
168}
169
170/// A simple Pool of sqlite connections.
171///
172/// A Pool has the same API as an individual [`Client`].
173#[derive(Clone)]
174pub struct Pool {
175    state: Arc<State>,
176}
177
178struct State {
179    clients: Vec<Client>,
180    counter: AtomicU64,
181}
182
183impl Pool {
184    /// Invokes the provided function with a [`rusqlite::Connection`].
185    pub async fn conn<F, T>(&self, func: F) -> Result<T, Error>
186    where
187        F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
188        T: Send + 'static,
189    {
190        self.get().conn(func).await
191    }
192
193    /// Invokes the provided function with a mutable [`rusqlite::Connection`].
194    pub async fn conn_mut<F, T>(&self, func: F) -> Result<T, Error>
195    where
196        F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
197        T: Send + 'static,
198    {
199        self.get().conn_mut(func).await
200    }
201
202    /// Closes the underlying sqlite connections.
203    ///
204    /// After this method returns, all calls to `self::conn()` or
205    /// `self::conn_mut()` will return an [`Error::Closed`] error.
206    pub async fn close(&self) -> Result<(), Error> {
207        let closes = self.state.clients.iter().map(|client| client.close());
208        let res = join_all(closes).await;
209        res.into_iter().collect::<Result<Vec<_>, Error>>()?;
210        Ok(())
211    }
212
213    /// Invokes the provided function with a [`rusqlite::Connection`], blocking
214    /// the current thread.
215    pub fn conn_blocking<F, T>(&self, func: F) -> Result<T, Error>
216    where
217        F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
218        T: Send + 'static,
219    {
220        self.get().conn_blocking(func)
221    }
222
223    /// Invokes the provided function with a mutable [`rusqlite::Connection`],
224    /// blocking the current thread.
225    pub fn conn_mut_blocking<F, T>(&self, func: F) -> Result<T, Error>
226    where
227        F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
228        T: Send + 'static,
229    {
230        self.get().conn_mut_blocking(func)
231    }
232
233    /// Closes the underlying sqlite connections, blocking the current thread.
234    ///
235    /// After this method returns, all calls to `self::conn_blocking()` or
236    /// `self::conn_mut_blocking()` will return an [`Error::Closed`] error.
237    pub fn close_blocking(&self) -> Result<(), Error> {
238        self.state
239            .clients
240            .iter()
241            .try_for_each(|client| client.close_blocking())
242    }
243
244    fn get(&self) -> &Client {
245        let n = self.state.counter.fetch_add(1, Relaxed);
246        &self.state.clients[n as usize % self.state.clients.len()]
247    }
248
249    /// Runs a function on all connections in the pool asynchronously.
250    ///
251    /// The function is executed on each connection concurrently.
252    pub async fn conn_for_each<F, T>(&self, func: F) -> Vec<Result<T, Error>>
253    where
254        F: Fn(&Connection) -> Result<T, rusqlite::Error> + Send + Sync + 'static,
255        T: Send + 'static,
256    {
257        let func = Arc::new(func);
258        let futures = self.state.clients.iter().map(|client| {
259            let func = func.clone();
260            async move { client.conn(move |conn| func(conn)).await }
261        });
262        join_all(futures).await
263    }
264
265    /// Runs a function on all connections in the pool, blocking the current thread.
266    pub fn conn_for_each_blocking<F, T>(&self, func: F) -> Vec<Result<T, Error>>
267    where
268        F: Fn(&Connection) -> Result<T, rusqlite::Error> + Send + Sync + 'static,
269        T: Send + 'static,
270    {
271        let func = Arc::new(func);
272        self.state
273            .clients
274            .iter()
275            .map(|client| {
276                let func = func.clone();
277                client.conn_blocking(move |conn| func(conn))
278            })
279            .collect()
280    }
281}