use super::{DbError, Row, SimpleQueryMessage, Statement, ToSql, ToStatement};
use tokio_postgres::Client;
pub struct Connection {
client: Client,
transaction_n: u32,
}
impl From<Client> for Connection {
fn from(client: Client) -> Self {
Self {
client,
transaction_n: 0,
}
}
}
impl Connection {
pub async fn batch_execute(&self, query: &str) -> Result<(), DbError> {
self.client
.batch_execute(query)
.await
.map_err(DbError::from)
}
pub async fn execute<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<u64, DbError>
where
T: ?Sized + ToStatement,
{
let statement = &statement.__convert().into_statement(self).await?;
self.client
.execute(&**statement, params)
.await
.map_err(DbError::from)
}
pub async fn prepare(&self, query: &str) -> Result<Statement, DbError> {
self.client
.prepare(query)
.await
.map(Statement::from)
.map_err(DbError::from)
}
pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, DbError> {
self.client
.simple_query(query)
.await
.map(|rows| rows.into_iter().map(SimpleQueryMessage::from).collect())
.map_err(DbError::from)
}
pub async fn query<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Vec<Row>, DbError>
where
T: ?Sized + ToStatement,
{
let statement = &statement.__convert().into_statement(self).await?;
self.client
.query(&**statement, params)
.await
.map(|rows| rows.into_iter().map(Row::from).collect())
.map_err(DbError::from)
}
pub async fn query_one<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Row, DbError>
where
T: ?Sized + ToStatement,
{
let statement = &statement.__convert().into_statement(self).await?;
self.client
.query_one(&**statement, params)
.await
.map(Row::from)
.map_err(DbError::from)
}
pub async fn query_opt<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Option<Row>, DbError>
where
T: ?Sized + ToStatement,
{
let statement = &statement.__convert().into_statement(self).await?;
self.client
.query_opt(&**statement, params)
.await
.map(|option_row| match option_row {
Some(row) => Some(Row::from(row)),
_ => None,
})
.map_err(DbError::from)
}
pub async fn transaction(&mut self) -> Result<(), DbError> {
let qry = if self.transaction_n == 0 {
"BEGIN".into()
} else {
format!("SAVEPOINT pt{}", self.transaction_n)
};
self.batch_execute(&qry).await?;
self.transaction_n += 1;
Ok(())
}
pub async fn commit(&mut self) -> Result<(), DbError> {
if self.transaction_n == 0 {
Err(DbError::new("Not in a transaction", None))
} else {
let qry = if self.transaction_n == 1 {
String::from("COMMIT")
} else {
format!("RELEASE pt{}", self.transaction_n - 1)
};
self.batch_execute(&qry).await?;
self.transaction_n -= 1;
Ok(())
}
}
pub async fn rollback(&mut self) -> Result<(), DbError> {
if self.transaction_n == 0 {
Err(DbError::new("Not in a transaction", None))
} else {
let qry = if self.transaction_n == 1 {
String::from("ROLLBACK")
} else {
format!("ROLLBACK TO SAVEPOINT pt{}", self.transaction_n - 1)
};
self.batch_execute(&qry).await?;
self.transaction_n -= 1;
Ok(())
}
}
}