1use crate::{
4 error::{
5 AcquireThenError, AcquireThenQueryError, AcquireThenRusqliteError, ConnectionCloseErrors,
6 },
7 with_tx,
8};
9use essential_builder_types::SolutionSetFailure;
10use essential_types::{solution::SolutionSet, ContentAddress};
11use rusqlite_pool::tokio::{AsyncConnectionHandle, AsyncConnectionPool};
12use std::{ops::Range, path::PathBuf, sync::Arc, time::Duration};
13use tokio::sync::{AcquireError, TryAcquireError};
14
15#[derive(Clone)]
19pub struct ConnectionPool(AsyncConnectionPool);
20
21pub struct ConnectionHandle(AsyncConnectionHandle);
25
26#[derive(Clone, Debug)]
28pub struct Config {
29 pub conn_limit: usize,
31 pub source: Source,
33}
34
35#[derive(Clone, Debug)]
37pub enum Source {
38 Memory(String),
40 Path(PathBuf),
42}
43
44impl ConnectionPool {
45 pub fn new(conf: &Config) -> rusqlite::Result<Self> {
47 let conn_pool = Self(new_conn_pool(conf)?);
48 if let Source::Path(_) = conf.source {
49 let conn = conn_pool
50 .try_acquire()
51 .expect("pool must have at least one connection");
52 conn.pragma_update(None, "journal_mode", "wal")?;
53 }
54 Ok(conn_pool)
55 }
56
57 pub fn with_tables(conf: &Config) -> rusqlite::Result<Self> {
60 let conn_pool = Self::new(conf)?;
61 let mut conn = conn_pool.try_acquire().unwrap();
62 with_tx(&mut conn, |tx| crate::create_tables(tx))?;
63 Ok(conn_pool)
64 }
65
66 pub async fn acquire(&self) -> Result<ConnectionHandle, AcquireError> {
71 self.0.acquire().await.map(ConnectionHandle)
72 }
73
74 pub fn try_acquire(&self) -> Result<ConnectionHandle, TryAcquireError> {
80 self.0.try_acquire().map(ConnectionHandle)
81 }
82
83 pub fn close(&self) -> Result<(), ConnectionCloseErrors> {
85 let res = self.0.close();
86 let errs: Vec<_> = res.into_iter().filter_map(Result::err).collect();
87 if !errs.is_empty() {
88 return Err(ConnectionCloseErrors(errs));
89 }
90 Ok(())
91 }
92}
93
94impl ConnectionPool {
96 pub async fn acquire_then<F, T, E>(&self, f: F) -> Result<T, AcquireThenError<E>>
101 where
102 F: 'static + Send + FnOnce(&mut ConnectionHandle) -> Result<T, E>,
103 T: 'static + Send,
104 E: 'static + Send,
105 {
106 let mut handle = self.acquire().await?;
108
109 tokio::task::spawn_blocking(move || f(&mut handle))
111 .await?
112 .map_err(AcquireThenError::Inner)
113 }
114
115 pub async fn create_tables(&self) -> Result<(), AcquireThenRusqliteError> {
117 self.acquire_then(|h| with_tx(h, |tx| crate::create_tables(tx)))
118 .await
119 }
120
121 pub async fn insert_solution_set_submission(
123 &self,
124 solution_set: Arc<SolutionSet>,
125 timestamp: Duration,
126 ) -> Result<ContentAddress, AcquireThenRusqliteError> {
127 self.acquire_then(move |h| {
128 with_tx(h, |tx| {
129 crate::insert_solution_set_submission(tx, &solution_set, timestamp)
130 })
131 })
132 .await
133 }
134
135 pub async fn insert_solution_set_failure(
137 &self,
138 solution_set_ca: ContentAddress,
139 failure: SolutionSetFailure<'static>,
140 ) -> Result<(), AcquireThenRusqliteError> {
141 self.acquire_then(move |h| crate::insert_solution_set_failure(h, &solution_set_ca, failure))
142 .await
143 }
144
145 pub async fn get_solution_set(
147 &self,
148 ca: ContentAddress,
149 ) -> Result<Option<SolutionSet>, AcquireThenQueryError> {
150 self.acquire_then(move |h| crate::get_solution_set(h, &ca))
151 .await
152 }
153
154 pub async fn list_solution_sets(
156 &self,
157 time_range: Range<Duration>,
158 limit: i64,
159 ) -> Result<Vec<(ContentAddress, SolutionSet, Duration)>, AcquireThenQueryError> {
160 self.acquire_then(move |h| crate::list_solution_sets(h, time_range, limit))
161 .await
162 }
163
164 pub async fn list_submissions(
166 &self,
167 time_range: Range<Duration>,
168 limit: i64,
169 ) -> Result<Vec<(ContentAddress, Duration)>, AcquireThenRusqliteError> {
170 self.acquire_then(move |h| crate::list_submissions(h, time_range, limit))
171 .await
172 }
173
174 pub async fn latest_solution_set_failures(
176 &self,
177 solution_set_ca: ContentAddress,
178 limit: u32,
179 ) -> Result<Vec<SolutionSetFailure<'static>>, AcquireThenRusqliteError> {
180 self.acquire_then(move |h| crate::latest_solution_set_failures(h, &solution_set_ca, limit))
181 .await
182 }
183
184 pub async fn list_solution_set_failures(
186 &self,
187 offset: u32,
188 limit: u32,
189 ) -> Result<Vec<SolutionSetFailure<'static>>, AcquireThenRusqliteError> {
190 self.acquire_then(move |h| crate::list_solution_set_failures(h, offset, limit))
191 .await
192 }
193
194 pub async fn delete_solution_set(
196 &self,
197 ca: ContentAddress,
198 ) -> Result<(), AcquireThenRusqliteError> {
199 self.acquire_then(move |h| crate::delete_solution_set(h, &ca))
200 .await
201 }
202
203 pub async fn delete_solution_sets(
205 &self,
206 cas: impl 'static + IntoIterator<Item = ContentAddress> + Send,
207 ) -> Result<(), AcquireThenRusqliteError> {
208 self.acquire_then(|h| with_tx(h, |tx| crate::delete_solution_sets(tx, cas)))
209 .await
210 }
211
212 pub async fn delete_oldest_solution_set_failures(
214 &self,
215 keep_limit: u32,
216 ) -> Result<(), AcquireThenRusqliteError> {
217 self.acquire_then(move |h| crate::delete_oldest_solution_set_failures(h, keep_limit))
218 .await
219 }
220}
221
222impl Config {
223 pub fn default_conn_limit() -> usize {
229 num_cpus::get().saturating_mul(4)
231 }
232}
233
234impl Source {
235 pub fn default_memory() -> Self {
237 Self::Memory("__default-id".to_string())
239 }
240}
241
242impl AsRef<rusqlite::Connection> for ConnectionHandle {
243 fn as_ref(&self) -> &rusqlite::Connection {
244 self
245 }
246}
247
248impl core::ops::Deref for ConnectionHandle {
249 type Target = AsyncConnectionHandle;
250 fn deref(&self) -> &Self::Target {
251 &self.0
252 }
253}
254
255impl core::ops::DerefMut for ConnectionHandle {
256 fn deref_mut(&mut self) -> &mut Self::Target {
257 &mut self.0
258 }
259}
260
261impl Default for Source {
262 fn default() -> Self {
263 Self::default_memory()
264 }
265}
266
267impl Default for Config {
268 fn default() -> Self {
269 Self {
270 conn_limit: Self::default_conn_limit(),
271 source: Source::default(),
272 }
273 }
274}
275
276fn new_conn_pool(conf: &Config) -> rusqlite::Result<AsyncConnectionPool> {
278 AsyncConnectionPool::new(conf.conn_limit, || new_conn(&conf.source))
279}
280
281fn new_conn(source: &Source) -> rusqlite::Result<rusqlite::Connection> {
283 let conn = match source {
284 Source::Memory(id) => new_mem_conn(id),
285 Source::Path(p) => {
286 if let Some(dir) = p.parent() {
287 let _ = std::fs::create_dir_all(dir);
288 }
289 let conn = rusqlite::Connection::open(p)?;
290 conn.pragma_update(None, "trusted_schema", false)?;
291 conn.pragma_update(None, "synchronous", 1)?;
292 Ok(conn)
293 }
294 }?;
295 conn.pragma_update(None, "foreign_keys", true)?;
296 Ok(conn)
297}
298
299fn new_mem_conn(id: &str) -> rusqlite::Result<rusqlite::Connection> {
301 let conn_str = format!("file:/{id}");
302 rusqlite::Connection::open_with_flags_and_vfs(conn_str, Default::default(), "memdb")
303}