use std::num::NonZeroUsize;
use std::sync::Mutex;
use deadpool_postgres::{Object, Transaction};
use lru::LruCache;
use tokio_postgres::Statement;
use tracing::{debug, trace};
use crate::error::PgResult;
pub struct PreparedStatementCache {
max_size: usize,
prepared_queries: Mutex<LruCache<String, ()>>,
}
impl PreparedStatementCache {
pub fn new(max_size: usize) -> Self {
let cap = NonZeroUsize::new(max_size.max(1)).expect("max(1) ensures non-zero");
Self {
max_size,
prepared_queries: Mutex::new(LruCache::new(cap)),
}
}
pub async fn get_or_prepare(&self, client: &Object, sql: &str) -> PgResult<Statement> {
let is_cached = {
let mut cache = self
.prepared_queries
.lock()
.unwrap_or_else(|e| e.into_inner());
if cache.get(sql).is_some() {
true
} else {
cache.put(sql.to_string(), ());
false
}
};
if is_cached {
trace!(sql = %sql, "Using cached prepared statement");
} else {
trace!(sql = %sql, "Preparing new statement");
}
let stmt = client.prepare_cached(sql).await?;
Ok(stmt)
}
pub async fn get_or_prepare_in_txn<'a>(
&self,
txn: &Transaction<'a>,
sql: &str,
) -> PgResult<Statement> {
let is_cached = {
let mut cache = self
.prepared_queries
.lock()
.unwrap_or_else(|e| e.into_inner());
if cache.get(sql).is_some() {
true
} else {
cache.put(sql.to_string(), ());
false
}
};
if is_cached {
trace!(sql = %sql, "Using cached prepared statement (txn)");
} else {
trace!(sql = %sql, "Preparing new statement (txn)");
}
let stmt = txn.prepare_cached(sql).await?;
Ok(stmt)
}
pub fn clear(&self) {
let mut cache = self
.prepared_queries
.lock()
.unwrap_or_else(|e| e.into_inner());
cache.clear();
debug!("Statement cache cleared");
}
pub fn len(&self) -> usize {
let cache = self
.prepared_queries
.lock()
.unwrap_or_else(|e| e.into_inner());
cache.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn max_size(&self) -> usize {
self.max_size
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_creation() {
let cache = PreparedStatementCache::new(100);
assert_eq!(cache.max_size(), 100);
assert!(cache.is_empty());
}
#[test]
fn test_cache_clear() {
let cache = PreparedStatementCache::new(100);
{
let mut inner = cache.prepared_queries.lock().unwrap();
inner.put("SELECT 1".to_string(), ());
inner.put("SELECT 2".to_string(), ());
}
assert_eq!(cache.len(), 2);
cache.clear();
assert!(cache.is_empty());
}
#[test]
fn test_cache_lru_eviction() {
let cache = PreparedStatementCache::new(2);
{
let mut inner = cache.prepared_queries.lock().unwrap();
inner.put("A".to_string(), ());
inner.put("B".to_string(), ());
let _ = inner.get("A");
inner.put("C".to_string(), ());
}
let inner = cache.prepared_queries.lock().unwrap();
assert_eq!(inner.len(), 2);
assert!(inner.peek("A").is_some());
assert!(inner.peek("B").is_none(), "B should have been evicted");
assert!(inner.peek("C").is_some());
}
}