use std::ops::Deref;
use std::sync::Arc;
use tokio_postgres::{Client, Row, Socket};
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::types::ToSql;
pub trait DatabaseConnection: Sync {
fn query_single(
&self,
statement: &str,
params: &[&(dyn ToSql + Sync)],
) -> impl std::future::Future<Output = crate::Result<Row>> + Send;
fn query_many(
&self,
statement: &str,
params: &[&(dyn ToSql + Sync)],
) -> impl std::future::Future<Output = crate::Result<Vec<Row>>> + Send;
fn execute_query(
&self,
statement: &str,
params: &[&(dyn ToSql + Sync)],
) -> impl std::future::Future<Output = crate::Result<u64>> + Send;
}
pub struct CrashOrmDatabaseConnection {
client: Client,
}
impl CrashOrmDatabaseConnection {
pub async fn new<T>(config: &str, tls: T) -> crate::Result<Self>
where
T: MakeTlsConnect<Socket>,
<T as MakeTlsConnect<Socket>>::Stream: Send + 'static,
{
let (client, connection) = tokio_postgres::connect(config, tls).await?;
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
Ok(Self { client })
}
#[cfg(test)]
pub async fn test() -> crate::Result<Self> {
Self::new(
&*std::env::var("DATABASE_URL").unwrap_or(String::from("postgresql://crash_orm:postgres@localhost/crash_orm_test")),
tokio_postgres::NoTls,
)
.await
}
}
impl Deref for CrashOrmDatabaseConnection {
type Target = Client;
fn deref(&self) -> &Self::Target {
&self.client
}
}
macro_rules! impl_database_connection {
($class:ty) => {
impl DatabaseConnection for $class {
async fn query_single(
&self,
statement: &str,
params: &[&(dyn ToSql + Sync)],
) -> crate::Result<Row> {
self.query_one(statement, params)
.await
.map_err(|e| e.into())
}
async fn query_many(
&self,
statement: &str,
params: &[&(dyn ToSql + Sync)],
) -> crate::Result<Vec<Row>> {
self.query(statement, params).await.map_err(|e| e.into())
}
async fn execute_query(
&self,
statement: &str,
params: &[&(dyn ToSql + Sync)],
) -> crate::Result<u64> {
self.execute(statement, params).await.map_err(|e| e.into())
}
}
};
}
impl_database_connection!(CrashOrmDatabaseConnection);
impl_database_connection!(Client);
impl<T: DatabaseConnection + Send> DatabaseConnection for Arc<T> {
async fn query_single(
&self,
statement: &str,
params: &[&(dyn ToSql + Sync)],
) -> crate::Result<Row> {
self.deref().query_single(statement, params).await
}
async fn query_many(
&self,
statement: &str,
params: &[&(dyn ToSql + Sync)],
) -> crate::Result<Vec<Row>> {
self.deref().query_many(statement, params).await
}
async fn execute_query(
&self,
statement: &str,
params: &[&(dyn ToSql + Sync)],
) -> crate::Result<u64> {
self.deref().execute_query(statement, params).await
}
}
#[cfg(test)]
mod tests {
use crate::prelude::CrashOrmDatabaseConnection;
#[tokio::test]
async fn test_connection() {
let connection = CrashOrmDatabaseConnection::test().await;
assert!(connection.is_ok());
let connection = connection.unwrap();
let rows = connection
.query_one("SELECT $1::TEXT;", &[&"hello world"])
.await;
assert!(rows.is_ok());
let rows = rows.unwrap();
let column: &str = rows.get(0);
assert_eq!(column, "hello world");
}
}