1use crate::{with_tx, AcquireConnection, AwaitNewBlock, QueryError};
7use core::ops::Range;
8use essential_node_types::{block_notify::BlockRx, Block};
9use essential_types::{solution::SolutionSet, ContentAddress, Key, Value, Word};
10use futures::Stream;
11use rusqlite_pool::tokio::{AsyncConnectionHandle, AsyncConnectionPool};
12use std::{path::PathBuf, sync::Arc, time::Duration};
13use thiserror::Error;
14use tokio::sync::{AcquireError, TryAcquireError};
15
16#[derive(Clone)]
20pub struct ConnectionPool(AsyncConnectionPool);
21
22pub struct ConnectionHandle(AsyncConnectionHandle);
26
27#[derive(Clone, Debug)]
29pub struct Config {
30 pub conn_limit: usize,
32 pub source: Source,
34}
35
36#[derive(Clone, Debug)]
38pub enum Source {
39 Memory(String),
41 Path(PathBuf),
43}
44
45#[derive(Debug, Error)]
47pub enum AcquireThenError<E> {
48 #[error("failed to acquire a DB connection: {0}")]
50 Acquire(#[from] tokio::sync::AcquireError),
51 #[error("failed to join task: {0}")]
53 Join(#[from] tokio::task::JoinError),
54 #[error("{0}")]
56 Inner(E),
57}
58
59pub type AcquireThenRusqliteError = AcquireThenError<rusqlite::Error>;
61
62pub type AcquireThenQueryError = AcquireThenError<crate::QueryError>;
64
65#[derive(Debug, Error)]
67pub struct ConnectionCloseErrors(pub Vec<(rusqlite::Connection, rusqlite::Error)>);
68
69impl ConnectionPool {
70 pub fn new(conf: &Config) -> rusqlite::Result<Self> {
75 let conn_pool = Self(new_conn_pool(conf)?);
76 if let Source::Path(_) = conf.source {
77 let conn = conn_pool
78 .try_acquire()
79 .expect("pool must have at least one connection");
80 conn.pragma_update(None, "journal_mode", "wal")?;
81 }
82 Ok(conn_pool)
83 }
84
85 pub fn with_tables(conf: &Config) -> rusqlite::Result<Self> {
101 let conn_pool = Self::new(conf)?;
102 let mut conn = conn_pool.try_acquire().unwrap();
103 with_tx(&mut conn, |tx| crate::create_tables(tx))?;
104 Ok(conn_pool)
105 }
106
107 pub async fn acquire(&self) -> Result<ConnectionHandle, AcquireError> {
112 self.0.acquire().await.map(ConnectionHandle)
113 }
114
115 pub fn try_acquire(&self) -> Result<ConnectionHandle, TryAcquireError> {
121 self.0.try_acquire().map(ConnectionHandle)
122 }
123
124 pub fn close(&self) -> Result<(), ConnectionCloseErrors> {
126 let res = self.0.close();
127 let errs: Vec<_> = res.into_iter().filter_map(Result::err).collect();
128 if !errs.is_empty() {
129 return Err(ConnectionCloseErrors(errs));
130 }
131 Ok(())
132 }
133}
134
135impl ConnectionPool {
137 pub async fn acquire_then<F, T, E>(&self, f: F) -> Result<T, AcquireThenError<E>>
142 where
143 F: 'static + Send + FnOnce(&mut ConnectionHandle) -> Result<T, E>,
144 T: 'static + Send,
145 E: 'static + Send,
146 {
147 let mut handle = self.acquire().await?;
149
150 tokio::task::spawn_blocking(move || f(&mut handle))
152 .await?
153 .map_err(AcquireThenError::Inner)
154 }
155
156 pub async fn create_tables(&self) -> Result<(), AcquireThenRusqliteError> {
158 self.acquire_then(|h| with_tx(h, |tx| crate::create_tables(tx)))
159 .await
160 }
161
162 pub async fn insert_block(
165 &self,
166 block: Arc<Block>,
167 ) -> Result<ContentAddress, AcquireThenRusqliteError> {
168 self.acquire_then(move |h| with_tx(h, |tx| crate::insert_block(tx, &block)))
169 .await
170 }
171
172 pub async fn finalize_block(
175 &self,
176 block_ca: ContentAddress,
177 ) -> Result<(), AcquireThenRusqliteError> {
178 self.acquire_then(move |h| crate::finalize_block(h, &block_ca))
179 .await
180 }
181
182 pub async fn update_state(
184 &self,
185 contract_ca: ContentAddress,
186 key: Key,
187 value: Value,
188 ) -> Result<(), AcquireThenRusqliteError> {
189 self.acquire_then(move |h| crate::update_state(h, &contract_ca, &key, &value))
190 .await
191 }
192
193 pub async fn delete_state(
195 &self,
196 contract_ca: ContentAddress,
197 key: Key,
198 ) -> Result<(), AcquireThenRusqliteError> {
199 self.acquire_then(move |h| crate::delete_state(h, &contract_ca, &key))
200 .await
201 }
202
203 pub async fn get_block(
205 &self,
206 block_address: ContentAddress,
207 ) -> Result<Option<Block>, AcquireThenQueryError> {
208 self.acquire_then(move |h| with_tx(h, |tx| crate::get_block(tx, &block_address)))
209 .await
210 }
211
212 pub async fn get_solution_set(
214 &self,
215 ca: ContentAddress,
216 ) -> Result<SolutionSet, AcquireThenQueryError> {
217 self.acquire_then(move |h| with_tx(h, |tx| crate::get_solution_set(tx, &ca)))
218 .await
219 }
220
221 pub async fn query_state(
223 &self,
224 contract_ca: ContentAddress,
225 key: Key,
226 ) -> Result<Option<Value>, AcquireThenQueryError> {
227 self.acquire_then(move |h| crate::query_state(h, &contract_ca, &key))
228 .await
229 }
230
231 pub async fn query_latest_finalized_block(
234 &self,
235 contract_ca: ContentAddress,
236 key: Key,
237 ) -> Result<Option<Value>, AcquireThenQueryError> {
238 self.acquire_then(move |h| {
239 let tx = h.transaction()?;
240 let Some(addr) = crate::get_latest_finalized_block_address(&tx)? else {
241 return Ok(None);
242 };
243 let Some(header) = crate::get_block_header(&tx, &addr)? else {
244 return Ok(None);
245 };
246 let value = crate::finalized::query_state_inclusive_block(
247 &tx,
248 &contract_ca,
249 &key,
250 header.number,
251 )?;
252 tx.finish()?;
253 Ok(value)
254 })
255 .await
256 }
257
258 pub async fn query_state_finalized_inclusive_block(
261 &self,
262 contract_ca: ContentAddress,
263 key: Key,
264 block_number: Word,
265 ) -> Result<Option<Value>, AcquireThenQueryError> {
266 self.acquire_then(move |h| {
267 crate::finalized::query_state_inclusive_block(h, &contract_ca, &key, block_number)
268 })
269 .await
270 }
271
272 pub async fn query_state_finalized_exclusive_block(
275 &self,
276 contract_ca: ContentAddress,
277 key: Key,
278 block_number: Word,
279 ) -> Result<Option<Value>, AcquireThenQueryError> {
280 self.acquire_then(move |h| {
281 crate::finalized::query_state_exclusive_block(h, &contract_ca, &key, block_number)
282 })
283 .await
284 }
285
286 pub async fn query_state_finalized_inclusive_solution_set(
289 &self,
290 contract_ca: ContentAddress,
291 key: Key,
292 block_number: Word,
293 solution_set_ix: u64,
294 ) -> Result<Option<Value>, AcquireThenQueryError> {
295 self.acquire_then(move |h| {
296 crate::finalized::query_state_inclusive_solution_set(
297 h,
298 &contract_ca,
299 &key,
300 block_number,
301 solution_set_ix,
302 )
303 })
304 .await
305 }
306
307 pub async fn query_state_finalized_exclusive_solution_set(
310 &self,
311 contract_ca: ContentAddress,
312 key: Key,
313 block_number: Word,
314 solution_set_ix: u64,
315 ) -> Result<Option<Value>, AcquireThenQueryError> {
316 self.acquire_then(move |h| {
317 crate::finalized::query_state_exclusive_solution_set(
318 h,
319 &contract_ca,
320 &key,
321 block_number,
322 solution_set_ix,
323 )
324 })
325 .await
326 }
327
328 pub async fn get_validation_progress(
330 &self,
331 ) -> Result<Option<ContentAddress>, AcquireThenQueryError> {
332 self.acquire_then(|h| crate::get_validation_progress(h))
333 .await
334 }
335
336 pub async fn get_next_block_addresses(
338 &self,
339 current_block: ContentAddress,
340 ) -> Result<Vec<ContentAddress>, AcquireThenQueryError> {
341 self.acquire_then(move |h| crate::get_next_block_addresses(h, ¤t_block))
342 .await
343 }
344
345 pub async fn update_validation_progress(
347 &self,
348 block_ca: ContentAddress,
349 ) -> Result<(), AcquireThenRusqliteError> {
350 self.acquire_then(move |h| crate::update_validation_progress(h, &block_ca))
351 .await
352 }
353
354 pub async fn list_blocks(
356 &self,
357 block_range: Range<Word>,
358 ) -> Result<Vec<Block>, AcquireThenQueryError> {
359 self.acquire_then(move |h| with_tx(h, |tx| crate::list_blocks(tx, block_range)))
360 .await
361 }
362
363 pub async fn list_blocks_by_time(
365 &self,
366 range: Range<Duration>,
367 page_size: i64,
368 page_number: i64,
369 ) -> Result<Vec<Block>, AcquireThenQueryError> {
370 self.acquire_then(move |h| {
371 with_tx(h, |tx| {
372 crate::list_blocks_by_time(tx, range, page_size, page_number)
373 })
374 })
375 .await
376 }
377
378 pub fn subscribe_blocks(
380 &self,
381 start_block: Word,
382 await_new_block: impl AwaitNewBlock,
383 ) -> impl Stream<Item = Result<Block, QueryError>> {
384 crate::subscribe_blocks(start_block, self.clone(), await_new_block)
385 }
386}
387
388impl Config {
389 pub fn new(source: Source, conn_limit: usize) -> Self {
391 Self { source, conn_limit }
392 }
393
394 pub fn default_conn_limit() -> usize {
400 num_cpus::get().saturating_mul(4)
402 }
403}
404
405impl Source {
406 pub fn default_memory() -> Self {
408 Self::Memory("__default-id".to_string())
410 }
411}
412
413impl AwaitNewBlock for BlockRx {
414 async fn await_new_block(&mut self) -> Option<()> {
415 self.changed().await.ok()
416 }
417}
418
419impl AsRef<AsyncConnectionPool> for ConnectionPool {
420 fn as_ref(&self) -> &AsyncConnectionPool {
421 &self.0
422 }
423}
424
425impl AsRef<rusqlite::Connection> for ConnectionHandle {
426 fn as_ref(&self) -> &rusqlite::Connection {
427 self
428 }
429}
430
431impl AsMut<rusqlite::Connection> for ConnectionHandle {
432 fn as_mut(&mut self) -> &mut rusqlite::Connection {
433 self
434 }
435}
436
437impl core::ops::Deref for ConnectionHandle {
438 type Target = AsyncConnectionHandle;
439 fn deref(&self) -> &Self::Target {
440 &self.0
441 }
442}
443
444impl core::ops::DerefMut for ConnectionHandle {
445 fn deref_mut(&mut self) -> &mut Self::Target {
446 &mut self.0
447 }
448}
449
450impl AcquireConnection for ConnectionPool {
451 async fn acquire_connection(&self) -> Option<impl 'static + AsMut<rusqlite::Connection>> {
452 self.acquire().await.ok()
453 }
454}
455
456impl Default for Source {
457 fn default() -> Self {
458 Self::default_memory()
459 }
460}
461
462impl Default for Config {
463 fn default() -> Self {
464 Self {
465 conn_limit: Self::default_conn_limit(),
466 source: Source::default(),
467 }
468 }
469}
470
471impl core::fmt::Display for ConnectionCloseErrors {
472 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
473 writeln!(f, "failed to close one or more connections:")?;
474 for (ix, (_conn, err)) in self.0.iter().enumerate() {
475 writeln!(f, " {ix}: {err}")?;
476 }
477 Ok(())
478 }
479}
480
481fn new_conn_pool(conf: &Config) -> rusqlite::Result<AsyncConnectionPool> {
483 AsyncConnectionPool::new(conf.conn_limit, || new_conn(&conf.source))
484}
485
486pub(crate) fn new_conn(source: &Source) -> rusqlite::Result<rusqlite::Connection> {
488 let conn = match source {
489 Source::Memory(id) => new_mem_conn(id),
490 Source::Path(p) => {
491 if let Some(dir) = p.parent() {
492 let _ = std::fs::create_dir_all(dir);
493 }
494 let conn = rusqlite::Connection::open(p)?;
495 conn.pragma_update(None, "trusted_schema", false)?;
496 conn.pragma_update(None, "synchronous", 1)?;
497 Ok(conn)
498 }
499 }?;
500 conn.pragma_update(None, "foreign_keys", true)?;
501 Ok(conn)
502}
503
504fn new_mem_conn(id: &str) -> rusqlite::Result<rusqlite::Connection> {
506 let conn_str = format!("file:/{id}");
507 rusqlite::Connection::open_with_flags_and_vfs(conn_str, Default::default(), "memdb")
508}