use std::collections::HashMap;
use std::sync::RwLock;
use deadpool_postgres::{Object, Transaction};
use tokio_postgres::Statement;
use tracing::debug;
use crate::error::PgResult;
pub struct PreparedStatementCache {
max_size: usize,
prepared_queries: RwLock<HashMap<String, bool>>,
}
impl PreparedStatementCache {
pub fn new(max_size: usize) -> Self {
Self {
max_size,
prepared_queries: RwLock::new(HashMap::new()),
}
}
pub async fn get_or_prepare(&self, client: &Object, sql: &str) -> PgResult<Statement> {
let is_cached = {
let cache = self.prepared_queries.read().unwrap();
cache.contains_key(sql)
};
if is_cached {
debug!(sql = %sql, "Using cached prepared statement");
} else {
debug!(sql = %sql, "Preparing new statement");
let mut cache = self.prepared_queries.write().unwrap();
if cache.len() >= self.max_size {
let to_remove: Vec<_> = cache.keys().take(cache.len() / 2).cloned().collect();
for key in to_remove {
cache.remove(&key);
}
}
cache.insert(sql.to_string(), true);
}
let stmt = client.prepare_cached(sql).await?;
Ok(stmt)
}
pub async fn get_or_prepare_in_txn<'a>(
&self,
txn: &Transaction<'a>,
sql: &str,
) -> PgResult<Statement> {
let is_cached = {
let cache = self.prepared_queries.read().unwrap();
cache.contains_key(sql)
};
if is_cached {
debug!(sql = %sql, "Using cached prepared statement (txn)");
} else {
debug!(sql = %sql, "Preparing new statement (txn)");
let mut cache = self.prepared_queries.write().unwrap();
if cache.len() >= self.max_size {
let to_remove: Vec<_> = cache.keys().take(cache.len() / 2).cloned().collect();
for key in to_remove {
cache.remove(&key);
}
}
cache.insert(sql.to_string(), true);
}
let stmt = txn.prepare_cached(sql).await?;
Ok(stmt)
}
pub fn clear(&self) {
let mut cache = self.prepared_queries.write().unwrap();
cache.clear();
debug!("Statement cache cleared");
}
pub fn len(&self) -> usize {
let cache = self.prepared_queries.read().unwrap();
cache.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn max_size(&self) -> usize {
self.max_size
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_creation() {
let cache = PreparedStatementCache::new(100);
assert_eq!(cache.max_size(), 100);
assert!(cache.is_empty());
}
#[test]
fn test_cache_clear() {
let cache = PreparedStatementCache::new(100);
{
let mut inner = cache.prepared_queries.write().unwrap();
inner.insert("SELECT 1".to_string(), true);
inner.insert("SELECT 2".to_string(), true);
}
assert_eq!(cache.len(), 2);
cache.clear();
assert!(cache.is_empty());
}
}