use super::{Connection, Parameter};
use std::{
collections::HashMap,
hash::Hash,
sync::Arc,
time::{Duration, Instant},
};
use thiserror::Error;
use tokio::sync::RwLock;
use tonic::{async_trait, Status};
use tracing::Instrument;
use uuid::Uuid;
#[derive(Error, Debug)]
pub enum Error<C>
where
C: std::error::Error + 'static,
{
#[error(transparent)]
Connection(C),
#[error("Error retrieving connection from transaction pool")]
ConnectionFailure,
#[error("Requested transaction has not been initialized or was cleaned up due to inactivity")]
Uninitialized,
}
impl<C> From<Error<C>> for Status
where
C: std::error::Error + Into<Status> + 'static,
{
fn from(error: Error<C>) -> Self {
match error {
Error::Connection(error) => error.into(),
Error::ConnectionFailure => Status::resource_exhausted(error.to_string()),
Error::Uninitialized => Status::not_found(error.to_string()),
}
}
}
impl<C> From<C> for Error<C>
where
C: std::error::Error + Into<Status> + 'static,
{
fn from(connection_error: C) -> Self {
Self::Connection(connection_error)
}
}
const VACUUM_POLLING_INTERVAL_SECONDS: u64 = 1;
const INACTIVE_THRESHOLD_SECONDS: u64 = 30;
const TRANSACTION_LIFETIME_LIMIT_SECONDS: u64 = 30 * 60;
pub struct Transaction<C>
where
C: Connection,
{
connection: Arc<C>,
created_at: Instant,
last_used_at: Arc<RwLock<Instant>>,
}
impl<C> Transaction<C>
where
C: Connection,
{
fn new(connection: Arc<C>) -> Self {
let now = Instant::now();
Self {
connection,
created_at: now,
last_used_at: Arc::new(RwLock::new(now)),
}
}
}
impl<C> Clone for Transaction<C>
where
C: Connection,
{
fn clone(&self) -> Self {
Self {
connection: Arc::clone(&self.connection),
created_at: self.created_at,
last_used_at: Arc::clone(&self.last_used_at),
}
}
}
#[async_trait]
impl<C> Connection for Transaction<C>
where
C: Connection + Send + Sync + 'static,
{
type Error = C::Error;
type RowStream = C::RowStream;
#[tracing::instrument(skip(self, parameters))]
async fn query(
&self,
statement: &str,
parameters: &[Parameter],
) -> Result<Self::RowStream, Self::Error> {
tracing::trace!("Querying transaction Connection");
let rows = self.connection.query(statement, parameters).await?;
let mut last_used_at = self.last_used_at.write().await;
*last_used_at = Instant::now();
Ok(rows)
}
#[tracing::instrument(skip(self))]
async fn batch(&self, query: &str) -> Result<(), Self::Error> {
tracing::trace!("Executing batch query on transaction Connection");
self.connection.batch(query).await?;
let mut last_used_at = self.last_used_at.write().await;
*last_used_at = Instant::now();
Ok(())
}
}
#[derive(Debug, Clone, Eq, Hash, PartialEq)]
pub struct Key<K>
where
K: Hash + Eq,
{
key: K,
transaction_id: Uuid,
}
impl<K> Key<K>
where
K: Hash + Eq,
{
pub fn new(key: K, transaction_id: Uuid) -> Self {
Self {
key,
transaction_id,
}
}
}
type TransactionMap<K, C> = HashMap<Key<K>, Transaction<C>>;
pub struct Pool<P>
where
P: super::Pool,
P::Key: Hash + Eq + Clone,
{
pool: Arc<P>,
transactions: Arc<RwLock<TransactionMap<P::Key, P::Connection>>>,
}
impl<P> Clone for Pool<P>
where
P: super::Pool,
P::Key: Hash + Eq + Clone,
{
fn clone(&self) -> Self {
Self {
pool: Arc::clone(&self.pool),
transactions: Arc::clone(&self.transactions),
}
}
}
impl<P> Pool<P>
where
P: super::Pool + 'static,
P::Key: Hash + Eq + Send + Sync + Clone + 'static,
P::Connection: 'static,
<P::Connection as Connection>::Error: Send + Sync + 'static,
{
#[tracing::instrument(skip(pool))]
pub fn new(pool: Arc<P>) -> Self {
tracing::debug!("Creating transaction pool from connection pool");
let transactions = Arc::new(RwLock::new(HashMap::new()));
let cache = Self {
pool,
transactions: Arc::clone(&transactions),
};
let shared_cache = cache.clone();
let polling_interval = Duration::from_secs(VACUUM_POLLING_INTERVAL_SECONDS);
let inactive_limit = Duration::from_secs(INACTIVE_THRESHOLD_SECONDS);
let created_at_limit = Duration::from_secs(TRANSACTION_LIFETIME_LIMIT_SECONDS);
tokio::spawn(
async move {
loop {
tokio::time::sleep(polling_interval).await;
let now = Instant::now();
let mut rollback_queue = vec![];
for (transaction_key, transaction) in transactions.read().await.iter() {
let last_used_at = transaction.last_used_at.read().await;
let is_inactive = (now - *last_used_at) > inactive_limit;
let is_too_old = (now - transaction.created_at) > created_at_limit;
if is_inactive || is_too_old {
rollback_queue.push(transaction_key.clone());
}
}
for transaction_key in rollback_queue.into_iter() {
if let Err(error) = shared_cache
.rollback(transaction_key.transaction_id, transaction_key.key)
.await
{
tracing::error!(%error, "Error removing stale transaction from cache");
}
}
}
}
.instrument(tracing::info_span!("vacuum")),
);
cache
}
#[tracing::instrument(skip(self))]
pub async fn begin(
&self,
key: P::Key,
) -> Result<Uuid, Error<<P::Connection as Connection>::Error>> {
let transaction_id = Uuid::new_v4();
tracing::trace!(%transaction_id, "Beginning transaction");
let transaction_key = Key {
key: key.clone(),
transaction_id,
};
let connection = self
.pool
.get_connection(key)
.await
.map_err(|_| Error::ConnectionFailure)?;
connection.batch("BEGIN").await.map_err(Error::Connection)?;
let transaction = Transaction::new(Arc::new(connection));
self.transactions
.write()
.await
.insert(transaction_key, transaction);
tracing::trace!(%transaction_id, "Transaction successfully cached");
Ok(transaction_id)
}
#[tracing::instrument(skip(self))]
pub async fn commit(
&self,
transaction_id: Uuid,
key: P::Key,
) -> Result<(), Error<<P::Connection as Connection>::Error>> {
tracing::trace!("Committing active transaction");
self.remove(transaction_id, key)
.await?
.connection
.batch("COMMIT")
.await
.map_err(Error::Connection)?;
Ok(())
}
#[tracing::instrument(skip(self))]
pub async fn rollback(
&self,
transaction_id: Uuid,
key: P::Key,
) -> Result<(), Error<<P::Connection as Connection>::Error>> {
tracing::trace!("Rolling back active transaction");
self.remove(transaction_id, key)
.await?
.connection
.batch("ROLLBACK")
.await
.map_err(Error::Connection)?;
Ok(())
}
#[tracing::instrument(skip(self))]
async fn remove(
&self,
transaction_id: Uuid,
key: P::Key,
) -> Result<Transaction<P::Connection>, Error<<P::Connection as Connection>::Error>> {
tracing::trace!("Removing transaction from the cache");
let transaction = self
.transactions
.write()
.await
.remove(&Key {
key,
transaction_id,
})
.ok_or(Error::Uninitialized)?;
Ok(transaction)
}
}
#[async_trait]
impl<P> super::Pool for Pool<P>
where
P: super::Pool,
P::Key: Hash + Eq + Send + Sync + Clone,
P::Connection: 'static,
<P::Connection as Connection>::Error: Send + Sync + Into<Status> + 'static,
{
type Key = Key<P::Key>;
type Connection = Transaction<P::Connection>;
type Error = Error<<Self::Connection as Connection>::Error>;
#[tracing::instrument(
skip(self, key),
fields(
?key = key.key,
%transaction_id = key.transaction_id
)
)]
async fn get_connection(&self, key: Self::Key) -> Result<Self::Connection, Self::Error> {
let transaction = self
.transactions
.read()
.await
.get(&key)
.cloned()
.ok_or(Error::Uninitialized)?;
Ok(transaction)
}
}