use std::{
convert::TryFrom,
env::temp_dir,
fs,
iter,
path::{Path, PathBuf},
sync::{Arc, RwLock, RwLockWriteGuard},
};
use diesel::{
SqliteConnection,
r2d2::{ConnectionManager, PooledConnection},
};
use diesel_migrations::{EmbeddedMigrations, MigrationHarness};
use log::*;
use rand::{Rng, distributions::Alphanumeric, thread_rng};
use serde::{Deserialize, Serialize};
use crate::{
connection_options::PRAGMA_BUSY_TIMEOUT,
error::{SqliteStorageError, StorageError},
sqlite_connection_pool::{PooledDbConnection, SqliteConnectionPool},
};
const LOG_TARGET: &str = "common_sqlite::connection";
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(into = "String", try_from = "String")]
pub enum DbConnectionUrl {
Memory,
MemoryShared(String),
File(PathBuf),
}
impl DbConnectionUrl {
pub fn file<P: AsRef<Path>>(path: P) -> Self {
DbConnectionUrl::File(path.as_ref().to_path_buf())
}
pub fn to_url_string(&self) -> String {
use DbConnectionUrl::{File, Memory, MemoryShared};
match self {
Memory => ":memory:".to_owned(),
MemoryShared(identifier) => format!("file:{identifier}?mode=memory&cache=shared"),
File(path) => path
.to_str()
.expect("Invalid non-UTF8 character in database path")
.to_owned(),
}
}
pub fn set_base_path<P: AsRef<Path>>(&mut self, base_path: P) {
if let DbConnectionUrl::File(inner) = self &&
!inner.is_absolute()
{
*inner = base_path.as_ref().join(inner.as_path());
}
}
}
impl From<DbConnectionUrl> for String {
fn from(source: DbConnectionUrl) -> Self {
source.to_url_string()
}
}
impl TryFrom<String> for DbConnectionUrl {
type Error = String;
fn try_from(value: String) -> Result<Self, Self::Error> {
if value.as_str() == ":memory:" {
Ok(Self::Memory)
} else {
Ok(Self::File(PathBuf::from(value)))
}
}
}
lazy_static::lazy_static! {
static ref DB_WRITE_LOCK: Arc<RwLock<()>> = Arc::new(RwLock::new(()));
}
#[derive(Clone)]
pub struct DbConnection {
pool: SqliteConnectionPool,
}
impl DbConnection {
pub fn connect_url(db_url: &DbConnectionUrl, sqlite_pool_size: Option<usize>) -> Result<Self, StorageError> {
debug!(target: LOG_TARGET, "Connecting to database using '{db_url:?}'");
if let DbConnectionUrl::File(path) = db_url &&
let Some(parent) = path.parent()
{
std::fs::create_dir_all(parent)?;
}
let mut pool = SqliteConnectionPool::new(
db_url.to_url_string(),
sqlite_pool_size.unwrap_or(1),
true,
true,
PRAGMA_BUSY_TIMEOUT,
);
pool.create_pool()?;
debug!(target: LOG_TARGET, "{}", pool);
Ok(Self::new(pool))
}
fn acquire_migration_write_lock() -> Result<RwLockWriteGuard<'static, ()>, StorageError> {
match DB_WRITE_LOCK.write() {
Ok(value) => Ok(value),
Err(err) => Err(StorageError::DatabaseMigrationLockError(format!(
"Failed to acquire write lock for database migration: {err}"
))),
}
}
#[inline]
pub fn migration_lock_active() -> bool {
DB_WRITE_LOCK.try_read().is_err()
}
pub fn connect_and_migrate(
db_url: &DbConnectionUrl,
migrations: EmbeddedMigrations,
sqlite_pool_size: Option<usize>,
) -> Result<Self, StorageError> {
let _lock = Self::acquire_migration_write_lock()?;
let conn = Self::connect_url(db_url, sqlite_pool_size)?;
let output = conn.migrate(migrations)?;
debug!(target: LOG_TARGET, "Database migration: {}", output.trim());
Ok(conn)
}
fn temp_db_dir() -> PathBuf {
temp_dir().join("tari-temp")
}
pub fn connect_temp_file_and_migrate(migrations: EmbeddedMigrations) -> Result<Self, StorageError> {
fn prefixed_string(prefix: &str, len: usize) -> String {
let mut rng = thread_rng();
let rand_str = iter::repeat(())
.map(|_| rng.sample(Alphanumeric) as char)
.take(len)
.collect::<String>();
format!("{prefix}{rand_str}")
}
let path = DbConnection::temp_db_dir().join(prefixed_string("data-", 20));
fs::create_dir_all(&path)?;
let db_url = DbConnectionUrl::File(path.join("my_temp.db"));
DbConnection::connect_and_migrate(&db_url, migrations, Some(10))
}
fn new(pool: SqliteConnectionPool) -> Self {
Self { pool }
}
pub fn get_pooled_connection(&self) -> Result<PooledConnection<ConnectionManager<SqliteConnection>>, StorageError> {
self.pool.get_pooled_connection().map_err(StorageError::DieselR2d2Error)
}
pub fn migrate(&self, migrations: EmbeddedMigrations) -> Result<String, StorageError> {
let mut conn = self.get_pooled_connection()?;
let result: Vec<String> = conn
.run_pending_migrations(migrations)
.map(|v| v.into_iter().map(|b| format!("Running migration {b}")).collect())
.map_err(|err| StorageError::DatabaseMigrationFailed(format!("Database migration failed {err}")))?;
Ok(result.join("\r\n"))
}
#[cfg(test)]
pub(crate) fn db_path(&self) -> PathBuf {
self.pool.db_path()
}
}
impl Drop for DbConnection {
fn drop(&mut self) {
let path = self.pool.db_path();
if path.exists() &&
let Some(parent) = path.parent() &&
parent.starts_with(DbConnection::temp_db_dir())
{
debug!(target: LOG_TARGET, "DbConnection - Dropping database: {}", path.display());
let pool_state = self.pool.cleanup();
debug!(target: LOG_TARGET, "DbConnection - Pool stats before cleanup: {pool_state:?}");
debug!(target: LOG_TARGET, "DbConnection - Cleaning up tempdir: {}", parent.display());
if let Err(e) = fs::remove_dir_all(parent) {
error!(target: LOG_TARGET, "Failed to clean up temp dir: {e}");
} else {
debug!(target: LOG_TARGET, "Temp dir cleaned up: {}", parent.display());
}
}
}
}
impl PooledDbConnection for DbConnection {
type Error = SqliteStorageError;
fn get_pooled_connection(&self) -> Result<PooledConnection<ConnectionManager<SqliteConnection>>, Self::Error> {
let conn = self.pool.get_pooled_connection()?;
Ok(conn)
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use diesel::{
RunQueryDsl,
connection::SimpleConnection,
dsl::sql,
sql_types::{Integer, Text},
};
use diesel_migrations::embed_migrations;
use tokio::{sync::Barrier, task::JoinSet};
use super::*;
#[tokio::test]
async fn connect_and_migrate() {
const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./test/migrations");
let db_conn = DbConnection::connect_temp_file_and_migrate(MIGRATIONS).unwrap();
let path = db_conn.db_path();
let mut pool_conn = db_conn.get_pooled_connection().unwrap();
let count: i32 = sql::<Integer>("SELECT COUNT(*) FROM test_table")
.get_result(&mut pool_conn)
.unwrap();
assert_eq!(count, 0);
assert!(path.exists());
drop(pool_conn);
drop(db_conn);
assert!(!path.exists());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_connect_and_migrate_contention() {
const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./test/migrations");
let db = DbConnection::connect_temp_file_and_migrate(MIGRATIONS).unwrap();
let mut c = db.get_pooled_connection().unwrap();
sql::<Integer>("PRAGMA wal_autocheckpoint = 1;")
.execute(&mut c)
.unwrap();
let mode: String = sql::<Text>("PRAGMA journal_mode;").get_result(&mut c).unwrap();
assert!(mode.eq_ignore_ascii_case("wal"));
let busy: String = sql::<Text>("PRAGMA busy_timeout;").get_result(&mut c).unwrap();
assert!(busy.parse::<u128>().unwrap() >= PRAGMA_BUSY_TIMEOUT.as_millis());
const WRITERS: usize = 160;
const READERS: usize = 320;
const HOLD_MS: u64 = 100;
let barrier = Arc::new(Barrier::new(WRITERS + READERS));
let mut tasks = JoinSet::new();
for _ in 0..WRITERS {
let synchronization_barrier = barrier.clone();
let db2 = db.clone();
tasks.spawn(async move {
synchronization_barrier.wait().await;
tokio::task::spawn_blocking(move || {
let mut conn = db2.get_pooled_connection().expect("writer checkout");
conn.batch_execute("BEGIN EXCLUSIVE;").unwrap();
sql::<Integer>("INSERT INTO test_table DEFAULT VALUES;")
.execute(&mut conn)
.unwrap();
std::thread::sleep(std::time::Duration::from_millis(HOLD_MS));
conn.batch_execute("COMMIT;").unwrap();
})
.await
.expect("writer join");
});
}
for _ in 0..READERS {
let b = barrier.clone();
let db2 = db.clone();
tasks.spawn(async move {
b.wait().await;
tokio::task::spawn_blocking(move || {
let mut conn = db2.get_pooled_connection().expect("reader checkout");
for _ in 0..3 {
let _: i32 = sql::<Integer>("SELECT COUNT(*) FROM test_table")
.get_result(&mut conn)
.expect("reader select");
std::thread::sleep(std::time::Duration::from_millis(10));
}
})
.await
.expect("reader join");
});
}
while let Some(res) = tasks.join_next().await {
res.expect("task panicked");
}
let mut c = db.get_pooled_connection().unwrap();
let count: i32 = sql::<Integer>("SELECT COUNT(*) FROM test_table")
.get_result(&mut c)
.unwrap();
assert_eq!(count as usize, WRITERS);
}
}