Skip to main content

prax_postgres/
statement.rs

1//! Prepared statement caching.
2
3use std::num::NonZeroUsize;
4use std::sync::Mutex;
5
6use deadpool_postgres::{Object, Transaction};
7use lru::LruCache;
8use tokio_postgres::Statement;
9use tracing::{debug, trace};
10
11use crate::error::PgResult;
12
13/// A cache for prepared statements.
14///
15/// Tracks which SQL strings have been prepared so we emit a `trace!`
16/// for hits vs. misses. Eviction is true LRU via [`lru::LruCache`] —
17/// when the cache reaches `max_size` the least-recently-used entry is
18/// dropped on the next insert.
19///
20/// The cache is keyed on the SQL string; the actual `Statement` is
21/// fetched from `client.prepare_cached` on every call (deadpool reuses
22/// its own per-connection cache).
23pub struct PreparedStatementCache {
24    max_size: usize,
25    /// LRU cache of SQL strings we've seen. The value is `()` because
26    /// the real `Statement` lives in deadpool-postgres' per-connection
27    /// cache; we just need to know whether we've encountered the SQL
28    /// before for tracing/metrics. `Mutex` (not `RwLock`) because every
29    /// `get_or_prepare` mutates LRU order, so the read-only path
30    /// doesn't exist.
31    prepared_queries: Mutex<LruCache<String, ()>>,
32}
33
34impl PreparedStatementCache {
35    /// Create a new statement cache with the given maximum size.
36    ///
37    /// `max_size` of 0 is treated as 1 to satisfy `NonZeroUsize`.
38    pub fn new(max_size: usize) -> Self {
39        let cap = NonZeroUsize::new(max_size.max(1)).expect("max(1) ensures non-zero");
40        Self {
41            max_size,
42            prepared_queries: Mutex::new(LruCache::new(cap)),
43        }
44    }
45
46    /// Get or prepare a statement for the given SQL.
47    pub async fn get_or_prepare(&self, client: &Object, sql: &str) -> PgResult<Statement> {
48        let is_cached = {
49            let mut cache = self
50                .prepared_queries
51                .lock()
52                .unwrap_or_else(|e| e.into_inner());
53            if cache.get(sql).is_some() {
54                true
55            } else {
56                cache.put(sql.to_string(), ());
57                false
58            }
59        };
60
61        if is_cached {
62            trace!(sql = %sql, "Using cached prepared statement");
63        } else {
64            trace!(sql = %sql, "Preparing new statement");
65        }
66
67        // Always prepare - the database will reuse if it's cached server-side
68        let stmt = client.prepare_cached(sql).await?;
69        Ok(stmt)
70    }
71
72    /// Get or prepare a statement within a transaction.
73    pub async fn get_or_prepare_in_txn<'a>(
74        &self,
75        txn: &Transaction<'a>,
76        sql: &str,
77    ) -> PgResult<Statement> {
78        let is_cached = {
79            let mut cache = self
80                .prepared_queries
81                .lock()
82                .unwrap_or_else(|e| e.into_inner());
83            if cache.get(sql).is_some() {
84                true
85            } else {
86                cache.put(sql.to_string(), ());
87                false
88            }
89        };
90
91        if is_cached {
92            trace!(sql = %sql, "Using cached prepared statement (txn)");
93        } else {
94            trace!(sql = %sql, "Preparing new statement (txn)");
95        }
96
97        let stmt = txn.prepare_cached(sql).await?;
98        Ok(stmt)
99    }
100
101    /// Clear all cached statements.
102    pub fn clear(&self) {
103        let mut cache = self
104            .prepared_queries
105            .lock()
106            .unwrap_or_else(|e| e.into_inner());
107        cache.clear();
108        debug!("Statement cache cleared");
109    }
110
111    /// Get the number of cached statement keys.
112    pub fn len(&self) -> usize {
113        let cache = self
114            .prepared_queries
115            .lock()
116            .unwrap_or_else(|e| e.into_inner());
117        cache.len()
118    }
119
120    /// Check if the cache is empty.
121    pub fn is_empty(&self) -> bool {
122        self.len() == 0
123    }
124
125    /// Get the maximum cache size.
126    pub fn max_size(&self) -> usize {
127        self.max_size
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134
135    #[test]
136    fn test_cache_creation() {
137        let cache = PreparedStatementCache::new(100);
138        assert_eq!(cache.max_size(), 100);
139        assert!(cache.is_empty());
140    }
141
142    #[test]
143    fn test_cache_clear() {
144        let cache = PreparedStatementCache::new(100);
145
146        // Manually insert some entries for testing
147        {
148            let mut inner = cache.prepared_queries.lock().unwrap();
149            inner.put("SELECT 1".to_string(), ());
150            inner.put("SELECT 2".to_string(), ());
151        }
152
153        assert_eq!(cache.len(), 2);
154        cache.clear();
155        assert!(cache.is_empty());
156    }
157
158    #[test]
159    fn test_cache_lru_eviction() {
160        let cache = PreparedStatementCache::new(2);
161        {
162            let mut inner = cache.prepared_queries.lock().unwrap();
163            inner.put("A".to_string(), ());
164            inner.put("B".to_string(), ());
165            // Touch A so B becomes LRU.
166            let _ = inner.get("A");
167            inner.put("C".to_string(), ());
168        }
169        let inner = cache.prepared_queries.lock().unwrap();
170        assert_eq!(inner.len(), 2);
171        assert!(inner.peek("A").is_some());
172        assert!(inner.peek("B").is_none(), "B should have been evicted");
173        assert!(inner.peek("C").is_some());
174    }
175}