use std::{
hash::{BuildHasher, Hash, Hasher, RandomState},
sync::Arc,
};
use async_trait::async_trait;
use lru_cache::LruCache;
use postgres_types::Type;
#[cfg(feature = "telemetry")]
use telemetry::formatting::strip_query_traceparent;
use tokio::sync::Mutex;
use tokio_postgres::{Client, Error, Statement};
use super::query::{PreparedQuery, QueryMetadata, TypedQuery};
#[async_trait]
pub trait QueryCache: From<CacheSettings> + Send + Sync {
type Query<'a>: PreparedQuery;
async fn get_query<'a>(&self, client: &Client, sql: &'a str, types: &[Type]) -> Result<Self::Query<'a>, Error>;
async fn get_statement(&self, client: &Client, sql: &str, types: &[Type]) -> Result<Statement, Error>;
}
#[derive(Debug, Default)]
pub struct NoOpCache;
#[async_trait]
impl QueryCache for NoOpCache {
type Query<'a> = Statement;
#[inline]
async fn get_query<'a>(&self, client: &Client, sql: &'a str, types: &[Type]) -> Result<Statement, Error> {
self.get_statement(client, sql, types).await
}
#[inline]
async fn get_statement(&self, client: &Client, sql: &str, types: &[Type]) -> Result<Statement, Error> {
client.prepare_typed(sql, types).await
}
}
impl From<CacheSettings> for NoOpCache {
fn from(_: CacheSettings) -> Self {
Self
}
}
#[derive(Debug)]
pub struct PreparedStatementLruCache {
cache: InnerLruCache<Statement>,
}
impl PreparedStatementLruCache {
pub fn with_capacity(capacity: usize) -> Self {
Self {
cache: InnerLruCache::with_capacity(capacity),
}
}
}
#[async_trait]
impl QueryCache for PreparedStatementLruCache {
type Query<'a> = Statement;
#[inline]
async fn get_query<'a>(&self, client: &Client, sql: &'a str, types: &[Type]) -> Result<Statement, Error> {
self.get_statement(client, sql, types).await
}
async fn get_statement(&self, client: &Client, sql: &str, types: &[Type]) -> Result<Statement, Error> {
match self.cache.get(sql, types).await {
Some(statement) => Ok(statement),
None => {
let stmt = client.prepare_typed(sql, types).await?;
self.cache.insert(sql, types, stmt.clone()).await;
Ok(stmt)
}
}
}
}
impl From<CacheSettings> for PreparedStatementLruCache {
fn from(settings: CacheSettings) -> Self {
Self::with_capacity(settings.capacity)
}
}
#[derive(Debug)]
pub struct TracingLruCache {
cache: InnerLruCache<Arc<QueryMetadata>>,
}
impl TracingLruCache {
pub fn with_capacity(capacity: usize) -> Self {
Self {
cache: InnerLruCache::with_capacity(capacity),
}
}
}
#[cfg(not(feature = "telemetry"))]
fn strip_query_traceparent(sql: &str) -> &str {
sql
}
#[async_trait]
impl QueryCache for TracingLruCache {
type Query<'a> = TypedQuery<'a>;
async fn get_query<'a>(&self, client: &Client, sql: &'a str, types: &[Type]) -> Result<TypedQuery<'a>, Error> {
let sql_without_traceparent = strip_query_traceparent(sql);
let metadata = match self.cache.get(sql_without_traceparent, types).await {
Some(metadata) => metadata,
None => {
let stmt = client.prepare_typed(sql_without_traceparent, types).await?;
let metadata = Arc::new(QueryMetadata::from(&stmt));
self.cache
.insert(sql_without_traceparent, types, metadata.clone())
.await;
metadata
}
};
Ok(TypedQuery::from_sql_and_metadata(sql, metadata))
}
async fn get_statement(&self, client: &Client, sql: &str, types: &[Type]) -> Result<Statement, Error> {
client.prepare_typed(sql, types).await
}
}
impl From<CacheSettings> for TracingLruCache {
fn from(settings: CacheSettings) -> Self {
Self::with_capacity(settings.capacity)
}
}
#[derive(Debug)]
pub struct CacheSettings {
pub capacity: usize,
}
#[derive(Debug, PartialEq, Eq, Hash)]
struct QueryKey(u64);
impl QueryKey {
fn new<S: BuildHasher>(st: &S, sql: &str, params: &[Type]) -> Self {
Self(st.hash_one((sql, params)))
}
}
#[derive(Debug)]
struct InnerLruCache<V> {
cache: Mutex<LruCache<QueryKey, V, NoOpHasherBuilder>>,
state: RandomState,
}
impl<V> InnerLruCache<V> {
fn with_capacity(capacity: usize) -> Self {
Self {
cache: Mutex::new(LruCache::with_hasher(capacity, NoOpHasherBuilder)),
state: RandomState::new(),
}
}
async fn get(&self, sql: &str, types: &[Type]) -> Option<V>
where
V: Clone,
{
let mut cache = self.cache.lock().await;
let capacity = cache.capacity();
let stored = cache.len();
let key = QueryKey::new(&self.state, sql, types);
match cache.get_mut(&key) {
Some(value) => {
tracing::trace!(
message = "query cache hit",
query = sql,
capacity = capacity,
stored = stored,
);
Some(value.clone())
}
None => {
tracing::trace!(
message = "query cache miss",
query = sql,
capacity = capacity,
stored = stored,
);
None
}
}
}
pub async fn insert(&self, sql: &str, types: &[Type], value: V) {
let key = QueryKey::new(&self.state, sql, types);
self.cache.lock().await.insert(key, value);
}
}
struct NoOpHasherBuilder;
impl BuildHasher for NoOpHasherBuilder {
type Hasher = NoOpHasher;
fn build_hasher(&self) -> Self::Hasher {
NoOpHasher(None)
}
}
struct NoOpHasher(Option<u64>);
impl Hasher for NoOpHasher {
fn finish(&self) -> u64 {
self.0.expect("NoopHasher should have been called with a single u64")
}
fn write(&mut self, _bytes: &[u8]) {
panic!("NoopHasher should only be called with u64")
}
fn write_u64(&mut self, i: u64) {
assert!(self.0.is_none(), "NoopHasher should only be called once");
self.0 = Some(i);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::future::Future;
pub(crate) use crate::connector::postgres::url::PostgresNativeUrl;
use crate::{
connector::{MakeTlsConnectorManager, PostgresFlavour},
tests::test_api::postgres::CONN_STR,
};
use url::Url;
#[tokio::test]
async fn noop_cache_returns_new_queries_every_time() {
run_with_client(|client| async move {
let cache = NoOpCache;
let sql = "SELECT $1";
let types = [Type::INT4];
let stmt1 = cache.get_query(&client, sql, &types).await.unwrap();
let stmt2 = cache.get_query(&client, sql, &types).await.unwrap();
assert_ne!(stmt1.name(), stmt2.name());
})
.await;
}
#[tokio::test]
async fn noop_cache_returns_new_statements_every_time() {
run_with_client(|client| async move {
let cache = NoOpCache;
let sql = "SELECT $1";
let types = [Type::INT4];
let stmt1 = cache.get_statement(&client, sql, &types).await.unwrap();
let stmt2 = cache.get_statement(&client, sql, &types).await.unwrap();
assert_ne!(stmt1.name(), stmt2.name());
})
.await;
}
#[tokio::test]
async fn prepared_statement_lru_cache_reuses_queries_within_capacity() {
run_with_client(|client| async move {
let cache = PreparedStatementLruCache::with_capacity(3);
let sql = "SELECT $1";
let types = [Type::INT4];
let stmt1 = cache.get_query(&client, sql, &types).await.unwrap();
let stmt2 = cache.get_query(&client, sql, &types).await.unwrap();
assert_eq!(stmt1.name(), stmt2.name());
for typ in [Type::INT8, Type::INT4_ARRAY, Type::INT8_ARRAY] {
cache.get_query(&client, sql, &[typ]).await.unwrap();
}
let stmt3 = cache.get_query(&client, sql, &types).await.unwrap();
assert_ne!(stmt1.name(), stmt3.name());
})
.await;
}
#[tokio::test]
async fn prepared_statement_lru_cache_reuses_statements_within_capacity() {
run_with_client(|client| async move {
let cache = PreparedStatementLruCache::with_capacity(3);
let sql = "SELECT $1";
let types = [Type::INT4];
let stmt1 = cache.get_statement(&client, sql, &types).await.unwrap();
let stmt2 = cache.get_statement(&client, sql, &types).await.unwrap();
assert_eq!(stmt1.name(), stmt2.name());
for typ in [Type::INT8, Type::INT4_ARRAY, Type::INT8_ARRAY] {
cache.get_query(&client, sql, &[typ]).await.unwrap();
}
let stmt3 = cache.get_statement(&client, sql, &types).await.unwrap();
assert_ne!(stmt1.name(), stmt3.name());
})
.await;
}
#[tokio::test]
async fn tracing_lru_cache_reuses_queries_within_capacity() {
run_with_client(|client| async move {
let cache = TracingLruCache::with_capacity(3);
let sql = "SELECT $1";
let types = [Type::INT4];
let q1 = cache.get_query(&client, sql, &types).await.unwrap();
let q2 = cache.get_query(&client, sql, &types).await.unwrap();
assert!(
Arc::ptr_eq(&q1.metadata, &q2.metadata),
"q1 and q2 should re-use the same metadata"
);
for typ in [Type::INT8, Type::INT4_ARRAY, Type::INT8_ARRAY] {
cache.get_query(&client, sql, &[typ]).await.unwrap();
}
let q3 = cache.get_query(&client, sql, &types).await.unwrap();
assert!(
!Arc::ptr_eq(&q1.metadata, &q3.metadata),
"q1 and q3 should not re-use the same metadata"
);
})
.await;
}
#[tokio::test]
async fn tracing_lru_cache_reuses_queries_with_different_traceparent() {
run_with_client(|client| async move {
let cache = TracingLruCache::with_capacity(1);
let sql1 = "SELECT $1 /* traceparent=00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01 */";
let sql2 = "SELECT $1 /* traceparent=00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-02 */";
let types = [Type::INT4];
let q1 = cache.get_query(&client, sql1, &types).await.unwrap();
assert_eq!(q1.sql, sql1);
let q2 = cache.get_query(&client, sql2, &types).await.unwrap();
assert_eq!(q2.sql, sql2);
assert!(
Arc::ptr_eq(&q1.metadata, &q2.metadata),
"q1 and q2 should re-use the same metadata"
);
})
.await;
}
#[tokio::test]
async fn tracing_lru_cache_returns_new_statements_every_time() {
run_with_client(|client| async move {
let cache = TracingLruCache::with_capacity(1);
let sql = "SELECT $1";
let types = [Type::INT4];
let q1 = cache.get_statement(&client, sql, &types).await.unwrap();
let q2 = cache.get_statement(&client, sql, &types).await.unwrap();
assert_ne!(q1.name(), q2.name());
})
.await;
}
#[tokio::test]
async fn zero_sized_tracing_lru_cache_returns_new_metadata_every_time() {
run_with_client(|client| async move {
let cache = TracingLruCache::with_capacity(0);
let sql = "SELECT $1";
let types = [Type::INT4];
let q1 = cache.get_query(&client, sql, &types).await.unwrap();
let q2 = cache.get_query(&client, sql, &types).await.unwrap();
assert!(
!Arc::ptr_eq(&q1.metadata, &q2.metadata),
"q1 and q2 should not re-use the same metadata"
);
})
.await;
}
#[test]
fn noop_hasher_returns_the_same_hash_the_input() {
assert_eq!(NoOpHasherBuilder.hash_one(0xdeadc0deu64), 0xdeadc0de);
assert_eq!(NoOpHasherBuilder.hash_one(0xcafeu64), 0xcafe);
}
#[test]
#[should_panic(expected = "NoopHasher should only be called with u64")]
fn noop_hasher_doesnt_accept_non_u64_input() {
NoOpHasherBuilder.hash_one("hello");
}
async fn run_with_client<Func, Fut>(test: Func)
where
Func: FnOnce(Client) -> Fut,
Fut: Future<Output = ()>,
{
let url = Url::parse(&CONN_STR).unwrap();
let mut pg_url = PostgresNativeUrl::new(url).unwrap();
pg_url.set_flavour(PostgresFlavour::Postgres);
let tls_manager = MakeTlsConnectorManager::new(pg_url.clone());
let tls = tls_manager.get_connector().await.unwrap();
let (client, conn) = pg_url.to_config().connect(tls).await.unwrap();
let set = tokio::task::LocalSet::new();
set.spawn_local(conn);
set.run_until(test(client)).await
}
}