use crate::{with_tx, AcquireConnection, AwaitNewBlock, QueryError};
use core::ops::Range;
use essential_node_types::{block_notify::BlockRx, Block};
use essential_types::{solution::SolutionSet, ContentAddress, Key, Value, Word};
use futures::Stream;
use rusqlite_pool::tokio::{AsyncConnectionHandle, AsyncConnectionPool};
use std::{path::PathBuf, sync::Arc, time::Duration};
use thiserror::Error;
use tokio::sync::{AcquireError, TryAcquireError};
#[derive(Clone)]
pub struct ConnectionPool(AsyncConnectionPool);
pub struct ConnectionHandle(AsyncConnectionHandle);
#[derive(Clone, Debug)]
pub struct Config {
pub conn_limit: usize,
pub source: Source,
}
#[derive(Clone, Debug)]
pub enum Source {
Memory(String),
Path(PathBuf),
}
#[derive(Debug, Error)]
pub enum AcquireThenError<E> {
#[error("failed to acquire a DB connection: {0}")]
Acquire(#[from] tokio::sync::AcquireError),
#[error("failed to join task: {0}")]
Join(#[from] tokio::task::JoinError),
#[error("{0}")]
Inner(E),
}
pub type AcquireThenRusqliteError = AcquireThenError<rusqlite::Error>;
pub type AcquireThenQueryError = AcquireThenError<crate::QueryError>;
#[derive(Debug, Error)]
pub struct ConnectionCloseErrors(pub Vec<(rusqlite::Connection, rusqlite::Error)>);
impl ConnectionPool {
pub fn new(conf: &Config) -> rusqlite::Result<Self> {
let conn_pool = Self(new_conn_pool(conf)?);
if let Source::Path(_) = conf.source {
let conn = conn_pool
.try_acquire()
.expect("pool must have at least one connection");
conn.pragma_update(None, "journal_mode", "wal")?;
}
Ok(conn_pool)
}
pub fn with_tables(conf: &Config) -> rusqlite::Result<Self> {
let conn_pool = Self::new(conf)?;
let mut conn = conn_pool.try_acquire().unwrap();
with_tx(&mut conn, |tx| crate::create_tables(tx))?;
Ok(conn_pool)
}
pub async fn acquire(&self) -> Result<ConnectionHandle, AcquireError> {
self.0.acquire().await.map(ConnectionHandle)
}
pub fn try_acquire(&self) -> Result<ConnectionHandle, TryAcquireError> {
self.0.try_acquire().map(ConnectionHandle)
}
pub fn close(&self) -> Result<(), ConnectionCloseErrors> {
let res = self.0.close();
let errs: Vec<_> = res.into_iter().filter_map(Result::err).collect();
if !errs.is_empty() {
return Err(ConnectionCloseErrors(errs));
}
Ok(())
}
}
impl ConnectionPool {
pub async fn acquire_then<F, T, E>(&self, f: F) -> Result<T, AcquireThenError<E>>
where
F: 'static + Send + FnOnce(&mut ConnectionHandle) -> Result<T, E>,
T: 'static + Send,
E: 'static + Send,
{
let mut handle = self.acquire().await?;
tokio::task::spawn_blocking(move || f(&mut handle))
.await?
.map_err(AcquireThenError::Inner)
}
pub async fn create_tables(&self) -> Result<(), AcquireThenRusqliteError> {
self.acquire_then(|h| with_tx(h, |tx| crate::create_tables(tx)))
.await
}
pub async fn insert_block(
&self,
block: Arc<Block>,
) -> Result<ContentAddress, AcquireThenRusqliteError> {
self.acquire_then(move |h| with_tx(h, |tx| crate::insert_block(tx, &block)))
.await
}
pub async fn finalize_block(
&self,
block_ca: ContentAddress,
) -> Result<(), AcquireThenRusqliteError> {
self.acquire_then(move |h| crate::finalize_block(h, &block_ca))
.await
}
pub async fn update_state(
&self,
contract_ca: ContentAddress,
key: Key,
value: Value,
) -> Result<(), AcquireThenRusqliteError> {
self.acquire_then(move |h| crate::update_state(h, &contract_ca, &key, &value))
.await
}
pub async fn delete_state(
&self,
contract_ca: ContentAddress,
key: Key,
) -> Result<(), AcquireThenRusqliteError> {
self.acquire_then(move |h| crate::delete_state(h, &contract_ca, &key))
.await
}
pub async fn get_block(
&self,
block_address: ContentAddress,
) -> Result<Option<Block>, AcquireThenQueryError> {
self.acquire_then(move |h| with_tx(h, |tx| crate::get_block(tx, &block_address)))
.await
}
pub async fn get_solution_set(
&self,
ca: ContentAddress,
) -> Result<SolutionSet, AcquireThenQueryError> {
self.acquire_then(move |h| with_tx(h, |tx| crate::get_solution_set(tx, &ca)))
.await
}
pub async fn query_state(
&self,
contract_ca: ContentAddress,
key: Key,
) -> Result<Option<Value>, AcquireThenQueryError> {
self.acquire_then(move |h| crate::query_state(h, &contract_ca, &key))
.await
}
pub async fn query_latest_finalized_block(
&self,
contract_ca: ContentAddress,
key: Key,
) -> Result<Option<Value>, AcquireThenQueryError> {
self.acquire_then(move |h| {
let tx = h.transaction()?;
let Some(addr) = crate::get_latest_finalized_block_address(&tx)? else {
return Ok(None);
};
let Some(header) = crate::get_block_header(&tx, &addr)? else {
return Ok(None);
};
let value = crate::finalized::query_state_inclusive_block(
&tx,
&contract_ca,
&key,
header.number,
)?;
tx.finish()?;
Ok(value)
})
.await
}
pub async fn query_state_finalized_inclusive_block(
&self,
contract_ca: ContentAddress,
key: Key,
block_number: Word,
) -> Result<Option<Value>, AcquireThenQueryError> {
self.acquire_then(move |h| {
crate::finalized::query_state_inclusive_block(h, &contract_ca, &key, block_number)
})
.await
}
pub async fn query_state_finalized_exclusive_block(
&self,
contract_ca: ContentAddress,
key: Key,
block_number: Word,
) -> Result<Option<Value>, AcquireThenQueryError> {
self.acquire_then(move |h| {
crate::finalized::query_state_exclusive_block(h, &contract_ca, &key, block_number)
})
.await
}
pub async fn query_state_finalized_inclusive_solution_set(
&self,
contract_ca: ContentAddress,
key: Key,
block_number: Word,
solution_set_ix: u64,
) -> Result<Option<Value>, AcquireThenQueryError> {
self.acquire_then(move |h| {
crate::finalized::query_state_inclusive_solution_set(
h,
&contract_ca,
&key,
block_number,
solution_set_ix,
)
})
.await
}
pub async fn query_state_finalized_exclusive_solution_set(
&self,
contract_ca: ContentAddress,
key: Key,
block_number: Word,
solution_set_ix: u64,
) -> Result<Option<Value>, AcquireThenQueryError> {
self.acquire_then(move |h| {
crate::finalized::query_state_exclusive_solution_set(
h,
&contract_ca,
&key,
block_number,
solution_set_ix,
)
})
.await
}
pub async fn get_validation_progress(
&self,
) -> Result<Option<ContentAddress>, AcquireThenQueryError> {
self.acquire_then(|h| crate::get_validation_progress(h))
.await
}
pub async fn get_next_block_addresses(
&self,
current_block: ContentAddress,
) -> Result<Vec<ContentAddress>, AcquireThenQueryError> {
self.acquire_then(move |h| crate::get_next_block_addresses(h, ¤t_block))
.await
}
pub async fn update_validation_progress(
&self,
block_ca: ContentAddress,
) -> Result<(), AcquireThenRusqliteError> {
self.acquire_then(move |h| crate::update_validation_progress(h, &block_ca))
.await
}
pub async fn list_blocks(
&self,
block_range: Range<Word>,
) -> Result<Vec<Block>, AcquireThenQueryError> {
self.acquire_then(move |h| with_tx(h, |tx| crate::list_blocks(tx, block_range)))
.await
}
pub async fn list_blocks_by_time(
&self,
range: Range<Duration>,
page_size: i64,
page_number: i64,
) -> Result<Vec<Block>, AcquireThenQueryError> {
self.acquire_then(move |h| {
with_tx(h, |tx| {
crate::list_blocks_by_time(tx, range, page_size, page_number)
})
})
.await
}
pub fn subscribe_blocks(
&self,
start_block: Word,
await_new_block: impl AwaitNewBlock,
) -> impl Stream<Item = Result<Block, QueryError>> {
crate::subscribe_blocks(start_block, self.clone(), await_new_block)
}
}
impl Config {
pub fn new(source: Source, conn_limit: usize) -> Self {
Self { source, conn_limit }
}
pub fn default_conn_limit() -> usize {
num_cpus::get().saturating_mul(4)
}
}
impl Source {
pub fn default_memory() -> Self {
Self::Memory("__default-id".to_string())
}
}
impl AwaitNewBlock for BlockRx {
async fn await_new_block(&mut self) -> Option<()> {
self.changed().await.ok()
}
}
impl AsRef<AsyncConnectionPool> for ConnectionPool {
fn as_ref(&self) -> &AsyncConnectionPool {
&self.0
}
}
impl AsRef<rusqlite::Connection> for ConnectionHandle {
fn as_ref(&self) -> &rusqlite::Connection {
self
}
}
impl AsMut<rusqlite::Connection> for ConnectionHandle {
fn as_mut(&mut self) -> &mut rusqlite::Connection {
self
}
}
impl core::ops::Deref for ConnectionHandle {
type Target = AsyncConnectionHandle;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl core::ops::DerefMut for ConnectionHandle {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl AcquireConnection for ConnectionPool {
async fn acquire_connection(&self) -> Option<impl 'static + AsMut<rusqlite::Connection>> {
self.acquire().await.ok()
}
}
impl Default for Source {
fn default() -> Self {
Self::default_memory()
}
}
impl Default for Config {
fn default() -> Self {
Self {
conn_limit: Self::default_conn_limit(),
source: Source::default(),
}
}
}
impl core::fmt::Display for ConnectionCloseErrors {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
writeln!(f, "failed to close one or more connections:")?;
for (ix, (_conn, err)) in self.0.iter().enumerate() {
writeln!(f, " {ix}: {err}")?;
}
Ok(())
}
}
fn new_conn_pool(conf: &Config) -> rusqlite::Result<AsyncConnectionPool> {
AsyncConnectionPool::new(conf.conn_limit, || new_conn(&conf.source))
}
pub(crate) fn new_conn(source: &Source) -> rusqlite::Result<rusqlite::Connection> {
let conn = match source {
Source::Memory(id) => new_mem_conn(id),
Source::Path(p) => {
if let Some(dir) = p.parent() {
let _ = std::fs::create_dir_all(dir);
}
let conn = rusqlite::Connection::open(p)?;
conn.pragma_update(None, "trusted_schema", false)?;
conn.pragma_update(None, "synchronous", 1)?;
Ok(conn)
}
}?;
conn.pragma_update(None, "foreign_keys", true)?;
Ok(conn)
}
fn new_mem_conn(id: &str) -> rusqlite::Result<rusqlite::Connection> {
let conn_str = format!("file:/{id}");
rusqlite::Connection::open_with_flags_and_vfs(conn_str, Default::default(), "memdb")
}