essential_builder_db/
pool.rs

1//! Provides an async-friendly [`ConnectionPool`] implementation.
2
3use crate::{
4    error::{
5        AcquireThenError, AcquireThenQueryError, AcquireThenRusqliteError, ConnectionCloseErrors,
6    },
7    with_tx,
8};
9use essential_builder_types::SolutionSetFailure;
10use essential_types::{solution::SolutionSet, ContentAddress};
11use rusqlite_pool::tokio::{AsyncConnectionHandle, AsyncConnectionPool};
12use std::{ops::Range, path::PathBuf, sync::Arc, time::Duration};
13use tokio::sync::{AcquireError, TryAcquireError};
14
15/// Access to the builder's DB connection pool and DB-access-related methods.
16///
17/// The handle is safe to clone and share between threads.
18#[derive(Clone)]
19pub struct ConnectionPool(AsyncConnectionPool);
20
21/// A temporary connection handle to a builder's [`ConnectionPool`].
22///
23/// Provides `Deref`, `DerefMut` impls for the inner [`rusqlite::Connection`].
24pub struct ConnectionHandle(AsyncConnectionHandle);
25
26/// Builder configuration related to the database.
27#[derive(Clone, Debug)]
28pub struct Config {
29    /// The number of simultaneous connections to the database to maintain.
30    pub conn_limit: usize,
31    /// How to source the builder's database.
32    pub source: Source,
33}
34
35/// The source of the builder's database.
36#[derive(Clone, Debug)]
37pub enum Source {
38    /// Use an in-memory database using the given string as a unique ID.
39    Memory(String),
40    /// Use the database at the given path.
41    Path(PathBuf),
42}
43
44impl ConnectionPool {
45    /// Create the connection pool from the given configuration.
46    pub fn new(conf: &Config) -> rusqlite::Result<Self> {
47        let conn_pool = Self(new_conn_pool(conf)?);
48        if let Source::Path(_) = conf.source {
49            let conn = conn_pool
50                .try_acquire()
51                .expect("pool must have at least one connection");
52            conn.pragma_update(None, "journal_mode", "wal")?;
53        }
54        Ok(conn_pool)
55    }
56
57    /// Create the connection pool from the given configuration and ensure the DB tables have been
58    /// created if they do not already exist before returning.
59    pub fn with_tables(conf: &Config) -> rusqlite::Result<Self> {
60        let conn_pool = Self::new(conf)?;
61        let mut conn = conn_pool.try_acquire().unwrap();
62        with_tx(&mut conn, |tx| crate::create_tables(tx))?;
63        Ok(conn_pool)
64    }
65
66    /// Acquire a temporary database [`ConnectionHandle`] from the inner pool.
67    ///
68    /// In the case that all connections are busy, waits for the first available
69    /// connection.
70    pub async fn acquire(&self) -> Result<ConnectionHandle, AcquireError> {
71        self.0.acquire().await.map(ConnectionHandle)
72    }
73
74    /// Attempt to synchronously acquire a temporary database [`ConnectionHandle`]
75    /// from the inner pool.
76    ///
77    /// Returns `Err` in the case that all database connections are busy or if
78    /// the builder has been closed.
79    pub fn try_acquire(&self) -> Result<ConnectionHandle, TryAcquireError> {
80        self.0.try_acquire().map(ConnectionHandle)
81    }
82
83    /// Close a connection pool, returning a `ConnectionCloseErrors` in the case of any errors.
84    pub fn close(&self) -> Result<(), ConnectionCloseErrors> {
85        let res = self.0.close();
86        let errs: Vec<_> = res.into_iter().filter_map(Result::err).collect();
87        if !errs.is_empty() {
88            return Err(ConnectionCloseErrors(errs));
89        }
90        Ok(())
91    }
92}
93
94/// Short-hand methods for async DB access.
95impl ConnectionPool {
96    /// Asynchronous access to the builder's DB via the given function.
97    ///
98    /// Requests and awaits a connection from the connection pool, then spawns a
99    /// blocking task for the given function providing access to the connection handle.
100    pub async fn acquire_then<F, T, E>(&self, f: F) -> Result<T, AcquireThenError<E>>
101    where
102        F: 'static + Send + FnOnce(&mut ConnectionHandle) -> Result<T, E>,
103        T: 'static + Send,
104        E: 'static + Send,
105    {
106        // Acquire a handle.
107        let mut handle = self.acquire().await?;
108
109        // Spawn the given DB connection access function on a task.
110        tokio::task::spawn_blocking(move || f(&mut handle))
111            .await?
112            .map_err(AcquireThenError::Inner)
113    }
114
115    /// Acquire a connection and call [`crate::create_tables`].
116    pub async fn create_tables(&self) -> Result<(), AcquireThenRusqliteError> {
117        self.acquire_then(|h| with_tx(h, |tx| crate::create_tables(tx)))
118            .await
119    }
120
121    /// Acquire a connection and call [`crate::insert_solution_set_submission`].
122    pub async fn insert_solution_set_submission(
123        &self,
124        solution_set: Arc<SolutionSet>,
125        timestamp: Duration,
126    ) -> Result<ContentAddress, AcquireThenRusqliteError> {
127        self.acquire_then(move |h| {
128            with_tx(h, |tx| {
129                crate::insert_solution_set_submission(tx, &solution_set, timestamp)
130            })
131        })
132        .await
133    }
134
135    /// Acquire a connection and call [`crate::insert_solution_set_failure`].
136    pub async fn insert_solution_set_failure(
137        &self,
138        solution_set_ca: ContentAddress,
139        failure: SolutionSetFailure<'static>,
140    ) -> Result<(), AcquireThenRusqliteError> {
141        self.acquire_then(move |h| crate::insert_solution_set_failure(h, &solution_set_ca, failure))
142            .await
143    }
144
145    /// Acquire a connection and call [`crate::get_solution_set`].
146    pub async fn get_solution_set(
147        &self,
148        ca: ContentAddress,
149    ) -> Result<Option<SolutionSet>, AcquireThenQueryError> {
150        self.acquire_then(move |h| crate::get_solution_set(h, &ca))
151            .await
152    }
153
154    /// Acquire a connection and call [`crate::list_solution_sets`].
155    pub async fn list_solution_sets(
156        &self,
157        time_range: Range<Duration>,
158        limit: i64,
159    ) -> Result<Vec<(ContentAddress, SolutionSet, Duration)>, AcquireThenQueryError> {
160        self.acquire_then(move |h| crate::list_solution_sets(h, time_range, limit))
161            .await
162    }
163
164    /// Acquire a connection and call [`crate::list_submissions`].
165    pub async fn list_submissions(
166        &self,
167        time_range: Range<Duration>,
168        limit: i64,
169    ) -> Result<Vec<(ContentAddress, Duration)>, AcquireThenRusqliteError> {
170        self.acquire_then(move |h| crate::list_submissions(h, time_range, limit))
171            .await
172    }
173
174    /// Acquire a connection and call [`crate::latest_solution_set_failures`].
175    pub async fn latest_solution_set_failures(
176        &self,
177        solution_set_ca: ContentAddress,
178        limit: u32,
179    ) -> Result<Vec<SolutionSetFailure<'static>>, AcquireThenRusqliteError> {
180        self.acquire_then(move |h| crate::latest_solution_set_failures(h, &solution_set_ca, limit))
181            .await
182    }
183
184    /// Acquire a connection and call [`crate::list_solution_set_failures`].
185    pub async fn list_solution_set_failures(
186        &self,
187        offset: u32,
188        limit: u32,
189    ) -> Result<Vec<SolutionSetFailure<'static>>, AcquireThenRusqliteError> {
190        self.acquire_then(move |h| crate::list_solution_set_failures(h, offset, limit))
191            .await
192    }
193
194    /// Acquire a connection and call [`crate::delete_solution_set`].
195    pub async fn delete_solution_set(
196        &self,
197        ca: ContentAddress,
198    ) -> Result<(), AcquireThenRusqliteError> {
199        self.acquire_then(move |h| crate::delete_solution_set(h, &ca))
200            .await
201    }
202
203    /// Delete the given set of solution sets in a single transaction.
204    pub async fn delete_solution_sets(
205        &self,
206        cas: impl 'static + IntoIterator<Item = ContentAddress> + Send,
207    ) -> Result<(), AcquireThenRusqliteError> {
208        self.acquire_then(|h| with_tx(h, |tx| crate::delete_solution_sets(tx, cas)))
209            .await
210    }
211
212    /// Acquire a connection and call [`crate::delete_oldest_solution_set_failures`].
213    pub async fn delete_oldest_solution_set_failures(
214        &self,
215        keep_limit: u32,
216    ) -> Result<(), AcquireThenRusqliteError> {
217        self.acquire_then(move |h| crate::delete_oldest_solution_set_failures(h, keep_limit))
218            .await
219    }
220}
221
222impl Config {
223    /// The default connection limit.
224    ///
225    /// This default uses the number of available CPUs as a heuristic for a
226    /// default connection limit. Specifically, it multiplies the number of
227    /// available CPUs by 4.
228    pub fn default_conn_limit() -> usize {
229        // TODO: Unsure if wasm-compatible? May want a feature for this?
230        num_cpus::get().saturating_mul(4)
231    }
232}
233
234impl Source {
235    /// A temporary, in-memory DB with a default ID.
236    pub fn default_memory() -> Self {
237        // Default ID cannot be an empty string.
238        Self::Memory("__default-id".to_string())
239    }
240}
241
242impl AsRef<rusqlite::Connection> for ConnectionHandle {
243    fn as_ref(&self) -> &rusqlite::Connection {
244        self
245    }
246}
247
248impl core::ops::Deref for ConnectionHandle {
249    type Target = AsyncConnectionHandle;
250    fn deref(&self) -> &Self::Target {
251        &self.0
252    }
253}
254
255impl core::ops::DerefMut for ConnectionHandle {
256    fn deref_mut(&mut self) -> &mut Self::Target {
257        &mut self.0
258    }
259}
260
261impl Default for Source {
262    fn default() -> Self {
263        Self::default_memory()
264    }
265}
266
267impl Default for Config {
268    fn default() -> Self {
269        Self {
270            conn_limit: Self::default_conn_limit(),
271            source: Source::default(),
272        }
273    }
274}
275
276/// Initialise the connection pool from the given configuration.
277fn new_conn_pool(conf: &Config) -> rusqlite::Result<AsyncConnectionPool> {
278    AsyncConnectionPool::new(conf.conn_limit, || new_conn(&conf.source))
279}
280
281/// Create a new connection given a DB source.
282fn new_conn(source: &Source) -> rusqlite::Result<rusqlite::Connection> {
283    let conn = match source {
284        Source::Memory(id) => new_mem_conn(id),
285        Source::Path(p) => {
286            if let Some(dir) = p.parent() {
287                let _ = std::fs::create_dir_all(dir);
288            }
289            let conn = rusqlite::Connection::open(p)?;
290            conn.pragma_update(None, "trusted_schema", false)?;
291            conn.pragma_update(None, "synchronous", 1)?;
292            Ok(conn)
293        }
294    }?;
295    conn.pragma_update(None, "foreign_keys", true)?;
296    Ok(conn)
297}
298
299/// Create an in-memory connection with the given ID
300fn new_mem_conn(id: &str) -> rusqlite::Result<rusqlite::Connection> {
301    let conn_str = format!("file:/{id}");
302    rusqlite::Connection::open_with_flags_and_vfs(conn_str, Default::default(), "memdb")
303}