use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use rusqlite::types::{ToSqlOutput, Value as SqliteValue, ValueRef};
use rusqlite::{Connection, ToSql};
use tokio::sync::{oneshot, OwnedSemaphorePermit, Semaphore};
use crate::driver::ExecuteResult;
use crate::error::OrmError;
use crate::row::Row;
use crate::value::Value;
const BUSY_TIMEOUT_MS: u32 = 5_000;
const DEFAULT_ACQUIRE_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Debug, Clone)]
enum Source {
File(String),
Memory,
}
struct Inner {
source: Source,
idle: Mutex<Vec<Connection>>,
semaphore: Arc<Semaphore>,
statements: AtomicU64,
acquire_timeout: Duration,
}
impl Inner {
async fn acquire_permit(&self) -> crate::Result<OwnedSemaphorePermit> {
let acquire = Arc::clone(&self.semaphore).acquire_owned();
match tokio::time::timeout(self.acquire_timeout, acquire).await {
Ok(Ok(permit)) => Ok(permit),
Ok(Err(_)) => Err(OrmError::connection("connection pool is closed")),
Err(_) => Err(OrmError::connection(format!(
"timed out after {}s waiting for a database connection",
self.acquire_timeout.as_secs()
))),
}
}
}
impl Inner {
fn open(&self) -> crate::Result<Connection> {
let conn = match &self.source {
Source::File(path) => Connection::open(path)
.map_err(|e| OrmError::connection(format!("cannot open `{path}`")).with_source(e))?,
Source::Memory => Connection::open_in_memory()
.map_err(|e| OrmError::connection("cannot open in-memory database").with_source(e))?,
};
conn.busy_timeout(std::time::Duration::from_millis(u64::from(BUSY_TIMEOUT_MS)))
.map_err(|e| OrmError::connection("cannot set busy timeout").with_source(e))?;
conn.pragma_update(None, "foreign_keys", "ON")
.map_err(|e| OrmError::connection("cannot enable foreign keys").with_source(e))?;
if matches!(self.source, Source::File(_)) {
conn.pragma_update(None, "journal_mode", "WAL")
.map_err(|e| OrmError::connection("cannot enable WAL").with_source(e))?;
}
conn.set_prepared_statement_cache_capacity(200);
Ok(conn)
}
}
#[derive(Clone)]
pub struct SqlitePool {
inner: Arc<Inner>,
}
impl SqlitePool {
pub fn new(url: &str, max_connections: u32) -> crate::Result<Self> {
if max_connections == 0 {
return Err(OrmError::configuration("max_connections must be at least 1"));
}
let source = parse_url(url)?;
let permits = match source {
Source::Memory => 1,
Source::File(_) => max_connections as usize,
};
Ok(Self {
inner: Arc::new(Inner {
source,
idle: Mutex::new(Vec::new()),
semaphore: Arc::new(Semaphore::new(permits)),
statements: AtomicU64::new(0),
acquire_timeout: DEFAULT_ACQUIRE_TIMEOUT,
}),
})
}
pub fn with_acquire_timeout(mut self, timeout: Duration) -> Self {
if let Some(inner) = Arc::get_mut(&mut self.inner) {
inner.acquire_timeout = if timeout.is_zero() {
DEFAULT_ACQUIRE_TIMEOUT
} else {
timeout
};
}
self
}
pub async fn fetch_all(&self, sql: String, params: Vec<Value>) -> crate::Result<Vec<Row>> {
self.with_connection(move |conn| fetch_all(conn, &sql, ¶ms))
.await
}
pub async fn execute(&self, sql: String, params: Vec<Value>) -> crate::Result<ExecuteResult> {
self.with_connection(move |conn| execute(conn, &sql, ¶ms))
.await
}
pub async fn execute_batch(&self, sql: String) -> crate::Result<()> {
self.with_connection(move |conn| execute_batch(conn, &sql))
.await
}
pub fn statement_count(&self) -> u64 {
self.inner.statements.load(Ordering::Relaxed)
}
pub(crate) async fn acquire_pinned(&self) -> crate::Result<PinnedSqlite> {
let permit = self.inner.acquire_permit().await?;
let checked_out = lock(&self.inner.idle).pop();
let conn = match checked_out {
Some(conn) => conn,
None => {
let inner = Arc::clone(&self.inner);
tokio::task::spawn_blocking(move || inner.open())
.await
.map_err(|e| OrmError::connection(format!("database worker failed: {e}")))??
}
};
Ok(PinnedSqlite {
inner: Arc::clone(&self.inner),
conn: Arc::new(Mutex::new(Some(conn))),
gate: tokio::sync::Mutex::new(()),
_permit: permit,
})
}
pub async fn close(&self) {
let drained = {
let mut idle = lock(&self.inner.idle);
std::mem::take(&mut *idle)
};
drop(drained);
}
async fn with_connection<F, T>(&self, work: F) -> crate::Result<T>
where
F: FnOnce(&mut Connection) -> crate::Result<T> + Send + 'static,
T: Send + 'static,
{
self.inner.statements.fetch_add(1, Ordering::Relaxed);
let permit = self.inner.acquire_permit().await?;
let checked_out = lock(&self.inner.idle).pop();
let inner = Arc::clone(&self.inner);
let (tx, rx) = oneshot::channel();
tokio::task::spawn_blocking(move || {
let _permit = permit;
let mut conn = match checked_out {
Some(conn) => conn,
None => match inner.open() {
Ok(conn) => conn,
Err(error) => {
let _ = tx.send(Err(error));
return;
}
},
};
let result = work(&mut conn);
let healthy = result.is_ok() || conn.execute_batch("SELECT 1;").is_ok();
if healthy {
lock(&inner.idle).push(conn);
}
let _ = tx.send(result);
});
rx.await
.map_err(|_| OrmError::query("database worker was dropped before completing"))?
}
}
fn lock(mutex: &Mutex<Vec<Connection>>) -> std::sync::MutexGuard<'_, Vec<Connection>> {
mutex.lock().unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn parse_url(url: &str) -> crate::Result<Source> {
let trimmed = url.trim();
if trimmed.is_empty() {
return Err(OrmError::configuration("database url is empty"));
}
let without_scheme = trimmed
.strip_prefix("sqlite://")
.or_else(|| trimmed.strip_prefix("sqlite:"))
.unwrap_or(trimmed);
if without_scheme == ":memory:" || without_scheme.is_empty() {
return Ok(Source::Memory);
}
let has_traversal = std::path::Path::new(without_scheme)
.components()
.any(|component| matches!(component, std::path::Component::ParentDir));
if has_traversal {
return Err(OrmError::configuration(
"database path must not contain `..` components",
));
}
Ok(Source::File(without_scheme.to_string()))
}
impl ToSql for Value {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
let value = match self {
Value::Null => SqliteValue::Null,
Value::Bool(b) => SqliteValue::Integer(i64::from(*b)),
Value::Int(i) => SqliteValue::Integer(*i),
Value::Real(r) => SqliteValue::Real(*r),
Value::Text(s) => SqliteValue::Text(s.clone()),
Value::Blob(b) => SqliteValue::Blob(b.clone()),
Value::Timestamp(ts) => SqliteValue::Text(format_timestamp(ts)?),
Value::Json(j) => SqliteValue::Text(j.to_string()),
Value::Uuid(u) => SqliteValue::Text(u.to_string()),
Value::Array(items) => SqliteValue::Text(format!("{items:?}")),
};
Ok(ToSqlOutput::Owned(value))
}
}
fn format_timestamp(ts: &time::OffsetDateTime) -> rusqlite::Result<String> {
ts.format(&time::format_description::well_known::Rfc3339)
.map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))
}
fn read_value(raw: ValueRef<'_>) -> crate::Result<Value> {
Ok(match raw {
ValueRef::Null => Value::Null,
ValueRef::Integer(i) => Value::Int(i),
ValueRef::Real(r) => Value::Real(r),
ValueRef::Text(bytes) => {
let text = std::str::from_utf8(bytes)
.map_err(|_| OrmError::conversion("column text is not valid UTF-8"))?;
Value::Text(text.to_string())
}
ValueRef::Blob(bytes) => Value::Blob(bytes.to_vec()),
})
}
fn fetch_all(conn: &mut Connection, sql: &str, params: &[Value]) -> crate::Result<Vec<Row>> {
let mut statement = conn
.prepare_cached(sql)
.map_err(|e| OrmError::query("failed to prepare statement").with_source(e))?;
let column_names: Arc<[String]> = statement
.column_names()
.into_iter()
.map(str::to_string)
.collect::<Vec<_>>()
.into();
let column_count = column_names.len();
let mut rows = statement
.query(rusqlite::params_from_iter(params.iter()))
.map_err(|e| OrmError::query("query execution failed").with_source(e))?;
let mut out = Vec::new();
while let Some(row) = rows
.next()
.map_err(|e| OrmError::query("reading a row failed").with_source(e))?
{
let mut values = Vec::with_capacity(column_count);
for index in 0..column_count {
let raw = row
.get_ref(index)
.map_err(|e| OrmError::query("reading a column failed").with_source(e))?;
values.push(read_value(raw)?);
}
out.push(Row::with_columns(Arc::clone(&column_names), values));
}
Ok(out)
}
fn execute_batch(conn: &mut Connection, sql: &str) -> crate::Result<()> {
conn.execute_batch(sql)
.map_err(|e| OrmError::query("statement batch failed").with_source(e))
}
fn execute(conn: &mut Connection, sql: &str, params: &[Value]) -> crate::Result<ExecuteResult> {
let affected = conn
.prepare_cached(sql)
.map_err(|e| OrmError::query("failed to prepare statement").with_source(e))?
.execute(rusqlite::params_from_iter(params.iter()))
.map_err(|e| OrmError::query("statement execution failed").with_source(e))?;
Ok(ExecuteResult {
rows_affected: affected as u64,
last_insert_rowid: conn.last_insert_rowid(),
})
}
pub(crate) struct PinnedSqlite {
inner: Arc<Inner>,
conn: Arc<Mutex<Option<Connection>>>,
gate: tokio::sync::Mutex<()>,
_permit: OwnedSemaphorePermit,
}
impl PinnedSqlite {
fn take_conn(&self) -> crate::Result<Connection> {
self.conn
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.take()
.ok_or_else(|| OrmError::query("pinned connection is already in use"))
}
async fn run<F, T>(&self, work: F) -> crate::Result<T>
where
F: FnOnce(&mut Connection) -> crate::Result<T> + Send + 'static,
T: Send + 'static,
{
let _gate = self.gate.lock().await;
self.inner.statements.fetch_add(1, Ordering::Relaxed);
let mut conn = self.take_conn()?;
let slot = Arc::clone(&self.conn);
let (tx, rx) = oneshot::channel();
tokio::task::spawn_blocking(move || {
let result = work(&mut conn);
*slot.lock().unwrap_or_else(|poisoned| poisoned.into_inner()) = Some(conn);
let _ = tx.send(result);
});
rx.await
.map_err(|_| OrmError::query("database worker was dropped before completing"))?
}
pub(crate) async fn fetch_all(
&self,
sql: String,
params: Vec<Value>,
) -> crate::Result<Vec<Row>> {
self.run(move |conn| fetch_all(conn, &sql, ¶ms)).await
}
pub(crate) async fn execute(
&self,
sql: String,
params: Vec<Value>,
) -> crate::Result<ExecuteResult> {
self.run(move |conn| execute(conn, &sql, ¶ms)).await
}
pub(crate) async fn execute_batch(&self, sql: String) -> crate::Result<()> {
self.run(move |conn| execute_batch(conn, &sql)).await
}
pub(crate) fn rollback_now(&self) {
if let Ok(conn) = self.take_conn() {
if let Err(error) = conn.execute_batch("ROLLBACK") {
eprintln!("tork-orm: failed to rollback transaction on drop: {error}");
}
*self.conn.lock().unwrap_or_else(|poisoned| poisoned.into_inner()) = Some(conn);
}
}
}
impl Drop for PinnedSqlite {
fn drop(&mut self) {
if let Some(conn) = self
.conn
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.take()
{
lock(&self.inner.idle).push(conn);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
impl SqlitePool {
fn idle_len(&self) -> usize {
lock(&self.inner.idle).len()
}
}
const SLOW_QUERY: &str = "WITH RECURSIVE c(n) AS (SELECT 1 UNION ALL \
SELECT n + 1 FROM c WHERE n < 4000000) SELECT count(*) FROM c";
#[tokio::test]
async fn checkout_times_out_instead_of_hanging_forever() {
let pool = SqlitePool::new(":memory:", 1)
.unwrap()
.with_acquire_timeout(Duration::from_millis(50));
let pinned = pool.acquire_pinned().await.unwrap();
let start = std::time::Instant::now();
let result = pool.fetch_all("SELECT 1".into(), vec![]).await;
let waited = start.elapsed();
let error = result.expect_err("checkout should time out");
assert!(
error.to_string().contains("timed out"),
"expected a timeout error, got: {error}"
);
assert!(waited < Duration::from_secs(5), "must fail fast, not hang");
drop(pinned);
let rows = pool.fetch_all("SELECT 1".into(), vec![]).await.unwrap();
assert_eq!(rows.len(), 1);
}
#[tokio::test]
async fn cancelled_query_returns_its_connection_to_the_pool() {
let pool = SqlitePool::new(":memory:", 1).unwrap();
let cancelled = tokio::time::timeout(
Duration::from_millis(1),
pool.fetch_all(SLOW_QUERY.into(), vec![]),
)
.await;
assert!(
cancelled.is_err(),
"the slow query should be cancelled by the timeout"
);
let mut returned = false;
for _ in 0..200 {
if pool.idle_len() == 1 {
returned = true;
break;
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
assert!(
returned,
"the cancelled query's connection was never returned to the pool"
);
let rows = pool.fetch_all("SELECT 1".into(), vec![]).await.unwrap();
assert_eq!(rows.len(), 1);
}
#[test]
fn rejects_parent_directory_traversal_in_the_path() {
assert!(SqlitePool::new("sqlite://../../etc/passwd", 1).is_err());
assert!(SqlitePool::new("../secret.db", 1).is_err());
assert!(SqlitePool::new("app.db", 1).is_ok());
assert!(SqlitePool::new("/tmp/tork-test.db", 1).is_ok());
assert!(SqlitePool::new(":memory:", 1).is_ok());
}
#[tokio::test]
async fn a_query_error_keeps_a_healthy_connection() {
let pool = SqlitePool::new(":memory:", 1).unwrap();
let result = pool
.fetch_all("SELECT * FROM does_not_exist".into(), vec![])
.await;
assert!(result.is_err(), "the query should fail");
assert_eq!(pool.idle_len(), 1, "a healthy connection stays in the pool");
let rows = pool.fetch_all("SELECT 1".into(), vec![]).await.unwrap();
assert_eq!(rows.len(), 1);
}
#[tokio::test]
async fn query_errors_do_not_embed_the_raw_sql() {
let pool = SqlitePool::new(":memory:", 1).unwrap();
let error = pool
.fetch_all("SELECT secret_column FROM secret_table".into(), vec![])
.await
.expect_err("a missing table should error");
assert!(
!error.message().contains("secret_table")
&& !error.message().contains("secret_column"),
"the error message leaked SQL: {}",
error.message()
);
assert!(
!error.to_string().contains("secret_table"),
"the error display leaked SQL: {error}"
);
}
}