use std::{
borrow::Cow,
collections::HashMap,
io::{self, Result},
sync::{Arc, OnceLock, RwLock},
task::{Context, Poll},
};
pub use bigdecimal::BigDecimal;
use futures::{future::poll_fn, Future};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum SqlType {
Bool,
Int,
BigInt,
Float,
Decimal,
Binary,
String,
Null,
}
#[derive(Debug, Clone, PartialEq)]
pub enum SqlValue<'a> {
Bool(bool),
Int(i64),
BigInt(i128),
Float(f64),
Decimal(BigDecimal),
Binary(Cow<'a, [u8]>),
String(Cow<'a, str>),
Null,
}
pub enum SqlParameter<'a> {
Named(Cow<'a, str>, SqlValue<'a>),
Offset(SqlValue<'a>),
}
pub mod syscall {
use super::*;
pub trait Driver: Send + Sync {
fn create_connection(&self, driver_name: &str, source_name: &str) -> Result<Connection>;
}
pub trait DriverConn: Send + Sync {
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<()>>;
fn begin(&self) -> Result<Transaction>;
fn prepare(&self, query: &str) -> Result<Prepare>;
fn exec(&self, query: &str, params: &[SqlParameter<'_>]) -> Result<Update>;
fn query(&self, query: &str, params: &[SqlParameter<'_>]) -> Result<Query>;
}
pub trait ConnectionPool: Send + Sync {
fn get_connection(
&self,
driver_name: &str,
source_name: &str,
) -> Result<Option<Connection>>;
fn idle_connection(
&self,
driver_name: &str,
driver_conn: Box<dyn DriverConn>,
) -> Result<()>;
}
pub trait DriverRow: Send + Sync {
fn get(&self, index: usize, sql_type: &SqlType) -> Result<SqlValue<'static>>;
}
pub trait DriverPrepare: Send + Sync {
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<()>>;
fn exec(&self, params: &[SqlParameter<'_>]) -> Result<Update>;
fn query(&self, params: &[SqlParameter<'_>]) -> Result<Query>;
}
pub trait DriverUpdate: Send {
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<(i64, i64)>>;
}
pub trait DriverTableMetadata: Send + Sync {
fn cols(&self) -> Result<usize>;
fn col_name(&self, offset: usize) -> Result<&str>;
fn col_offset(&self, col_name: &str) -> Result<usize> {
for i in 0..self.cols()? {
if self.col_name(i)? == col_name {
return Ok(i);
}
}
Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("Column with name '{}', not found", col_name),
))
}
fn col_type(&self, offset: usize) -> Result<Option<SqlType>>;
fn col_size(&self, offset: usize) -> Result<Option<usize>>;
}
pub trait DriverQuery: DriverTableMetadata {
fn poll_next(&self, cx: &mut Context<'_>) -> Poll<Result<Option<Row>>>;
}
pub trait DriverTx: Send + Sync {
fn poll_ready(&self, cx: &mut Context<'_>) -> Poll<Result<()>>;
fn poll_rollback(&self, cx: &mut Context<'_>) -> Poll<Result<()>>;
fn prepare(&self, query: &str) -> Result<Prepare>;
fn exec(&self, query: &str, params: &[SqlParameter<'_>]) -> Result<Update>;
fn query(&self, query: &str, params: &[SqlParameter<'_>]) -> Result<Query>;
}
}
pub struct Connection(String, Option<Box<dyn syscall::DriverConn>>);
impl<T: syscall::DriverConn + 'static> From<(String, T)> for Connection {
fn from(value: (String, T)) -> Self {
Self(value.0, Some(Box::new(value.1)))
}
}
impl Drop for Connection {
fn drop(&mut self) {
if let Some(conn_pool) = CONN_POOL.get() {
if let Err(err) = conn_pool.idle_connection(&self.0, self.1.take().unwrap()) {
log::error!(
target: "RDBC",
"Put idle connection int cache with error, {}",
err
);
}
}
}
}
impl Connection {
pub fn as_driver_conn(&self) -> &dyn syscall::DriverConn {
&*self.1.as_deref().unwrap()
}
pub async fn is_ready(&self) -> Result<()> {
poll_fn(|cx| self.1.as_ref().unwrap().poll_ready(cx)).await
}
pub async fn begin(&self) -> Result<Transaction> {
let tx = self.1.as_ref().unwrap().begin()?;
poll_fn(|cx| tx.0.poll_ready(cx)).await.map(|_| tx)
}
pub async fn prepare(&self, query: &str) -> Result<Prepare> {
let prepare = self.1.as_ref().unwrap().prepare(query)?;
poll_fn(|cx| prepare.0.poll_ready(cx))
.await
.map(|_| prepare)
}
pub async fn exec(&self, query: &str, params: &[SqlParameter<'_>]) -> Result<Update> {
self.1.as_ref().unwrap().exec(query, params)
}
pub async fn query(&self, query: &str, params: &[SqlParameter<'_>]) -> Result<Query> {
self.1.as_ref().unwrap().query(query, params)
}
}
pub struct Prepare(Box<dyn syscall::DriverPrepare>);
impl<T: syscall::DriverPrepare + 'static> From<T> for Prepare {
fn from(value: T) -> Self {
Self(Box::new(value))
}
}
impl Prepare {
pub fn as_driver_query(&self) -> &dyn syscall::DriverPrepare {
&*self.0
}
pub async fn exec(&self, params: &[SqlParameter<'_>]) -> Result<Update> {
self.0.exec(params)
}
pub async fn query(&self, params: &[SqlParameter<'_>]) -> Result<Query> {
self.0.query(params)
}
}
pub struct Update(Box<dyn syscall::DriverUpdate>);
impl<T: syscall::DriverUpdate + 'static> From<T> for Update {
fn from(value: T) -> Self {
Self(Box::new(value))
}
}
impl Future for Update {
type Output = Result<(i64, i64)>;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.poll_ready(cx)
}
}
pub struct Query(Box<dyn syscall::DriverQuery>);
impl<T: syscall::DriverQuery + 'static> From<T> for Query {
fn from(value: T) -> Self {
Self(Box::new(value))
}
}
impl Query {
pub fn as_driver_query(&self) -> &dyn syscall::DriverQuery {
&*self.0
}
pub async fn next(&self) -> Result<Option<Row>> {
poll_fn(|cx| self.0.poll_next(cx)).await
}
pub async fn cols(&self) -> Result<usize> {
self.0.cols()
}
pub async fn col_name(&self, offset: usize) -> Result<&str> {
self.0.col_name(offset)
}
pub async fn col_offset(&self, col_name: &str) -> Result<usize> {
self.0.col_offset(col_name)
}
pub async fn col_type(&self, offset: usize) -> Result<Option<SqlType>> {
self.0.col_type(offset)
}
pub async fn col_size(&self, offset: usize) -> Result<Option<usize>> {
self.0.col_size(offset)
}
}
pub struct Row(Box<dyn syscall::DriverRow>);
impl<T: syscall::DriverRow + 'static> From<T> for Row {
fn from(value: T) -> Self {
Self(Box::new(value))
}
}
impl Row {
pub fn as_driver_row(&self) -> &dyn syscall::DriverRow {
&*self.0
}
pub fn get(&self, index: usize, sql_type: &SqlType) -> Result<SqlValue<'static>> {
self.0.get(index, sql_type)
}
pub fn ensure_bool(&self, index: usize) -> Result<bool> {
Ok(self.get_bool(index)?.ok_or(io::Error::new(
io::ErrorKind::InvalidData,
"Column value is null",
))?)
}
pub fn get_bool(&self, index: usize) -> Result<Option<bool>> {
match self.get(index, &SqlType::Bool)? {
SqlValue::Bool(value) => Ok(Some(value)),
SqlValue::Null => Ok(None),
_ => Err(io::Error::new(
io::ErrorKind::InvalidData,
"Can't convert column value to bool type.",
)),
}
}
pub fn ensure_int(&self, index: usize) -> Result<i64> {
Ok(self.get_int(index)?.ok_or(io::Error::new(
io::ErrorKind::InvalidData,
"Column value is null",
))?)
}
pub fn get_int(&self, index: usize) -> Result<Option<i64>> {
match self.get(index, &SqlType::Int)? {
SqlValue::Int(value) => Ok(Some(value)),
SqlValue::Null => Ok(None),
_ => Err(io::Error::new(
io::ErrorKind::InvalidData,
"Can't convert column value to bool type.",
)),
}
}
pub fn ensure_bigint(&self, index: usize) -> Result<i128> {
Ok(self.get_bigint(index)?.ok_or(io::Error::new(
io::ErrorKind::InvalidData,
"Column value is null",
))?)
}
pub fn get_bigint(&self, index: usize) -> Result<Option<i128>> {
match self.get(index, &SqlType::BigInt)? {
SqlValue::BigInt(value) => Ok(Some(value)),
SqlValue::Null => Ok(None),
_ => Err(io::Error::new(
io::ErrorKind::InvalidData,
"Can't convert column value to bigint type.",
)),
}
}
pub fn ensure_decimal(&self, index: usize) -> Result<BigDecimal> {
Ok(self.get_decimal(index)?.ok_or(io::Error::new(
io::ErrorKind::InvalidData,
"Column value is null",
))?)
}
pub fn get_decimal(&self, index: usize) -> Result<Option<BigDecimal>> {
match self.get(index, &SqlType::Decimal)? {
SqlValue::Decimal(value) => Ok(Some(value)),
SqlValue::Null => Ok(None),
_ => Err(io::Error::new(
io::ErrorKind::InvalidData,
"Can't convert column value to decimal type.",
)),
}
}
pub fn ensure_float(&self, index: usize) -> Result<f64> {
Ok(self.get_float(index)?.ok_or(io::Error::new(
io::ErrorKind::InvalidData,
"Column value is null",
))?)
}
pub fn get_float(&self, index: usize) -> Result<Option<f64>> {
match self.get(index, &SqlType::Float)? {
SqlValue::Float(value) => Ok(Some(value)),
SqlValue::Null => Ok(None),
_ => Err(io::Error::new(
io::ErrorKind::InvalidData,
"Can't convert column value to decimal type.",
)),
}
}
pub fn ensure_binary(&self, index: usize) -> Result<Vec<u8>> {
Ok(self.get_binary(index)?.ok_or(io::Error::new(
io::ErrorKind::InvalidData,
"Column value is null",
))?)
}
pub fn get_binary(&self, index: usize) -> Result<Option<Vec<u8>>> {
match self.get(index, &SqlType::Binary)? {
SqlValue::Binary(value) => Ok(Some(value.into_owned())),
SqlValue::Null => Ok(None),
_ => Err(io::Error::new(
io::ErrorKind::InvalidData,
"Can't convert column value to decimal type.",
)),
}
}
pub fn ensure_string(&self, index: usize) -> Result<String> {
Ok(self.get_string(index)?.ok_or(io::Error::new(
io::ErrorKind::InvalidData,
"Column value is null",
))?)
}
pub fn get_string(&self, index: usize) -> Result<Option<String>> {
match self.get(index, &SqlType::String)? {
SqlValue::String(value) => Ok(Some(value.into_owned())),
SqlValue::Null => Ok(None),
_ => Err(io::Error::new(
io::ErrorKind::InvalidData,
"Can't convert column value to decimal type.",
)),
}
}
}
pub struct Transaction(Box<dyn syscall::DriverTx>);
impl<T: syscall::DriverTx + 'static> From<T> for Transaction {
fn from(value: T) -> Self {
Self(Box::new(value))
}
}
impl Transaction {
pub fn as_driver_tx(&self) -> &dyn syscall::DriverTx {
&*self.0
}
pub async fn rollback(&self) -> Result<()> {
poll_fn(|cx: &mut Context| self.0.poll_rollback(cx)).await
}
pub async fn prepare(&self, query: &str) -> Result<Prepare> {
let prepare = self.0.prepare(query)?;
poll_fn(|cx| prepare.0.poll_ready(cx))
.await
.map(|_| prepare)
}
pub async fn exec(&self, query: &str, params: &[SqlParameter<'_>]) -> Result<Update> {
self.0.exec(query, params)
}
pub async fn query(&self, query: &str, params: &[SqlParameter<'_>]) -> Result<Query> {
self.0.query(query, params)
}
}
#[derive(Default)]
struct GlobalRegister {
drivers: RwLock<HashMap<String, Arc<Box<dyn syscall::Driver>>>>,
}
static REGISTER: OnceLock<GlobalRegister> = OnceLock::new();
static CONN_POOL: OnceLock<Box<dyn syscall::ConnectionPool>> = OnceLock::new();
fn get_register() -> &'static GlobalRegister {
REGISTER.get_or_init(|| Default::default())
}
pub fn register_pool_strategy<P: syscall::ConnectionPool + 'static>(conn_pool: P) {
if CONN_POOL.set(Box::new(conn_pool)).is_err() {
panic!("Call register_pool_strategy more than once.")
}
}
pub async fn open<D: AsRef<str>, S: AsRef<str>>(
driver_name: D,
source_name: S,
) -> Result<Connection> {
let drivers = get_register()
.drivers
.read()
.map_err(|err| io::Error::new(io::ErrorKind::Interrupted, err.to_string()))?;
let driver_name = driver_name.as_ref();
let source_name = source_name.as_ref();
if let Some(database) = drivers.get(driver_name) {
if let Some(conn_pool) = CONN_POOL.get() {
log::debug!(target: "RDBC", "Try get connection from cached pool");
if let Some(conn) = conn_pool.get_connection(driver_name, source_name)? {
log::debug!(
target: "RDBC",
"Get connection from cached pool, driver={}, source={}",
driver_name,source_name,
);
return Ok(conn);
}
}
log::debug!(
target: "RDBC",
"Create new connection, driver={}, source={}",
driver_name,
source_name
);
let conn = database.create_connection(driver_name, source_name)?;
poll_fn(|cx| conn.1.as_ref().unwrap().poll_ready(cx)).await?;
return Ok(conn);
} else {
return Err(io::Error::new(
io::ErrorKind::NotFound,
format!("Unknown driver: {}", driver_name),
));
}
}
pub fn register_rdbc_driver<N: AsRef<str>, D: syscall::Driver + 'static>(
driver_name: N,
database: D,
) -> Result<()> {
let mut drivers = get_register()
.drivers
.write()
.map_err(|err| io::Error::new(io::ErrorKind::Interrupted, err.to_string()))?;
assert!(
drivers
.insert(
driver_name.as_ref().to_owned(),
Arc::new(Box::new(database))
)
.is_none(),
"register driver twice: {}",
driver_name.as_ref(),
);
Ok(())
}