use core::cell::UnsafeCell;
use core::error::Error;
use core::ffi::CStr;
use core::fmt;
use core::ops::{Deref, DerefMut};
use core::ptr::NonNull;
use core::sync::atomic::{AtomicU64, Ordering};
use alloc::boxed::Box;
#[cfg(feature = "std")]
use alloc::ffi::CString;
use alloc::sync::Arc;
use alloc::vec::Vec;
#[cfg(feature = "std")]
use std::path::Path;
use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore};
use crate::{Connection, Error as SqllError, NotThreadSafe, OpenOptions};
mod sealed_connection_setup {
use crate::{Connection, Error};
use super::EmptySetup;
pub trait Sealed {}
impl Sealed for EmptySetup {}
impl<F> Sealed for F where F: Fn(&mut Connection) -> Result<(), Error> {}
}
pub trait ConnectionSetup
where
Self: self::sealed_connection_setup::Sealed,
{
fn setup(&self, c: &mut Connection) -> Result<(), SqllError>;
}
#[non_exhaustive]
pub struct EmptySetup;
impl ConnectionSetup for EmptySetup {
fn setup(&self, _c: &mut Connection) -> Result<(), SqllError> {
Ok(())
}
}
impl<F> ConnectionSetup for F
where
F: Fn(&mut Connection) -> Result<(), SqllError>,
{
#[inline]
fn setup(&self, c: &mut Connection) -> Result<(), SqllError> {
self(c)
}
}
pub struct PoolBuilder<RB, WB> {
open_options: OpenOptions,
read_concurrency: usize,
write_builder: WB,
read_builder: RB,
}
impl PoolBuilder<EmptySetup, EmptySetup> {
pub fn new(open_options: OpenOptions, read_concurrency: usize) -> Self {
Self {
open_options,
read_concurrency,
write_builder: EmptySetup,
read_builder: EmptySetup,
}
}
}
impl<RB, WB> PoolBuilder<RB, WB>
where
RB: ConnectionSetup,
WB: ConnectionSetup,
{
pub fn with_write_setup<T>(self, write_builder: T) -> PoolBuilder<RB, T>
where
T: Fn(&mut Connection) -> Result<(), SqllError>,
{
PoolBuilder {
open_options: self.open_options,
read_concurrency: self.read_concurrency,
read_builder: self.read_builder,
write_builder,
}
}
pub fn with_read_setup<T>(self, read_builder: T) -> PoolBuilder<T, WB>
where
T: Fn(&mut Connection) -> Result<(), SqllError>,
{
PoolBuilder {
open_options: self.open_options,
read_concurrency: self.read_concurrency,
read_builder,
write_builder: self.write_builder,
}
}
#[inline]
#[cfg(feature = "std")]
pub fn open<R, W>(self, path: impl AsRef<Path>) -> Result<Pool<R, W>, PoolError>
where
R: IsReadOnly,
W: Statements,
{
let path = path.as_ref();
let Some(bytes) = path.to_str() else {
return Err(PoolError::from(ErrorKind::NotUtf8Path));
};
let Ok(string) = CString::new(bytes) else {
return Err(PoolError::from(ErrorKind::NulByteInPath));
};
Pool::new_c_str(
self.open_options,
&string,
self.read_concurrency,
self.write_builder,
self.read_builder,
)
}
pub fn open_c_str<R, W>(self, path: &CStr) -> Result<Pool<R, W>, PoolError>
where
R: IsReadOnly,
W: Statements,
{
Pool::new_c_str(
self.open_options,
path,
self.read_concurrency,
self.write_builder,
self.read_builder,
)
}
}
pub trait Statements
where
Self: Sized,
{
#[doc(hidden)]
fn build(c: &mut Connection) -> Result<Self, PoolError>;
}
pub unsafe trait IsReadOnly: Statements {}
pub struct PoolError {
kind: ErrorKind,
}
impl PoolError {
#[inline]
#[doc(hidden)]
pub fn index_not_read_only(index: usize) -> Self {
Self::from(ErrorKind::IndexNotReadOnly(index))
}
#[inline]
#[doc(hidden)]
pub fn field_not_read_only(name: &'static str) -> Self {
Self::from(ErrorKind::FieldNotReadOnly(name))
}
#[inline]
#[doc(hidden)]
pub fn index_prepare_failed(index: usize, source: SqllError) -> Self {
Self::from(ErrorKind::IndexPrepareFailed(index, source))
}
#[inline]
#[doc(hidden)]
pub fn field_prepare_failed(name: &'static str, source: SqllError) -> Self {
Self::from(ErrorKind::FieldPrepareFailed(name, source))
}
#[inline]
#[doc(hidden)]
pub fn index_not_thread_safe(index: usize, source: NotThreadSafe) -> Self {
Self::from(ErrorKind::IndexNotThreadSafe(index, source))
}
#[inline]
#[doc(hidden)]
pub fn field_not_thread_safe(name: &'static str, source: NotThreadSafe) -> Self {
Self::from(ErrorKind::FieldNotThreadSafe(name, source))
}
#[inline]
#[doc(hidden)]
pub fn index_build_failed(index: usize, source: impl Error + Send + Sync + 'static) -> Self {
Self::from(ErrorKind::IndexBuildFailed(index, Box::new(source)))
}
#[inline]
#[doc(hidden)]
pub fn field_build_failed(
name: &'static str,
source: impl Error + Send + Sync + 'static,
) -> Self {
Self::from(ErrorKind::FieldBuildFailed(name, Box::new(source)))
}
}
impl fmt::Display for PoolError {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.kind.fmt(f)
}
}
impl fmt::Debug for PoolError {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.kind.fmt(f)
}
}
impl Error for PoolError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self.kind {
ErrorKind::BuildWriteConnection(ref source) => Some(source),
ErrorKind::BuildReadConnection(_, ref source) => Some(source),
ErrorKind::SetupWriteConnection(ref source) => Some(source),
ErrorKind::SetupReadConnection(_, ref source) => Some(source),
ErrorKind::IndexPrepareFailed(_, ref source) => Some(source),
ErrorKind::FieldPrepareFailed(_, ref source) => Some(source),
ErrorKind::IndexNotThreadSafe(_, ref source) => Some(source),
ErrorKind::FieldNotThreadSafe(_, ref source) => Some(source),
_ => None,
}
}
}
#[derive(Debug)]
enum ErrorKind {
#[cfg(feature = "std")]
NotUtf8Path,
#[cfg(feature = "std")]
NulByteInPath,
BuildWriteConnection(SqllError),
BuildReadConnection(usize, SqllError),
SetupWriteConnection(SqllError),
SetupReadConnection(usize, SqllError),
NoConnections,
ZeroConcurrency,
TooManyConnections(usize),
IndexNotReadOnly(usize),
FieldNotReadOnly(&'static str),
IndexPrepareFailed(usize, SqllError),
FieldPrepareFailed(&'static str, SqllError),
IndexNotThreadSafe(usize, NotThreadSafe),
FieldNotThreadSafe(&'static str, NotThreadSafe),
IndexBuildFailed(usize, Box<dyn Error + Send + Sync + 'static>),
FieldBuildFailed(&'static str, Box<dyn Error + Send + Sync + 'static>),
}
impl From<ErrorKind> for PoolError {
#[inline]
fn from(kind: ErrorKind) -> Self {
Self { kind }
}
}
impl fmt::Display for ErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
#[cfg(feature = "std")]
Self::NotUtf8Path => write!(f, "path is not valid UTF-8"),
#[cfg(feature = "std")]
Self::NulByteInPath => write!(f, "path contains a nul byte"),
Self::BuildWriteConnection(_) => write!(f, "building write connection"),
Self::BuildReadConnection(index, _) => write!(f, "building read connection #{index}"),
Self::SetupWriteConnection(_) => write!(f, "setting up write connection"),
Self::SetupReadConnection(index, _) => write!(f, "setting up read connection #{index}"),
Self::NoConnections => write!(f, "no connections available"),
Self::ZeroConcurrency => write!(f, "read concurrency must be at least 1"),
Self::TooManyConnections(concurrency) => {
write!(
f,
"read concurrency {concurrency} exceeds the maximum of 64"
)
}
Self::IndexNotReadOnly(index) => write!(f, "index {index} is not read-only"),
Self::FieldNotReadOnly(name) => write!(f, "field {name} is not read-only"),
Self::IndexPrepareFailed(index, _) => write!(f, "index {index} prepare failed"),
Self::FieldPrepareFailed(name, _) => write!(f, "field {name} prepare failed"),
Self::IndexNotThreadSafe(index, _) => write!(f, "index {index} is not thread-safe"),
Self::FieldNotThreadSafe(name, _) => write!(f, "field {name} is not thread-safe"),
Self::IndexBuildFailed(index, source) => {
write!(f, "index {index} build failed: {source}")
}
Self::FieldBuildFailed(name, source) => {
write!(f, "field {name} build failed: {source}")
}
}
}
}
pub struct Pool<R, W> {
read_pool: Box<[UnsafeCell<R>]>,
used_read: AtomicU64,
write: UnsafeCell<W>,
concurrency: usize,
semaphore: Arc<Semaphore>,
}
unsafe impl<R, W> Send for Pool<R, W>
where
R: Send,
W: Send,
{
}
unsafe impl<R, W> Sync for Pool<R, W>
where
R: Send,
W: Send,
{
}
impl<R, W> Pool<R, W>
where
R: IsReadOnly,
W: Statements,
{
#[inline]
#[cfg(feature = "std")]
pub fn new(
open_options: OpenOptions,
path: impl AsRef<Path>,
read_concurrency: usize,
write_builder: impl ConnectionSetup,
read_builder: impl ConnectionSetup,
) -> Result<Self, PoolError> {
let path = path.as_ref();
let Some(bytes) = path.to_str() else {
return Err(PoolError::from(ErrorKind::NotUtf8Path));
};
let Ok(string) = CString::new(bytes) else {
return Err(PoolError::from(ErrorKind::NulByteInPath));
};
Self::_new_c_str(
open_options,
&string,
read_concurrency,
write_builder,
read_builder,
)
}
pub fn new_c_str(
open_options: OpenOptions,
path: &CStr,
read_concurrency: usize,
write_builder: impl ConnectionSetup,
read_builder: impl ConnectionSetup,
) -> Result<Self, PoolError> {
Self::_new_c_str(
open_options,
path,
read_concurrency,
write_builder,
read_builder,
)
}
fn _new_c_str(
open_options: OpenOptions,
path: &CStr,
concurrency: usize,
write_builder: impl ConnectionSetup,
read_builder: impl ConnectionSetup,
) -> Result<Self, PoolError> {
if concurrency == 0 {
return Err(PoolError::from(ErrorKind::ZeroConcurrency));
}
if concurrency > 64 {
return Err(PoolError::from(ErrorKind::TooManyConnections(concurrency)));
}
let mut write = open_options
.clone()
.read_write()
.open_c_str(path)
.map_err(ErrorKind::BuildWriteConnection)?;
write_builder
.setup(&mut write)
.map_err(ErrorKind::SetupWriteConnection)?;
let write = W::build(&mut write)?;
let mut read_pool = Vec::with_capacity(concurrency);
for index in 0..concurrency {
let mut read = open_options
.clone()
.no_create()
.read_only()
.open_c_str(path)
.map_err(move |error| ErrorKind::BuildReadConnection(index, error))?;
read_builder
.setup(&mut read)
.map_err(move |error| ErrorKind::SetupReadConnection(index, error))?;
let read = R::build(&mut read)?;
read_pool.push(UnsafeCell::new(read));
}
Ok(Self {
read_pool: Box::from(read_pool),
used_read: AtomicU64::new(0),
write: UnsafeCell::new(write),
concurrency,
semaphore: Arc::new(Semaphore::new(concurrency)),
})
}
pub fn as_read_mut(&mut self) -> &mut R {
unsafe { &mut *self.read_pool[0].get() }
}
pub fn as_write_mut(&mut self) -> &mut W {
unsafe { &mut *self.write.get() }
}
pub async fn shared(self: Arc<Self>) -> Result<SharedGuard<R, W>, PoolError> {
let result = self.semaphore.clone().acquire_owned().await;
let permit = match result {
Ok(permit) => permit,
Err(AcquireError { .. }) => return Err(PoolError::from(ErrorKind::NoConnections)),
};
let index = loop {
let value = self.used_read.load(Ordering::Acquire);
let index = value.trailing_ones();
if self
.used_read
.compare_exchange(
value,
value | (1 << index),
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
{
break index as usize;
}
};
let read = unsafe {
let slice = self.read_pool[index].get();
NonNull::new_unchecked(&mut *slice)
};
Ok(SharedGuard {
read,
pool: self,
index,
_permit: permit,
})
}
pub async fn exclusive(self: Arc<Self>) -> Result<ExclusiveGuard<R, W>, PoolError> {
let result = self
.semaphore
.clone()
.acquire_many_owned(self.concurrency as u32)
.await;
let permit = match result {
Ok(permit) => permit,
Err(AcquireError { .. }) => return Err(PoolError::from(ErrorKind::NoConnections)),
};
let write = unsafe { NonNull::new_unchecked(&mut *self.write.get()) };
Ok(ExclusiveGuard {
write,
_pool: self,
_permit: permit,
})
}
}
pub struct SharedGuard<R, W> {
read: NonNull<R>,
pool: Arc<Pool<R, W>>,
index: usize,
_permit: OwnedSemaphorePermit,
}
impl<R, W> Deref for SharedGuard<R, W> {
type Target = R;
#[inline]
fn deref(&self) -> &Self::Target {
unsafe { self.read.as_ref() }
}
}
impl<R, W> DerefMut for SharedGuard<R, W> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.read.as_mut() }
}
}
unsafe impl<R, W> Send for SharedGuard<R, W>
where
R: Send,
W: Send,
{
}
impl<R, W> Drop for SharedGuard<R, W> {
fn drop(&mut self) {
self.pool
.used_read
.fetch_and(!(1 << self.index), Ordering::AcqRel);
}
}
pub struct ExclusiveGuard<R, W> {
write: NonNull<W>,
_pool: Arc<Pool<R, W>>,
_permit: OwnedSemaphorePermit,
}
impl<R, W> Deref for ExclusiveGuard<R, W> {
type Target = W;
#[inline]
fn deref(&self) -> &Self::Target {
unsafe { self.write.as_ref() }
}
}
impl<R, W> DerefMut for ExclusiveGuard<R, W> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.write.as_mut() }
}
}
unsafe impl<R, W> Send for ExclusiveGuard<R, W> {}