use std::convert::TryFrom;
use std::fmt;
use std::num::NonZeroU32;
use std::sync::{Arc, RwLock};
use diesel::{
connection::SimpleConnection,
r2d2::{ConnectionManager, CustomizeConnection, Pool, PooledConnection},
sqlite,
};
use crate::error::{InternalError, InvalidStateError};
use super::{Backend, Connection, WriteExclusiveExecute};
pub struct SqliteConnection(
pub(in crate::state::merkle::sql) PooledConnection<ConnectionManager<sqlite::SqliteConnection>>,
);
impl Connection for SqliteConnection {
type ConnectionType = sqlite::SqliteConnection;
fn as_inner(&self) -> &Self::ConnectionType {
&self.0
}
}
#[derive(Clone)]
pub struct SqliteBackend {
connection_pool: Arc<RwLock<Pool<ConnectionManager<sqlite::SqliteConnection>>>>,
}
impl WriteExclusiveExecute for SqliteBackend {
fn execute_write<F, T>(&self, f: F) -> Result<T, InternalError>
where
F: Fn(&Self::Connection) -> Result<T, InternalError>,
{
let write_pool = self.connection_pool.write().map_err(|_| {
InternalError::with_message("SqliteBackend connection pool lock was poisoned".into())
})?;
let conn = write_pool
.get()
.map(SqliteConnection)
.map_err(|err| InternalError::from_source(Box::new(err)))?;
f(&conn)
}
fn execute_read<F, T>(&self, f: F) -> Result<T, InternalError>
where
F: Fn(&Self::Connection) -> Result<T, InternalError>,
{
let read_pool = self.connection_pool.read().map_err(|_| {
InternalError::with_message("SqliteBackend connection pool lock was poisoned".into())
})?;
let conn = read_pool
.get()
.map(SqliteConnection)
.map_err(|err| InternalError::from_source(Box::new(err)))?;
f(&conn)
}
}
impl Backend for SqliteBackend {
type Connection = SqliteConnection;
fn connection(&self) -> Result<Self::Connection, InternalError> {
self.connection_pool
.read()
.map_err(|_| {
InternalError::with_message(
"SqliteBackend connection pool lock was poisoned".into(),
)
})?
.get()
.map(SqliteConnection)
.map_err(|err| InternalError::from_source(Box::new(err)))
}
}
impl From<Pool<ConnectionManager<sqlite::SqliteConnection>>> for SqliteBackend {
fn from(pool: Pool<ConnectionManager<sqlite::SqliteConnection>>) -> Self {
Self {
connection_pool: Arc::new(RwLock::new(pool)),
}
}
}
impl From<Arc<RwLock<Pool<ConnectionManager<sqlite::SqliteConnection>>>>> for SqliteBackend {
fn from(pool: Arc<RwLock<Pool<ConnectionManager<sqlite::SqliteConnection>>>>) -> Self {
Self {
connection_pool: pool,
}
}
}
impl From<SqliteBackend> for Arc<RwLock<Pool<ConnectionManager<sqlite::SqliteConnection>>>> {
fn from(backend: SqliteBackend) -> Self {
backend.connection_pool
}
}
#[cfg(feature = "state-merkle-sql-in-transaction")]
pub struct BorrowedSqliteConnection<'a>(&'a sqlite::SqliteConnection);
#[cfg(feature = "state-merkle-sql-in-transaction")]
impl<'a> Connection for BorrowedSqliteConnection<'a> {
type ConnectionType = sqlite::SqliteConnection;
fn as_inner(&self) -> &Self::ConnectionType {
self.0
}
}
#[cfg(feature = "state-merkle-sql-in-transaction")]
pub struct InTransactionSqliteBackend<'a> {
connection: &'a sqlite::SqliteConnection,
}
#[cfg(feature = "state-merkle-sql-in-transaction")]
impl<'a> InTransactionSqliteBackend<'a> {
pub fn new(connection: &'a sqlite::SqliteConnection) -> Self {
Self { connection }
}
}
#[cfg(feature = "state-merkle-sql-in-transaction")]
impl<'a> Backend for InTransactionSqliteBackend<'a> {
type Connection = BorrowedSqliteConnection<'a>;
fn connection(&self) -> Result<Self::Connection, InternalError> {
Ok(BorrowedSqliteConnection(self.connection))
}
}
#[cfg(feature = "state-merkle-sql-in-transaction")]
impl<'a> WriteExclusiveExecute for InTransactionSqliteBackend<'a> {
fn execute_write<F, T>(&self, f: F) -> Result<T, InternalError>
where
F: Fn(&Self::Connection) -> Result<T, InternalError>,
{
f(&BorrowedSqliteConnection(self.connection))
}
fn execute_read<F, T>(&self, f: F) -> Result<T, InternalError>
where
F: Fn(&Self::Connection) -> Result<T, InternalError>,
{
f(&BorrowedSqliteConnection(self.connection))
}
}
#[cfg(feature = "state-merkle-sql-in-transaction")]
impl<'a> Clone for InTransactionSqliteBackend<'a> {
fn clone(&self) -> Self {
Self {
connection: self.connection,
}
}
}
#[cfg(feature = "state-merkle-sql-in-transaction")]
impl<'a> From<&'a sqlite::SqliteConnection> for InTransactionSqliteBackend<'a> {
fn from(connection: &'a sqlite::SqliteConnection) -> Self {
Self::new(connection)
}
}
pub const DEFAULT_MMAP_SIZE: i64 = 100 * 1024 * 1024;
#[derive(Debug)]
pub enum Synchronous {
Off,
Normal,
Full,
}
impl fmt::Display for Synchronous {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Synchronous::Off => f.write_str("OFF"),
Synchronous::Normal => f.write_str("NORMAL"),
Synchronous::Full => f.write_str("FULL"),
}
}
}
#[derive(Debug)]
pub enum JournalMode {
Delete,
Truncate,
Persist,
Memory,
Wal,
Off,
}
impl fmt::Display for JournalMode {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
JournalMode::Delete => f.write_str("DELETE"),
JournalMode::Truncate => f.write_str("TRUNCATE"),
JournalMode::Persist => f.write_str("PERSIST"),
JournalMode::Memory => f.write_str("MEMORY"),
JournalMode::Wal => f.write_str("WAL"),
JournalMode::Off => f.write_str("OFF"),
}
}
}
pub struct SqliteBackendBuilder {
connection_path: Option<String>,
create: bool,
pool_size: Option<u32>,
memory_map_size: i64,
journal_mode: Option<JournalMode>,
synchronous: Option<Synchronous>,
}
impl SqliteBackendBuilder {
pub fn new() -> Self {
Self {
connection_path: None,
create: false,
pool_size: None,
memory_map_size: DEFAULT_MMAP_SIZE,
journal_mode: None,
synchronous: None,
}
}
pub fn with_memory_database(mut self) -> Self {
self.connection_path = Some(":memory:".into());
self.pool_size = Some(1);
self
}
pub fn with_connection_path<S: Into<String>>(mut self, connection_path: S) -> Self {
self.connection_path = Some(connection_path.into());
self
}
pub fn with_create(mut self) -> Self {
self.create = true;
self
}
pub fn with_connection_pool_size(mut self, pool_size: NonZeroU32) -> Self {
self.pool_size = Some(pool_size.get());
self
}
pub fn with_memory_map_size(mut self, memory_map_size: u64) -> Self {
self.memory_map_size = i64::try_from(memory_map_size).unwrap_or(std::i64::MAX);
self
}
pub fn with_journal_mode(mut self, journal_mode: JournalMode) -> Self {
self.journal_mode = Some(journal_mode);
if matches!(&self.journal_mode, &Some(JournalMode::Wal)) {
self.synchronous = Some(Synchronous::Full);
}
self
}
pub fn with_synchronous(mut self, synchronous: Synchronous) -> Self {
self.synchronous = Some(synchronous);
self
}
pub fn build(self) -> Result<SqliteBackend, InvalidStateError> {
let path = self.connection_path.ok_or_else(|| {
InvalidStateError::with_message("must provide a sqlite connection URI".into())
})?;
let mmap_size = self.memory_map_size;
let journal_mode_opt = self.journal_mode;
let synchronous_opt = self.synchronous.or(Some(Synchronous::Normal));
if !self.create && (path != ":memory:") && !std::path::Path::new(&path).exists() {
return Err(InvalidStateError::with_message(format!(
"Database file '{}' does not exist",
path
)));
}
let connection_manager = ConnectionManager::<diesel::sqlite::SqliteConnection>::new(&path);
let mut pool_builder = Pool::builder();
if let Some(pool_size) = self.pool_size {
pool_builder = pool_builder.max_size(pool_size);
}
let pool = pool_builder
.connection_customizer(Box::new(SqliteCustomizer::new(
mmap_size,
journal_mode_opt,
synchronous_opt,
)))
.build(connection_manager)
.map_err(|err| InvalidStateError::with_message(err.to_string()))?;
let _conn = pool
.get()
.map_err(|err| InvalidStateError::with_message(err.to_string()))?;
Ok(SqliteBackend::from(pool))
}
}
impl Default for SqliteBackendBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
struct SqliteCustomizer {
on_connect_sql: String,
}
impl SqliteCustomizer {
fn new(
mmap_size: i64,
journal_mode: Option<JournalMode>,
synchronous: Option<Synchronous>,
) -> Self {
let mut on_connect_sql = format!("PRAGMA mmap_size={};", mmap_size);
if let Some(journal_mode) = journal_mode {
on_connect_sql.push_str("PRAGMA journal_mode=");
on_connect_sql.push_str(&journal_mode.to_string());
on_connect_sql.push(';');
if matches!(journal_mode, JournalMode::Wal) {
on_connect_sql += r#"
PRAGMA wal_checkpoint(truncate);
PRAGMA busy_timeout = 2000;
"#;
}
}
if let Some(synchronous) = synchronous {
on_connect_sql.push_str("PRAGMA synchronous=");
on_connect_sql.push_str(&synchronous.to_string());
on_connect_sql.push(';');
}
on_connect_sql += "PRAGMA foreign_keys = ON;";
Self { on_connect_sql }
}
}
impl CustomizeConnection<diesel::sqlite::SqliteConnection, diesel::r2d2::Error>
for SqliteCustomizer
{
fn on_acquire(
&self,
conn: &mut diesel::sqlite::SqliteConnection,
) -> Result<(), diesel::r2d2::Error> {
conn.batch_execute(&self.on_connect_sql)
.map_err(diesel::r2d2::Error::QueryError)
}
}