use super::metrics::{
create_connection_use_time_metric, create_write_txn_duration_metric, Histogram,
};
use crate::db::conn::PConn;
use crate::db::databases::DATABASE_HANDLES;
use crate::db::guard::{PConnGuard, PTxnGuard};
use crate::db::kind::{DbKind, DbKindT};
use crate::db::pool::{initialize_connection, new_connection_pool, ConnectionPool, PoolConfig};
use crate::error::{DatabaseError, DatabaseResult};
use derive_more::Into;
use holochain_util::log_elapsed;
use parking_lot::Mutex;
use rusqlite::trace::{TraceEvent, TraceEventCodes};
use rusqlite::*;
use shrinkwraprs::Shrinkwrap;
use std::marker::PhantomData;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use std::{collections::HashMap, path::Path};
use std::{path::PathBuf, sync::atomic::AtomicUsize};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tracing::Instrument;
static ACQUIRE_TIMEOUT_MS: AtomicU64 = AtomicU64::new(10_000);
static THREAD_ACQUIRE_TIMEOUT_MS: AtomicU64 = AtomicU64::new(30_000);
#[derive(derive_more::Deref, derive_more::DerefMut, derive_more::Into)]
pub struct Txn<'a, 'txn, D: DbKindT> {
#[deref]
#[deref_mut]
#[into]
txn: &'a mut Transaction<'txn>,
db_kind: std::marker::PhantomData<D>,
}
impl<'a, 'txn, D: DbKindT> From<&'a mut Transaction<'txn>> for Txn<'a, 'txn, D> {
fn from(txn: &'a mut Transaction<'txn>) -> Self {
Txn {
txn,
db_kind: PhantomData,
}
}
}
#[async_trait::async_trait]
pub trait ReadAccess<Kind: DbKindT>: Clone + Into<DbRead<Kind>> {
async fn read_async<E, R, F>(&self, f: F) -> Result<R, E>
where
E: From<DatabaseError> + Send + 'static,
F: FnOnce(&Txn<Kind>) -> Result<R, E> + Send + 'static,
R: Send + 'static;
fn kind(&self) -> &Kind;
}
#[async_trait::async_trait]
impl<Kind: DbKindT> ReadAccess<Kind> for DbWrite<Kind> {
#[cfg_attr(feature = "instrument", tracing::instrument(skip_all, fields(kind = ?self.kind)))]
async fn read_async<E, R, F>(&self, f: F) -> Result<R, E>
where
E: From<DatabaseError> + Send + 'static,
F: FnOnce(&Txn<Kind>) -> Result<R, E> + Send + 'static,
R: Send + 'static,
{
let db: &DbRead<Kind> = self.as_ref();
DbRead::read_async(db, f).await
}
fn kind(&self) -> &Kind {
self.0.kind()
}
}
#[async_trait::async_trait]
impl<Kind: DbKindT> ReadAccess<Kind> for DbRead<Kind> {
#[cfg_attr(feature = "instrument", tracing::instrument(skip_all, fields(kind = ?self.kind)))]
async fn read_async<E, R, F>(&self, f: F) -> Result<R, E>
where
E: From<DatabaseError> + Send + 'static,
F: FnOnce(&Txn<Kind>) -> Result<R, E> + Send + 'static,
R: Send + 'static,
{
DbRead::read_async(self, f).await
}
fn kind(&self) -> &Kind {
&self.kind
}
}
#[derive(Clone)]
pub struct DbRead<Kind: DbKindT> {
kind: Kind,
path: PathBuf,
connection_pool: ConnectionPool,
write_semaphore: Arc<Semaphore>,
read_semaphore: Arc<Semaphore>,
long_read_semaphore: Arc<Semaphore>,
statement_trace_fn: Option<fn(TraceEvent)>,
max_readers: usize,
num_readers: Arc<AtomicUsize>,
use_time_metric: Histogram,
write_txn_metric: Histogram,
}
impl<Kind: DbKindT> std::fmt::Debug for DbRead<Kind> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DbRead")
.field("kind", &self.kind)
.field("path", &self.path)
.field("max_readers", &self.max_readers)
.field("num_readers", &self.num_readers)
.finish()
}
}
impl<Kind: DbKindT> DbRead<Kind> {
pub fn kind(&self) -> &Kind {
&self.kind
}
pub fn path(&self) -> &PathBuf {
&self.path
}
#[cfg_attr(feature = "instrument", tracing::instrument(skip_all, fields(kind = ?self.kind)))]
pub async fn read_async<E, R, F>(&self, f: F) -> Result<R, E>
where
E: From<DatabaseError> + Send + 'static,
F: FnOnce(&Txn<Kind>) -> Result<R, E> + Send + 'static,
R: Send + 'static,
{
let mut conn = self
.checkout_connection(self.read_semaphore.clone())
.await?;
let start = tokio::time::Instant::now();
let span = tracing::info_span!("spawn_blocking");
tokio::time::timeout(std::time::Duration::from_millis(THREAD_ACQUIRE_TIMEOUT_MS.load(Ordering::Acquire)), tokio::task::spawn_blocking(move || {
let _s = span.enter();
log_elapsed!([10, 100, 1000], start, "read_async:before-closure");
let r = conn.execute_in_read_txn(|mut txn| f(&Txn::from(&mut txn)));
log_elapsed!([10, 100, 1000], start, "read_async:after-closure");
r
}).in_current_span()).in_current_span().await.map_err(|e| {
tracing::error!("Failed to claim a thread to run the database read transaction. It's likely that the program is out of threads.");
DatabaseError::Timeout(e)
})?.map_err(DatabaseError::from)?
}
#[cfg_attr(feature = "instrument", tracing::instrument)]
pub async fn get_read_txn(&self) -> DatabaseResult<PTxnGuard> {
let conn = self
.checkout_connection(self.long_read_semaphore.clone())
.await?;
Ok(conn.into())
}
#[cfg_attr(feature = "instrument", tracing::instrument)]
async fn checkout_connection(&self, semaphore: Arc<Semaphore>) -> DatabaseResult<PConnGuard> {
let waiting = self.num_readers.fetch_add(1, Ordering::Relaxed);
if waiting > self.max_readers {
let s = tracing::info_span!("holochain_perf", kind = ?self.kind().kind());
s.in_scope(|| {
tracing::info!(
"Database read connection is saturated. Util {:.2}%",
waiting as f64 / self.max_readers as f64 * 100.0
)
});
} else {
tracing::trace!("checkout_connection ready to acquire semaphore");
}
let permit = acquire_semaphore_permit(semaphore).await?;
self.num_readers.fetch_sub(1, Ordering::Relaxed);
let conn = self.get_connection_from_pool()?;
if self.statement_trace_fn.is_some() {
conn.trace_v2(
TraceEventCodes::SQLITE_TRACE_PROFILE,
self.statement_trace_fn,
);
}
Ok(PConnGuard::new(conn, permit, self.use_time_metric.clone()))
}
#[cfg_attr(feature = "instrument", tracing::instrument)]
fn get_connection_from_pool(&self) -> DatabaseResult<PConn> {
let now = Instant::now();
let r = Ok(PConn::new(self.connection_pool.get()?));
let el = now.elapsed();
if el.as_millis() > 20 {
tracing::info!("Connection pool took {:?} to be freed", el);
} else {
tracing::trace!("Got connection");
}
r
}
#[cfg(all(any(test, feature = "test_utils"), not(loom)))]
pub fn test_read<R, F>(&self, f: F) -> R
where
F: FnOnce(&Txn<Kind>) -> R + Send + 'static,
R: Send + 'static,
{
holochain_util::tokio_helper::block_forever_on(async {
self.read_async(move |txn| -> DatabaseResult<R> { Ok(f(txn)) })
.await
.unwrap()
})
}
}
#[derive(Clone, Debug, Shrinkwrap, Into)]
pub struct DbWrite<Kind: DbKindT>(DbRead<Kind>);
impl<Kind: DbKindT + Send + Sync + 'static> DbWrite<Kind> {
pub fn open_with_pool_config(
path_prefix: &Path,
kind: Kind,
pool_config: PoolConfig,
) -> DatabaseResult<Self> {
DATABASE_HANDLES.get_or_insert(&kind, path_prefix, |kind| {
Self::new(Some(path_prefix), kind, pool_config, None)
})
}
pub fn new(
path_prefix: Option<&Path>,
kind: Kind,
pool_config: PoolConfig,
statement_trace_fn: Option<fn(TraceEvent)>,
) -> DatabaseResult<Self> {
let path = match path_prefix {
Some(path_prefix) => {
let path = path_prefix.join(kind.filename());
let parent = path
.parent()
.ok_or_else(|| DatabaseError::DatabaseMissing(path_prefix.to_owned()))?;
if !parent.is_dir() {
std::fs::create_dir_all(parent)
.map_err(|_e| DatabaseError::DatabaseMissing(parent.to_owned()))?;
}
match Self::check_database_file(&path, &pool_config) {
Ok(path) => path,
Err(err) => {
if "true"
== std::env::var("HOLOCHAIN_MIGRATE_UNENCRYPTED")
.unwrap_or_default()
.as_str()
{
#[cfg(feature = "sqlite-encrypted")]
encrypt_unencrypted_database(&path, &pool_config)?;
}
else if kind.if_corrupt_wipe() {
std::fs::remove_file(&path)?;
} else {
return Err(err.into());
}
match Self::check_database_file(&path, &pool_config) {
Ok(path) => path,
Err(e) => return Err(e.into()),
}
}
}
}
None => None,
};
let max_readers = pool_config.max_readers as usize;
let max_short_readers = std::cmp::max((pool_config.max_readers / 2) as usize, 1);
let max_long_readers = std::cmp::max(max_readers.saturating_sub(max_short_readers), 1);
let pool = new_connection_pool(path.as_ref().map(|p| p.as_ref()), pool_config);
let mut conn = pool.get()?;
conn.pragma_update(None, "journal_mode", "WAL".to_string())?;
crate::table::initialize_database(&mut conn, kind.kind())?;
let use_time_metric = create_connection_use_time_metric(kind.kind());
let write_txn_metric = create_write_txn_duration_metric(kind.kind());
let db_read = DbRead {
write_semaphore: Self::get_write_semaphore(kind.kind()),
read_semaphore: Self::get_read_semaphore(kind.kind(), max_short_readers),
long_read_semaphore: Self::get_long_read_semaphore(kind.kind(), max_long_readers),
max_readers,
num_readers: Arc::new(AtomicUsize::new(0)),
kind: kind.clone(),
path: path.unwrap_or_default(),
connection_pool: pool,
statement_trace_fn,
use_time_metric,
write_txn_metric,
};
Ok(DbWrite(db_read))
}
#[cfg_attr(feature = "instrument", tracing::instrument(skip_all, fields(kind = ?self.kind)))]
pub async fn write_async<E, R, F>(&self, f: F) -> Result<R, E>
where
E: From<DatabaseError> + Send + 'static,
F: FnOnce(&mut Txn<Kind>) -> Result<R, E> + Send + 'static,
R: Send + 'static,
{
let permit = acquire_semaphore_permit(self.0.write_semaphore.clone()).await?;
self.write_async_with_permit(permit, f)
.await
.map(|(r, _permit)| r)
}
#[cfg_attr(feature = "instrument", tracing::instrument(skip_all, fields(kind = ?self.kind)))]
pub async fn write_async_with_permit<E, R, F>(
&self,
permit: OwnedSemaphorePermit,
f: F,
) -> Result<(R, OwnedSemaphorePermit), E>
where
E: From<DatabaseError> + Send + 'static,
F: FnOnce(&mut Txn<Kind>) -> Result<R, E> + Send + 'static,
R: Send + 'static,
{
let mut conn = self.get_connection_from_pool()?;
let write_txn_metric = self.0.write_txn_metric.clone();
let start = tokio::time::Instant::now();
let span = tracing::info_span!("spawn_blocking");
tokio::time::timeout(std::time::Duration::from_millis(THREAD_ACQUIRE_TIMEOUT_MS.load(Ordering::Acquire)), tokio::task::spawn_blocking(move || {
let _s = span.enter();
log_elapsed!([10, 100, 1000], start, "write_async:before-closure");
let txn_start = std::time::Instant::now();
let r = conn.execute_in_exclusive_rw_txn(|txn| f(&mut Txn::from(txn)));
write_txn_metric.record(txn_start.elapsed().as_secs_f64(), &[]);
log_elapsed!([10, 100, 1000], start, "write_async:after-closure");
r.map(|r| (r, permit))
}).in_current_span()).in_current_span().await.map_err(|e| {
tracing::error!("Failed to claim a thread to run the database write transaction. It's likely that the program is out of threads.");
DatabaseError::Timeout(e)
})?.map_err(DatabaseError::from)?
}
pub async fn acquire_write_permit(&self) -> DatabaseResult<OwnedSemaphorePermit> {
acquire_semaphore_permit(self.0.write_semaphore.clone()).await
}
fn get_write_semaphore(kind: DbKind) -> Arc<Semaphore> {
static MAP: once_cell::sync::Lazy<Mutex<HashMap<DbKind, Arc<Semaphore>>>> =
once_cell::sync::Lazy::new(|| Mutex::new(HashMap::new()));
MAP.lock()
.entry(kind)
.or_insert_with(|| Arc::new(Semaphore::new(1)))
.clone()
}
fn get_read_semaphore(kind: DbKind, num_permits: usize) -> Arc<Semaphore> {
static MAP: once_cell::sync::Lazy<Mutex<HashMap<DbKind, Arc<Semaphore>>>> =
once_cell::sync::Lazy::new(|| Mutex::new(HashMap::new()));
MAP.lock()
.entry(kind)
.or_insert_with(|| Arc::new(Semaphore::new(num_permits)))
.clone()
}
fn get_long_read_semaphore(kind: DbKind, num_permits: usize) -> Arc<Semaphore> {
static MAP: once_cell::sync::Lazy<Mutex<HashMap<DbKind, Arc<Semaphore>>>> =
once_cell::sync::Lazy::new(|| Mutex::new(HashMap::new()));
MAP.lock()
.entry(kind)
.or_insert_with(|| Arc::new(Semaphore::new(num_permits)))
.clone()
}
fn check_database_file(
path: &Path,
pool_config: &PoolConfig,
) -> rusqlite::Result<Option<PathBuf>> {
Connection::open(path)
.and_then(|mut c| {
initialize_connection(&mut c, pool_config)?;
c.pragma_update(None, "synchronous", "0".to_string())?;
Ok(c.path().map(PathBuf::from))
})
}
#[cfg(any(test, feature = "test_utils"))]
pub fn test(path: &Path, kind: Kind) -> DatabaseResult<Self> {
Self::new(Some(path), kind, PoolConfig::default(), None)
}
#[cfg(any(test, feature = "test_utils"))]
pub fn test_in_mem(kind: Kind) -> DatabaseResult<Self> {
Self::new(
None,
kind,
PoolConfig::default(),
Some(|trace_event| {
match trace_event {
TraceEvent::Profile(stmt, dur) => {
tracing::debug!("SQLITE TRACE: {} took {:?}", stmt.sql(), dur);
}
_ => {
}
}
}),
)
}
#[cfg(all(any(test, feature = "test_utils"), not(loom)))]
pub fn test_write<R, F>(&self, f: F) -> R
where
F: FnOnce(&mut Txn<Kind>) -> R + Send + 'static,
R: Send + 'static,
{
holochain_util::tokio_helper::block_forever_on(async {
self.write_async(|txn| -> DatabaseResult<R> { Ok(f(txn)) })
.await
.unwrap()
})
}
#[cfg(any(test, feature = "test_utils"))]
pub fn connection_pool_max_size(&self) -> u32 {
self.0.connection_pool.max_size()
}
#[cfg(any(test, feature = "test_utils"))]
pub fn available_short_reader_count(&self) -> usize {
self.read_semaphore.available_permits()
}
#[cfg(any(test, feature = "test_utils"))]
pub fn available_long_reader_count(&self) -> usize {
self.0.long_read_semaphore.available_permits()
}
}
#[cfg(feature = "sqlite-encrypted")]
pub fn encrypt_unencrypted_database(path: &Path, pool_config: &PoolConfig) -> DatabaseResult<()> {
let encrypted_path = path
.parent()
.ok_or_else(|| DatabaseError::DatabaseMissing(path.to_owned()))?
.join(
path.file_stem()
.and_then(|s| s.to_str())
.ok_or_else(|| DatabaseError::DatabaseMissing(path.to_owned()))?
.to_string()
+ "-encrypted",
);
tracing::warn!(
"Attempting encryption of unencrypted database: {:?} -> {:?}",
path,
encrypted_path
);
{
let conn = Connection::open(path)?;
conn.execute("VACUUM", ())?;
conn.execute("BEGIN EXCLUSIVE", ())?;
{
let mut mutex_guard = pool_config.key.unlocked.lock().unwrap();
let lock = mutex_guard.lock();
conn.execute(
"ATTACH DATABASE :db_name AS encrypted KEY :key",
rusqlite::named_params! {
":db_name": encrypted_path.to_str(),
":key": &lock[15..82],
},
)?;
}
let mut batch = "PRAGMA encrypted.cipher_salt = \"x'".to_string();
for b in &pool_config.key.salt {
batch.push_str(&format!("{b:02X}"));
}
batch.push_str("'\";\n");
batch.push_str("PRAGMA encrypted.cipher_compatibility = 4;\n");
batch.push_str("PRAGMA encrypted.cipher_plaintext_header_size = 32;\n");
conn.execute_batch(&batch)?;
conn.query_row("SELECT sqlcipher_export('encrypted')", (), |_| Ok(0))?;
conn.execute("COMMIT", ())?;
conn.execute("DETACH DATABASE encrypted", ())?;
conn.close().map_err(|(_, err)| err)?;
}
std::fs::remove_file(path)?;
std::fs::rename(encrypted_path, path)?;
Ok(())
}
#[cfg(feature = "test_utils")]
pub fn set_acquire_timeout(timeout_ms: u64) {
ACQUIRE_TIMEOUT_MS.store(timeout_ms, Ordering::Relaxed);
}
#[cfg_attr(feature = "instrument", tracing::instrument)]
async fn acquire_semaphore_permit(
semaphore: Arc<Semaphore>,
) -> DatabaseResult<OwnedSemaphorePermit> {
let id = nanoid::nanoid!(7);
tracing::trace!(?id, "??? acquire semaphore permit");
let permit = tokio::time::timeout(
std::time::Duration::from_millis(ACQUIRE_TIMEOUT_MS.load(Ordering::Acquire)),
semaphore.acquire_owned(),
)
.await;
tracing::trace!(?id, ?permit, " !!! semaphore permit obtained");
match permit {
Ok(Ok(s)) => Ok(s),
Ok(Err(e)) => {
tracing::error!(
"Semaphore should not be closed but got an error while acquiring a permit, {:?}",
e
);
Err(DatabaseError::Other(e.into()))
}
Err(e) => Err(DatabaseError::Timeout(e)),
}
}