pub use futures::future::join_all;
pub use futures::future::try_join_all;
pub use tokio::join;
pub use tokio::try_join;
pub use tokio_postgres::Error as PsqlError;
pub use tokio_postgres::Row;
use bytes::Buf;
use futures::lock::Mutex;
use std::borrow::Cow;
use std::collections::HashMap;
use std::error;
use std::fmt;
use std::sync::Arc;
use tokio_postgres::types::{ToSql, Type};
use tokio_postgres::{
CancelToken, Client, CopyInSink, CopyOutStream, RowStream, SimpleQueryMessage, Statement,
};
use tokio_postgres::{ToStatement, Transaction};
#[derive(Clone)]
pub enum StatementState {
Awaiting(Arc<Mutex<()>>),
Ready(Statement),
}
pub enum TransactionOrClient<'a> {
Client(&'a mut Client),
Transaction(Transaction<'a>),
}
impl TransactionOrClient<'_> {
pub async fn commit(self) -> Result<(), PsqlError> {
match self {
TransactionOrClient::Client(_) => Ok(()),
TransactionOrClient::Transaction(transaction) => transaction.commit().await,
}
}
pub async fn rollback(self) -> Result<(), PsqlError> {
match self {
TransactionOrClient::Client(_) => Ok(()),
TransactionOrClient::Transaction(transaction) => transaction.rollback().await,
}
}
}
pub struct Context<'a> {
cache: Cache,
toc: TransactionOrClient<'a>,
}
impl<'i> Context<'i> {
pub fn new_direct(client: &'i mut Client) -> Self {
Self {
cache: Default::default(),
toc: TransactionOrClient::Client(client),
}
}
pub fn new_transactional(transaction: Transaction<'i>) -> Self {
Self {
cache: Default::default(),
toc: TransactionOrClient::Transaction(transaction),
}
}
pub async fn prepared<I: Into<Cow<'static, str>>>(
&self,
statement_str: I,
) -> Result<Statement, PsqlError> {
let statement_str = statement_str.into();
loop {
if let Some(statement) = self.cache.fast.get(&statement_str) {
return Ok(statement.clone());
}
let mut slow_locked = self.cache.slow.lock().await;
if let Some(statement) = slow_locked.get(&statement_str) {
match statement {
StatementState::Awaiting(mutex) => {
let local = mutex.clone();
drop(slow_locked); drop(local.lock().await); continue; }
StatementState::Ready(statement) => return Ok(statement.clone()),
};
} else {
let mutex = Arc::new(Mutex::new(()));
let lock = mutex.lock().await;
slow_locked.insert(
statement_str.clone(),
StatementState::Awaiting(mutex.clone()),
);
drop(slow_locked);
let statement = self.prepare(&statement_str).await?;
self.cache
.slow
.lock()
.await
.insert(statement_str, StatementState::Ready(statement.clone()));
drop(lock); return Ok(statement);
}
}
}
pub fn split(mut self) -> (Cache, TransactionOrClient<'i>) {
self.optimize_cache();
self.split_unoptimized()
}
pub fn split_unoptimized(self) -> (Cache, TransactionOrClient<'i>) {
(self.cache, self.toc)
}
pub fn optimize_cache(&mut self) {
self.cache.optimize();
}
pub async fn prepare(&self, query: &str) -> Result<Statement, PsqlError> {
match &self.toc {
TransactionOrClient::Client(client) => client.prepare(query).await,
TransactionOrClient::Transaction(transaction) => transaction.prepare(query).await,
}
}
pub async fn prepare_typed(
&self,
query: &str,
parameter_types: &[Type],
) -> Result<Statement, PsqlError> {
match &self.toc {
TransactionOrClient::Client(client) => {
client.prepare_typed(query, parameter_types).await
}
TransactionOrClient::Transaction(transaction) => {
transaction.prepare_typed(query, parameter_types).await
}
}
}
pub async fn query<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Vec<Row>, PsqlError>
where
T: ?Sized + ToStatement,
{
match &self.toc {
TransactionOrClient::Client(client) => client.query(statement, params).await,
TransactionOrClient::Transaction(transaction) => {
transaction.query(statement, params).await
}
}
}
pub async fn query_one<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Row, PsqlError>
where
T: ?Sized + ToStatement,
{
match &self.toc {
TransactionOrClient::Client(client) => client.query_one(statement, params).await,
TransactionOrClient::Transaction(transaction) => {
transaction.query_one(statement, params).await
}
}
}
pub async fn query_opt<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Option<Row>, PsqlError>
where
T: ?Sized + ToStatement,
{
match &self.toc {
TransactionOrClient::Client(client) => client.query_opt(statement, params).await,
TransactionOrClient::Transaction(transaction) => {
transaction.query_opt(statement, params).await
}
}
}
pub async fn query_raw<'b, T, I>(
&self,
statement: &T,
params: I,
) -> Result<RowStream, PsqlError>
where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'b dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
match &self.toc {
TransactionOrClient::Client(client) => client.query_raw(statement, params).await,
TransactionOrClient::Transaction(transaction) => {
transaction.query_raw(statement, params).await
}
}
}
pub async fn execute<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<u64, PsqlError>
where
T: ?Sized + ToStatement,
{
match &self.toc {
TransactionOrClient::Client(client) => client.execute(statement, params).await,
TransactionOrClient::Transaction(transaction) => {
transaction.execute(statement, params).await
}
}
}
pub async fn execute_raw<'b, I, T>(&self, statement: &T, params: I) -> Result<u64, PsqlError>
where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'b dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
match &self.toc {
TransactionOrClient::Client(client) => client.execute_raw(statement, params).await,
TransactionOrClient::Transaction(transaction) => {
transaction.execute_raw(statement, params).await
}
}
}
pub async fn copy_in<T, U>(&self, statement: &T) -> Result<CopyInSink<U>, PsqlError>
where
T: ?Sized + ToStatement,
U: Buf + 'static + Send,
{
match &self.toc {
TransactionOrClient::Client(client) => client.copy_in(statement).await,
TransactionOrClient::Transaction(transaction) => transaction.copy_in(statement).await,
}
}
pub async fn copy_out<T>(&self, statement: &T) -> Result<CopyOutStream, PsqlError>
where
T: ?Sized + ToStatement,
{
match &self.toc {
TransactionOrClient::Client(client) => client.copy_out(statement).await,
TransactionOrClient::Transaction(transaction) => transaction.copy_out(statement).await,
}
}
pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, PsqlError> {
match &self.toc {
TransactionOrClient::Client(client) => client.simple_query(query).await,
TransactionOrClient::Transaction(transaction) => transaction.simple_query(query).await,
}
}
pub async fn batch_execute(&self, query: &str) -> Result<(), PsqlError> {
match &self.toc {
TransactionOrClient::Client(client) => client.batch_execute(query).await,
TransactionOrClient::Transaction(transaction) => transaction.batch_execute(query).await,
}
}
pub fn cancel_token(&self) -> CancelToken {
match &self.toc {
TransactionOrClient::Client(client) => client.cancel_token(),
TransactionOrClient::Transaction(transaction) => transaction.cancel_token(),
}
}
}
#[derive(Default)]
pub struct Cache {
fast: HashMap<Cow<'static, str>, Statement>,
slow: Mutex<HashMap<Cow<'static, str>, StatementState>>,
}
impl Cache {
pub fn into_transaction_context(self, transaction: Transaction) -> Context {
Context {
cache: self,
toc: TransactionOrClient::Transaction(transaction),
}
}
pub fn into_client_context(self, client: &mut Client) -> Context {
Context {
cache: self,
toc: TransactionOrClient::Client(client),
}
}
pub fn iter_over_fast(&self) -> impl Iterator<Item = (&str, &Statement)> {
self.fast.iter().map(|(key, value)| (key.as_ref(), value))
}
pub fn optimize(&mut self) {
self.fast.extend(
self.slow
.get_mut()
.drain()
.flat_map(|(key, value)| match value {
StatementState::Ready(statement) => Some((key, statement)),
StatementState::Awaiting(_) => None, }),
);
}
}
#[derive(Debug)]
pub enum Error {
Psql(PsqlError),
UnexpectedVariant(usize),
NoEntryFoundForId(i32),
RowUnloadable,
}
impl error::Error for Error {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
#[allow(clippy::match_same_arms)]
match self {
Error::Psql(psql) => psql.source(),
Error::UnexpectedVariant(_) => None,
Error::NoEntryFoundForId(_) => None,
Error::RowUnloadable => None,
}
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Error::Psql(psql) => psql.fmt(f),
Error::UnexpectedVariant(index) => write!(f, "Unexpected variant index: {}", index),
Error::NoEntryFoundForId(id) => write!(f, "Id {} is unknown", id),
Error::RowUnloadable => write!(f, "The row has an error and cannot be loaded"),
}
}
}
impl From<PsqlError> for Error {
fn from(psql: PsqlError) -> Self {
Error::Psql(psql)
}
}