prax_postgres/
statement.rs1use std::collections::HashMap;
4use std::sync::RwLock;
5
6use deadpool_postgres::{Object, Transaction};
7use tokio_postgres::Statement;
8use tracing::debug;
9
10use crate::error::PgResult;
11
12pub struct PreparedStatementCache {
17 max_size: usize,
18 prepared_queries: RwLock<HashMap<String, bool>>,
23}
24
25impl PreparedStatementCache {
26 pub fn new(max_size: usize) -> Self {
28 Self {
29 max_size,
30 prepared_queries: RwLock::new(HashMap::new()),
31 }
32 }
33
34 pub async fn get_or_prepare(&self, client: &Object, sql: &str) -> PgResult<Statement> {
36 let is_cached = {
38 let cache = self.prepared_queries.read().unwrap();
39 cache.contains_key(sql)
40 };
41
42 if is_cached {
43 debug!(sql = %sql, "Using cached prepared statement");
44 } else {
45 debug!(sql = %sql, "Preparing new statement");
46
47 let mut cache = self.prepared_queries.write().unwrap();
49 if cache.len() >= self.max_size {
50 let to_remove: Vec<_> = cache.keys().take(cache.len() / 2).cloned().collect();
53 for key in to_remove {
54 cache.remove(&key);
55 }
56 }
57 cache.insert(sql.to_string(), true);
58 }
59
60 let stmt = client.prepare_cached(sql).await?;
62 Ok(stmt)
63 }
64
65 pub async fn get_or_prepare_in_txn<'a>(
67 &self,
68 txn: &Transaction<'a>,
69 sql: &str,
70 ) -> PgResult<Statement> {
71 let is_cached = {
73 let cache = self.prepared_queries.read().unwrap();
74 cache.contains_key(sql)
75 };
76
77 if is_cached {
78 debug!(sql = %sql, "Using cached prepared statement (txn)");
79 } else {
80 debug!(sql = %sql, "Preparing new statement (txn)");
81
82 let mut cache = self.prepared_queries.write().unwrap();
83 if cache.len() >= self.max_size {
84 let to_remove: Vec<_> = cache.keys().take(cache.len() / 2).cloned().collect();
85 for key in to_remove {
86 cache.remove(&key);
87 }
88 }
89 cache.insert(sql.to_string(), true);
90 }
91
92 let stmt = txn.prepare_cached(sql).await?;
93 Ok(stmt)
94 }
95
96 pub fn clear(&self) {
98 let mut cache = self.prepared_queries.write().unwrap();
99 cache.clear();
100 debug!("Statement cache cleared");
101 }
102
103 pub fn len(&self) -> usize {
105 let cache = self.prepared_queries.read().unwrap();
106 cache.len()
107 }
108
109 pub fn is_empty(&self) -> bool {
111 self.len() == 0
112 }
113
114 pub fn max_size(&self) -> usize {
116 self.max_size
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 #[test]
125 fn test_cache_creation() {
126 let cache = PreparedStatementCache::new(100);
127 assert_eq!(cache.max_size(), 100);
128 assert!(cache.is_empty());
129 }
130
131 #[test]
132 fn test_cache_clear() {
133 let cache = PreparedStatementCache::new(100);
134
135 {
137 let mut inner = cache.prepared_queries.write().unwrap();
138 inner.insert("SELECT 1".to_string(), true);
139 inner.insert("SELECT 2".to_string(), true);
140 }
141
142 assert_eq!(cache.len(), 2);
143 cache.clear();
144 assert!(cache.is_empty());
145 }
146}