use crate::{DBResult, Error};
use libsqlite3_sys as sq;
use std::{
cell::RefCell,
collections::{HashMap, VecDeque},
ffi::{c_char, c_int, c_void, CStr, CString},
sync::{Arc, Condvar, Mutex},
};
use super::{check_rcode, PreparedKey, Statement, StatementContext, StatementRow, Transaction};
pub(crate) struct ConnectionData {
pub(crate) sqlite: *mut sq::sqlite3,
pub(super) stmts: RefCell<HashMap<u64, Statement>>,
}
impl Drop for ConnectionData {
fn drop(&mut self) {
self.stmts.borrow_mut().clear();
unsafe {
sq::sqlite3_close(self.sqlite);
}
}
}
impl ConnectionData {
fn new(path: &std::path::Path) -> Result<Self, Error> {
let db_ptr = unsafe {
let cpath = CString::new(path.as_os_str().as_encoded_bytes())?;
let mut db_ptr = std::ptr::null_mut();
check_rcode(
|| None,
sq::sqlite3_open_v2(
cpath.as_ptr(),
&mut db_ptr,
sq::SQLITE_OPEN_READWRITE | sq::SQLITE_OPEN_NOMUTEX | sq::SQLITE_OPEN_CREATE,
std::ptr::null_mut(),
),
)?;
db_ptr
};
if db_ptr.is_null() {
return Err(Error::InternalError(
"sqlite3_open_v2 returned a NULL connection",
));
}
unsafe {
sq::sqlite3_extended_result_codes(db_ptr, 1);
sq::sqlite3_busy_timeout(db_ptr, 100);
}
#[cfg(feature = "regex")]
super::regex::install_regex(db_ptr)?;
let cdata = Self {
sqlite: db_ptr,
stmts: Default::default(),
};
cdata.execute_raw_sql("PRAGMA foreign_keys = ON")?;
Ok(cdata)
}
pub(crate) fn execute_raw_sql(&self, sql: impl AsRef<str>) -> DBResult<()> {
log::trace!(
"executing raw sql on {:?}: {sql}",
self.sqlite,
sql = sql.as_ref()
);
unsafe {
let c_sql = CString::new(sql.as_ref())?;
let mut err = std::ptr::null_mut();
let rcode = sq::sqlite3_exec(
self.sqlite,
c_sql.as_ptr(),
None,
std::ptr::null_mut(),
&mut err,
);
if rcode != sq::SQLITE_OK {
let e = Error::Sqlite {
code: rcode & 0xff,
extended_code: rcode & !0xff,
msg: if err.is_null() {
CStr::from_ptr(sq::sqlite3_errstr(rcode))
} else {
CStr::from_ptr(err)
}
.to_str()?
.to_string(),
sql: Some(sql.as_ref().into()),
};
if !err.is_null() {
sq::sqlite3_free(err.cast());
}
return Err(e);
}
}
Ok(())
}
}
unsafe impl Send for ConnectionData {}
pub(crate) struct ConnectionLease {
pool: Arc<ConnectionPoolData>,
pub(super) conn: &'static ConnectionData,
index: usize,
active: Option<(u64, Statement)>,
}
impl ConnectionLease {
pub(crate) fn execute_raw_sql(&mut self, sql: impl AsRef<str>) -> DBResult<()> {
self.conn.execute_raw_sql(sql)
}
fn prepare_query(&self, sql: &str) -> DBResult<Statement> {
log::trace!("preparing query: {sql}");
let mut stmt = std::ptr::null_mut();
unsafe {
check_rcode(
|| Some(sql),
sq::sqlite3_prepare_v2(
self.conn.sqlite,
sql.as_ptr().cast(),
sql.len() as i32,
&mut stmt,
std::ptr::null_mut(),
),
)?;
};
if stmt.is_null() {
return Err(Error::InternalError(
"sqlite3_prepare_v2 returned a NULL stmt",
));
}
Ok(Statement {
sqlite: self.conn.sqlite,
stmt,
})
}
pub(crate) fn with_prepared<R>(
&mut self,
hash_key: impl PreparedKey,
build_query: impl Fn() -> DBResult<String>,
run_query: impl Fn(StatementContext) -> DBResult<R>,
) -> DBResult<R> {
use std::collections::hash_map::Entry;
let mut stmts = self.conn.stmts.borrow_mut();
match stmts.entry(hash_key.into_u64()) {
Entry::Vacant(e) => {
let sql = build_query()?;
let stmt = e.insert(self.prepare_query(sql.as_str())?);
run_query(stmt.make_context()?)
},
Entry::Occupied(mut e) => run_query(e.get_mut().make_context()?),
}
}
pub(crate) fn iter_with_prepared<D>(
&mut self,
data: D,
hash_key: impl PreparedKey,
build_query: impl Fn(&D) -> DBResult<String>,
query_setup: impl Fn(&D, &mut StatementContext) -> DBResult<()>,
) -> DBResult<impl Iterator<Item = DBResult<StatementRow<'_>>> + '_> {
let mut stmts = self.conn.stmts.borrow_mut();
let key = hash_key.into_u64();
if let Some((key, s)) = self.active.take() {
stmts.insert(key, s);
}
let active = self.active.insert((
key,
match stmts.remove(&key) {
Some(v) => v,
None => self.prepare_query(build_query(&data)?.as_str())?,
},
));
let mut ctx = active.1.make_context()?;
query_setup(&data, &mut ctx)?;
Ok(ctx.iter())
}
pub(crate) fn check_foreign_keys(&self) -> DBResult<Vec<String>> {
log::debug!("Running foreign key check...");
let mut failures: Vec<String> = vec![];
unsafe extern "C" fn fk_callback(
data: *mut c_void,
ncol: c_int,
rowdata: *mut *mut c_char,
_colnames: *mut *mut c_char,
) -> c_int {
let failures: &mut Vec<String> = (data as *mut Vec<String>).as_mut().unwrap();
let rows = std::slice::from_raw_parts(rowdata, ncol as usize);
let extract_value = |idx: usize| {
if !rows[idx].is_null() {
CStr::from_ptr(rows[idx]).to_str().map(ToString::to_string)
} else {
Ok(String::new())
}
};
failures.push(format!(
"{}@{}",
extract_value(0).unwrap(),
extract_value(1).unwrap()
));
sq::SQLITE_OK
}
unsafe {
let mut err = std::ptr::null_mut();
let failure_ptr: *mut Vec<String> = &mut failures;
let rcode = sq::sqlite3_exec(
self.conn.sqlite,
CStr::from_bytes_with_nul_unchecked("PRAGMA foreign_key_check\0".as_bytes())
.as_ptr(),
Some(fk_callback),
failure_ptr as *mut c_void,
&mut err,
);
if rcode == sq::SQLITE_OK {
log::trace!("foreign keys check out!");
Ok(failures)
} else {
Err(Error::InternalError(
"failed to execute PRAGMA foreign_key_check",
))
}
}
}
}
impl AsRef<ConnectionLease> for ConnectionLease {
fn as_ref(&self) -> &ConnectionLease {
self
}
}
impl Drop for ConnectionLease {
fn drop(&mut self) {
if let Some((key, stmt)) = self.active.take() {
self.conn.stmts.borrow_mut().insert(key, stmt);
}
self.pool.release(self.index);
}
}
struct ConnectionPoolData {
path: std::path::PathBuf,
available_condition: Condvar,
available: Mutex<VecDeque<usize>>,
waiting: Mutex<VecDeque<std::task::Waker>>,
connections: Vec<&'static ConnectionData>,
}
unsafe impl Sync for ConnectionPoolData {}
unsafe impl Send for ConnectionPoolData {}
impl ConnectionPoolData {
fn spawn(&mut self, count: usize) -> Result<(), Error> {
let mut alock = self.available.lock()?;
for _ in 0..count {
let nconn = ConnectionData::new(&self.path)?;
alock.push_back(self.connections.len());
self.connections.push(Box::leak(Box::new(nconn)));
self.available_condition.notify_one();
}
Ok(())
}
fn try_acquire(self: &Arc<Self>) -> Result<Option<ConnectionLease>, Error> {
let mut alock = self.available.lock()?;
if alock.is_empty() {
return Ok(None);
}
let index = alock.pop_back().unwrap();
Ok(Some(ConnectionLease {
pool: self.clone(),
conn: self.connections[index],
index,
active: None,
}))
}
fn acquire(self: &Arc<Self>) -> Result<ConnectionLease, Error> {
let mut alock = self.available.lock()?;
while alock.is_empty() {
alock = self.available_condition.wait(alock)?;
}
let index = alock.pop_back().unwrap();
Ok(ConnectionLease {
pool: self.clone(),
conn: self.connections[index],
index,
active: None,
})
}
fn release(self: &Arc<Self>, conn: usize) {
let Ok(mut alock) = self.available.lock() else {
log::warn!("Dropping connection due to poisoned lock");
return;
};
alock.push_back(conn);
self.available_condition.notify_one();
if let Some(waker) = self.waiting.lock().unwrap().pop_front() {
waker.wake();
}
}
}
impl Drop for ConnectionPoolData {
fn drop(&mut self) {
for cdata in self.connections.drain(..) {
unsafe {
drop(Box::from_raw(cdata as *const _ as *mut ConnectionData));
}
}
}
}
pub struct ConnectionPoolConfig {
path: std::path::PathBuf,
pool_size: usize,
force_wal: bool,
}
impl ConnectionPoolConfig {
const DEFAULT_POOL_SIZE: usize = 2;
pub fn new(path: impl Into<std::path::PathBuf>) -> Self {
Self {
path: path.into(),
pool_size: Self::DEFAULT_POOL_SIZE,
force_wal: false,
}
}
pub fn with_pool_size(mut self, pool_size: usize) -> Self {
self.pool_size = pool_size;
self
}
pub fn with_wal(mut self) -> Self {
self.force_wal = true;
self
}
}
impl<'l> From<&'l str> for ConnectionPoolConfig {
fn from(value: &'l str) -> Self {
Self {
path: std::ffi::OsStr::new(value).into(),
pool_size: Self::DEFAULT_POOL_SIZE,
force_wal: false,
}
}
}
impl From<String> for ConnectionPoolConfig {
fn from(value: String) -> Self {
Self {
path: value.into(),
pool_size: Self::DEFAULT_POOL_SIZE,
force_wal: false,
}
}
}
impl From<std::path::PathBuf> for ConnectionPoolConfig {
fn from(value: std::path::PathBuf) -> Self {
Self {
path: value,
pool_size: Self::DEFAULT_POOL_SIZE,
force_wal: false,
}
}
}
impl From<&std::path::Path> for ConnectionPoolConfig {
fn from(value: &std::path::Path) -> Self {
Self {
path: value.to_owned(),
pool_size: Self::DEFAULT_POOL_SIZE,
force_wal: false,
}
}
}
#[derive(Clone)]
pub struct ConnectionPool {
data: Arc<ConnectionPoolData>,
}
impl ConnectionPool {
pub(crate) fn new<CPC: Into<ConnectionPoolConfig>>(cpc: CPC) -> Result<Self, Error> {
let cpc = cpc.into();
log::trace!("Creating ConnectionPool for {}", cpc.path.display());
let mut pooldata = ConnectionPoolData {
path: cpc.path,
available_condition: Condvar::default(),
available: Default::default(),
waiting: Default::default(),
connections: vec![],
};
pooldata.spawn(cpc.pool_size)?;
let data = Arc::new(pooldata);
if cpc.force_wal {
let mut lease = data.acquire()?;
lease.execute_raw_sql("PRAGMA journal_mode = wal;")?;
}
Ok(Self { data })
}
pub(crate) fn acquire(&self) -> Result<ConnectionLease, Error> {
self.data.acquire()
}
pub fn start(&self) -> Result<Transaction, Error> {
Transaction::new(self.data.acquire()?)
}
pub fn start_async(
&self,
) -> impl std::future::Future<Output = Result<Transaction, Error>> + '_ {
PoolPoller { data: &self.data }
}
pub fn run_transaction<T>(
&self,
retries: usize,
mut f: impl FnMut(&mut Transaction) -> Result<T, microrm::Error>,
) -> Result<T, microrm::Error> {
for _ in 0..retries {
let Ok(mut txn) = self.start() else {
continue;
};
match (f)(&mut txn).and_then(|t| {
txn.commit()?;
Ok(t)
}) {
Ok(t) => return Ok(t),
Err(Error::TransactionAbort) => continue,
Err(e) => return Err(e),
}
}
Err(Error::TransactionAbort)
}
pub async fn run_transaction_async<T>(
&self,
retries: usize,
mut f: impl FnMut(&mut Transaction) -> Result<T, microrm::Error>,
) -> Result<T, microrm::Error> {
for _ in 0..retries {
let Ok(mut txn) = self.start_async().await else {
continue;
};
match (f)(&mut txn).and_then(|t| {
txn.commit()?;
Ok(t)
}) {
Ok(t) => return Ok(t),
Err(Error::TransactionAbort) => continue,
Err(e) => return Err(e),
}
}
Err(Error::TransactionAbort)
}
}
impl std::fmt::Debug for ConnectionPool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnectionPool")
.field("path", &self.data.path.display())
.field("connections", &self.data.connections.len())
.finish_non_exhaustive()
}
}
struct PoolPoller<'l> {
data: &'l Arc<ConnectionPoolData>,
}
impl std::future::Future for PoolPoller<'_> {
type Output = DBResult<Transaction>;
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
match self.data.try_acquire() {
Ok(Some(lease)) => std::task::Poll::Ready(Transaction::new(lease)),
Err(e) => std::task::Poll::Ready(Err(e)),
Ok(None) => {
let mut wakers = self.data.waiting.lock().unwrap();
wakers.push_back(cx.waker().clone());
std::task::Poll::Pending
},
}
}
}