use crate::{value::Value, EResult, Error, OID};
use once_cell::sync::OnceCell;
use sqlx::encode::IsNull;
use sqlx::error::BoxDynError;
use sqlx::postgres::{self, PgConnectOptions, PgPool, PgPoolOptions};
use sqlx::sqlite::{self, SqliteConnectOptions, SqlitePool, SqlitePoolOptions};
use sqlx::{database, ConnectOptions, Database, Decode, Encode};
use sqlx::{Postgres, Sqlite, Type};
use std::borrow::Cow;
use std::str::FromStr;
use std::time::Duration;
pub mod prelude {
pub use super::{db_init, db_pool, DbKind, DbPool, Transaction};
}
static DB_POOL: OnceCell<DbPool> = OnceCell::new();
impl Type<Sqlite> for OID {
fn type_info() -> sqlite::SqliteTypeInfo {
<str as Type<Sqlite>>::type_info()
}
}
impl Type<Postgres> for OID {
fn type_info() -> postgres::PgTypeInfo {
<str as Type<Postgres>>::type_info()
}
fn compatible(ty: &postgres::PgTypeInfo) -> bool {
let s = ty.to_string();
s == "TEXT" || s == "VARCHAR"
}
}
impl<'r, DB: Database> Decode<'r, DB> for OID
where
&'r str: Decode<'r, DB>,
{
fn decode(value: <DB as database::HasValueRef<'r>>::ValueRef) -> Result<Self, BoxDynError> {
let value = <&str as Decode<DB>>::decode(value)?;
value.parse().map_err(Into::into)
}
}
impl<'q> Encode<'q, Sqlite> for OID {
fn encode(self, args: &mut Vec<sqlite::SqliteArgumentValue<'q>>) -> IsNull {
args.push(sqlite::SqliteArgumentValue::Text(Cow::Owned(
self.to_string(),
)));
IsNull::No
}
fn encode_by_ref(&self, args: &mut Vec<sqlite::SqliteArgumentValue<'q>>) -> IsNull {
args.push(sqlite::SqliteArgumentValue::Text(Cow::Owned(
self.to_string(),
)));
IsNull::No
}
fn size_hint(&self) -> usize {
self.as_str().len()
}
}
impl Encode<'_, Postgres> for OID {
fn encode_by_ref(&self, buf: &mut postgres::PgArgumentBuffer) -> IsNull {
<&str as Encode<Postgres>>::encode(self.as_str(), buf)
}
fn size_hint(&self) -> usize {
self.as_str().len()
}
}
impl Type<Sqlite> for Value {
fn type_info() -> sqlite::SqliteTypeInfo {
<str as Type<Sqlite>>::type_info()
}
fn compatible(ty: &sqlite::SqliteTypeInfo) -> bool {
<&str as Type<Sqlite>>::compatible(ty)
}
}
impl Type<Postgres> for Value {
fn type_info() -> postgres::PgTypeInfo {
postgres::PgTypeInfo::with_name("JSONB")
}
}
impl Encode<'_, Sqlite> for Value {
fn encode_by_ref(&self, buf: &mut Vec<sqlite::SqliteArgumentValue<'_>>) -> IsNull {
let json_string_value =
serde_json::to_string(self).expect("serde_json failed to convert to string");
Encode::<Sqlite>::encode(json_string_value, buf)
}
}
impl<'r> Decode<'r, Sqlite> for Value {
fn decode(value: sqlite::SqliteValueRef<'r>) -> Result<Self, BoxDynError> {
let string_value = <&str as Decode<Sqlite>>::decode(value)?;
serde_json::from_str(string_value).map_err(Into::into)
}
}
impl<'q> Encode<'q, Postgres> for Value {
fn encode_by_ref(&self, buf: &mut postgres::PgArgumentBuffer) -> IsNull {
buf.push(1);
serde_json::to_writer(&mut **buf, &self)
.expect("failed to serialize to JSON for encoding on transmission to the database");
IsNull::No
}
}
impl<'r> Decode<'r, Postgres> for Value {
fn decode(value: postgres::PgValueRef<'r>) -> Result<Self, BoxDynError> {
let buf = value.as_bytes()?;
serde_json::from_slice(&buf[1..]).map_err(Into::into)
}
}
#[allow(clippy::module_name_repetitions)]
#[inline]
pub fn db_pool() -> &'static DbPool {
DB_POOL.get().unwrap()
}
#[allow(clippy::module_name_repetitions)]
pub enum DbPool {
Sqlite(SqlitePool),
Postgres(PgPool),
}
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum DbKind {
Sqlite,
Postgres,
}
impl DbPool {
pub async fn begin(&self) -> Result<Transaction<'_>, sqlx::Error> {
match self {
DbPool::Sqlite(p) => Ok(Transaction::Sqlite(p.begin().await?)),
DbPool::Postgres(p) => Ok(Transaction::Postgres(p.begin().await?)),
}
}
pub fn kind(&self) -> DbKind {
match self {
DbPool::Sqlite(_) => DbKind::Sqlite,
DbPool::Postgres(_) => DbKind::Postgres,
}
}
pub async fn execute(&self, q: &str) -> EResult<()> {
match self {
DbPool::Sqlite(ref p) => {
sqlx::query(q).execute(p).await?;
}
DbPool::Postgres(ref p) => {
sqlx::query(q).execute(p).await?;
}
}
Ok(())
}
}
pub enum Transaction<'c> {
Sqlite(sqlx::Transaction<'c, sqlx::sqlite::Sqlite>),
Postgres(sqlx::Transaction<'c, sqlx::postgres::Postgres>),
}
impl<'c> Transaction<'c> {
pub async fn commit(self) -> Result<(), sqlx::Error> {
match self {
Transaction::Sqlite(tx) => tx.commit().await,
Transaction::Postgres(tx) => tx.commit().await,
}
}
pub fn kind(&self) -> DbKind {
match self {
Transaction::Sqlite(_) => DbKind::Sqlite,
Transaction::Postgres(_) => DbKind::Postgres,
}
}
pub async fn execute(&mut self, q: &str) -> EResult<()> {
match self {
Transaction::Sqlite(ref mut p) => {
sqlx::query(q).execute(p).await?;
}
Transaction::Postgres(ref mut p) => {
sqlx::query(q).execute(p).await?;
}
}
Ok(())
}
}
#[allow(clippy::module_name_repetitions)]
pub async fn db_init(conn: &str, pool_size: u32, timeout: Duration) -> EResult<()> {
let pool = if conn.starts_with("sqlite://") {
let mut opts = SqliteConnectOptions::from_str(conn)?
.create_if_missing(true)
.synchronous(sqlx::sqlite::SqliteSynchronous::Extra)
.busy_timeout(timeout);
opts.log_statements(log::LevelFilter::Trace)
.log_slow_statements(log::LevelFilter::Warn, timeout);
DbPool::Sqlite(
SqlitePoolOptions::new()
.max_connections(pool_size)
.acquire_timeout(timeout)
.connect_with(opts)
.await?,
)
} else if conn.starts_with("postgres://") {
let mut opts = PgConnectOptions::from_str(conn)?;
opts.log_statements(log::LevelFilter::Trace)
.log_slow_statements(log::LevelFilter::Warn, timeout);
DbPool::Postgres(
PgPoolOptions::new()
.max_connections(pool_size)
.acquire_timeout(timeout)
.connect_with(opts)
.await?,
)
} else {
return Err(Error::unsupported("Unsupported database kind"));
};
DB_POOL
.set(pool)
.map_err(|_| Error::core("unable to set DB_POOL"))?;
Ok(())
}