prax_postgres/
statement.rs

1//! Prepared statement caching.
2
3use std::collections::HashMap;
4use std::sync::RwLock;
5
6use deadpool_postgres::{Object, Transaction};
7use tokio_postgres::Statement;
8use tracing::debug;
9
10use crate::error::PgResult;
11
12/// A cache for prepared statements.
13///
14/// This cache stores prepared statements by their SQL query string,
15/// allowing reuse of statements across multiple queries.
16pub struct PreparedStatementCache {
17    max_size: usize,
18    /// Note: We use a simple HashMap here. In production, you might want
19    /// an LRU cache to evict old statements when the cache is full.
20    /// However, prepared statements are tied to connections, so this
21    /// cache is really just for tracking which statements we've prepared.
22    prepared_queries: RwLock<HashMap<String, bool>>,
23}
24
25impl PreparedStatementCache {
26    /// Create a new statement cache with the given maximum size.
27    pub fn new(max_size: usize) -> Self {
28        Self {
29            max_size,
30            prepared_queries: RwLock::new(HashMap::new()),
31        }
32    }
33
34    /// Get or prepare a statement for the given SQL.
35    pub async fn get_or_prepare(&self, client: &Object, sql: &str) -> PgResult<Statement> {
36        // Check if we've prepared this statement before
37        let is_cached = {
38            let cache = self.prepared_queries.read().unwrap();
39            cache.contains_key(sql)
40        };
41
42        if is_cached {
43            debug!(sql = %sql, "Using cached prepared statement");
44        } else {
45            debug!(sql = %sql, "Preparing new statement");
46
47            // Check cache size and potentially evict
48            let mut cache = self.prepared_queries.write().unwrap();
49            if cache.len() >= self.max_size {
50                // Simple eviction: clear half the cache
51                // In production, use an LRU cache
52                let to_remove: Vec<_> = cache.keys().take(cache.len() / 2).cloned().collect();
53                for key in to_remove {
54                    cache.remove(&key);
55                }
56            }
57            cache.insert(sql.to_string(), true);
58        }
59
60        // Always prepare - the database will reuse if it's cached server-side
61        let stmt = client.prepare_cached(sql).await?;
62        Ok(stmt)
63    }
64
65    /// Get or prepare a statement within a transaction.
66    pub async fn get_or_prepare_in_txn<'a>(
67        &self,
68        txn: &Transaction<'a>,
69        sql: &str,
70    ) -> PgResult<Statement> {
71        // Similar logic to above, but for transactions
72        let is_cached = {
73            let cache = self.prepared_queries.read().unwrap();
74            cache.contains_key(sql)
75        };
76
77        if is_cached {
78            debug!(sql = %sql, "Using cached prepared statement (txn)");
79        } else {
80            debug!(sql = %sql, "Preparing new statement (txn)");
81
82            let mut cache = self.prepared_queries.write().unwrap();
83            if cache.len() >= self.max_size {
84                let to_remove: Vec<_> = cache.keys().take(cache.len() / 2).cloned().collect();
85                for key in to_remove {
86                    cache.remove(&key);
87                }
88            }
89            cache.insert(sql.to_string(), true);
90        }
91
92        let stmt = txn.prepare_cached(sql).await?;
93        Ok(stmt)
94    }
95
96    /// Clear all cached statements.
97    pub fn clear(&self) {
98        let mut cache = self.prepared_queries.write().unwrap();
99        cache.clear();
100        debug!("Statement cache cleared");
101    }
102
103    /// Get the number of cached statement keys.
104    pub fn len(&self) -> usize {
105        let cache = self.prepared_queries.read().unwrap();
106        cache.len()
107    }
108
109    /// Check if the cache is empty.
110    pub fn is_empty(&self) -> bool {
111        self.len() == 0
112    }
113
114    /// Get the maximum cache size.
115    pub fn max_size(&self) -> usize {
116        self.max_size
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    #[test]
125    fn test_cache_creation() {
126        let cache = PreparedStatementCache::new(100);
127        assert_eq!(cache.max_size(), 100);
128        assert!(cache.is_empty());
129    }
130
131    #[test]
132    fn test_cache_clear() {
133        let cache = PreparedStatementCache::new(100);
134
135        // Manually insert some entries for testing
136        {
137            let mut inner = cache.prepared_queries.write().unwrap();
138            inner.insert("SELECT 1".to_string(), true);
139            inner.insert("SELECT 2".to_string(), true);
140        }
141
142        assert_eq!(cache.len(), 2);
143        cache.clear();
144        assert!(cache.is_empty());
145    }
146}