use super::GenericClient;
use crate::error::Error;
use async_trait::async_trait;
use futures::lock::Mutex;
use postgres_types::ToSql;
use std::collections::HashMap;
use std::hash::Hash;
use std::mem;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use tokio_postgres::{error::Error as SqlError, RowStream, Statement};
#[derive(Clone)]
pub struct Caching<C>
where
C: GenericClient,
{
client: C,
cache: Cache,
}
type Cache = Arc<Mutex<DynamicCache<StrKey, Statement>>>;
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
struct StrKey {
ptr: usize,
len: usize,
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum DynamicCache<K, V>
where
K: DynamicKey,
{
Linear(Vec<(K, V)>),
Hash(HashMap<K, V>),
}
trait DynamicKey: Hash + Eq {
const LINEAR_CUTOFF: usize;
}
impl<C> Caching<C>
where
C: GenericClient,
{
pub fn new(client: C) -> Caching<C> {
Caching {
client,
cache: Cache::default(),
}
}
pub fn into_inner(self) -> C {
self.client
}
}
impl<C> From<C> for Caching<C>
where
C: GenericClient,
{
fn from(client: C) -> Self {
Caching::new(client)
}
}
impl<C> Deref for Caching<C>
where
C: GenericClient,
{
type Target = C;
fn deref(&self) -> &Self::Target {
&self.client
}
}
impl<C> DerefMut for Caching<C>
where
C: GenericClient,
{
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.client
}
}
#[async_trait]
impl<C> GenericClient for Caching<C>
where
C: GenericClient + Sync + Send,
{
async fn prepare(&self, sql: &str) -> Result<Statement, SqlError> {
self.client.prepare(sql).await
}
async fn prepare_static(&self, sql: &'static str) -> Result<Statement, SqlError> {
if let Some(statement) = self.get_cached(sql).await {
Ok(statement)
} else {
let statement = self.client.prepare_static(sql).await?;
self.cache(sql, statement.clone()).await;
Ok(statement)
}
}
async fn execute_raw<'a>(
&'a self,
statement: &Statement,
parameters: &[&'a (dyn ToSql + Sync)],
) -> Result<u64, SqlError> {
self.client.execute_raw(statement, parameters).await
}
async fn query_raw<'a>(
&'a self,
statement: &Statement,
parameters: &[&'a (dyn ToSql + Sync)],
) -> Result<RowStream, SqlError> {
self.client.query_raw(statement, parameters).await
}
}
impl<C> Caching<C>
where
C: GenericClient,
{
async fn get_cached(&self, sql: &'static str) -> Option<Statement> {
let cache = self.cache.lock().await;
cache.get(&StrKey::new(sql)).map(Statement::clone)
}
async fn cache(&self, sql: &'static str, statement: Statement) {
let mut cache = self.cache.lock().await;
cache.insert(StrKey::new(sql), statement);
}
}
impl StrKey {
pub fn new(text: &'static str) -> StrKey {
StrKey {
ptr: text.as_ptr() as usize,
len: text.len(),
}
}
}
impl DynamicKey for StrKey {
const LINEAR_CUTOFF: usize = 64;
}
impl<K, V> DynamicCache<K, V>
where
K: DynamicKey,
{
pub fn get(&self, index: &K) -> Option<&V> {
match self {
DynamicCache::Linear(pairs) => pairs
.iter()
.find(|(key, _)| K::eq(key, &index))
.map(|(_, value)| value),
DynamicCache::Hash(map) => map.get(index),
}
}
pub fn insert(&mut self, key: K, value: V) {
match self {
DynamicCache::Linear(pairs) if pairs.len() >= K::LINEAR_CUTOFF => {
let map = mem::take(pairs).into_iter().collect();
*self = DynamicCache::Hash(map);
self.insert(key, value);
}
DynamicCache::Linear(pairs) => {
pairs.push((key, value));
}
DynamicCache::Hash(map) => {
map.insert(key, value);
}
}
}
}
impl<K, V> Default for DynamicCache<K, V>
where
K: DynamicKey,
{
fn default() -> Self {
DynamicCache::Linear(Vec::new())
}
}
macro_rules! impl_cached_transaction {
($client:ty, $transaction:ty) => {
impl Caching<$client> {
pub async fn transaction(&mut self) -> Result<Caching<$transaction>, Error> {
<$client>::transaction(self)
.await
.map(Caching::new)
.map_err(Error::BeginTransaction)
}
}
};
}
impl_cached_transaction!(tokio_postgres::Client, tokio_postgres::Transaction<'_>);
impl_cached_transaction!(
tokio_postgres::Transaction<'_>,
tokio_postgres::Transaction<'_>
);