prax_postgres/
statement.rs1use std::num::NonZeroUsize;
4use std::sync::Mutex;
5
6use deadpool_postgres::{Object, Transaction};
7use lru::LruCache;
8use tokio_postgres::Statement;
9use tracing::{debug, trace};
10
11use crate::error::PgResult;
12
13pub struct PreparedStatementCache {
24 max_size: usize,
25 prepared_queries: Mutex<LruCache<String, ()>>,
32}
33
34impl PreparedStatementCache {
35 pub fn new(max_size: usize) -> Self {
39 let cap = NonZeroUsize::new(max_size.max(1)).expect("max(1) ensures non-zero");
40 Self {
41 max_size,
42 prepared_queries: Mutex::new(LruCache::new(cap)),
43 }
44 }
45
46 pub async fn get_or_prepare(&self, client: &Object, sql: &str) -> PgResult<Statement> {
48 let is_cached = {
49 let mut cache = self
50 .prepared_queries
51 .lock()
52 .unwrap_or_else(|e| e.into_inner());
53 if cache.get(sql).is_some() {
54 true
55 } else {
56 cache.put(sql.to_string(), ());
57 false
58 }
59 };
60
61 if is_cached {
62 trace!(sql = %sql, "Using cached prepared statement");
63 } else {
64 trace!(sql = %sql, "Preparing new statement");
65 }
66
67 let stmt = client.prepare_cached(sql).await?;
69 Ok(stmt)
70 }
71
72 pub async fn get_or_prepare_in_txn<'a>(
74 &self,
75 txn: &Transaction<'a>,
76 sql: &str,
77 ) -> PgResult<Statement> {
78 let is_cached = {
79 let mut cache = self
80 .prepared_queries
81 .lock()
82 .unwrap_or_else(|e| e.into_inner());
83 if cache.get(sql).is_some() {
84 true
85 } else {
86 cache.put(sql.to_string(), ());
87 false
88 }
89 };
90
91 if is_cached {
92 trace!(sql = %sql, "Using cached prepared statement (txn)");
93 } else {
94 trace!(sql = %sql, "Preparing new statement (txn)");
95 }
96
97 let stmt = txn.prepare_cached(sql).await?;
98 Ok(stmt)
99 }
100
101 pub fn clear(&self) {
103 let mut cache = self
104 .prepared_queries
105 .lock()
106 .unwrap_or_else(|e| e.into_inner());
107 cache.clear();
108 debug!("Statement cache cleared");
109 }
110
111 pub fn len(&self) -> usize {
113 let cache = self
114 .prepared_queries
115 .lock()
116 .unwrap_or_else(|e| e.into_inner());
117 cache.len()
118 }
119
120 pub fn is_empty(&self) -> bool {
122 self.len() == 0
123 }
124
125 pub fn max_size(&self) -> usize {
127 self.max_size
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 #[test]
136 fn test_cache_creation() {
137 let cache = PreparedStatementCache::new(100);
138 assert_eq!(cache.max_size(), 100);
139 assert!(cache.is_empty());
140 }
141
142 #[test]
143 fn test_cache_clear() {
144 let cache = PreparedStatementCache::new(100);
145
146 {
148 let mut inner = cache.prepared_queries.lock().unwrap();
149 inner.put("SELECT 1".to_string(), ());
150 inner.put("SELECT 2".to_string(), ());
151 }
152
153 assert_eq!(cache.len(), 2);
154 cache.clear();
155 assert!(cache.is_empty());
156 }
157
158 #[test]
159 fn test_cache_lru_eviction() {
160 let cache = PreparedStatementCache::new(2);
161 {
162 let mut inner = cache.prepared_queries.lock().unwrap();
163 inner.put("A".to_string(), ());
164 inner.put("B".to_string(), ());
165 let _ = inner.get("A");
167 inner.put("C".to_string(), ());
168 }
169 let inner = cache.prepared_queries.lock().unwrap();
170 assert_eq!(inner.len(), 2);
171 assert!(inner.peek("A").is_some());
172 assert!(inner.peek("B").is_none(), "B should have been evicted");
173 assert!(inner.peek("C").is_some());
174 }
175}