1use crate::raw_statement::RawStatement;
4use crate::{Connection, Result, Statement};
5use hashlink::LruCache;
6use std::cell::RefCell;
7use std::ops::{Deref, DerefMut};
8use std::sync::Arc;
9
10impl Connection {
11 #[inline]
38 pub fn prepare_cached(&self, sql: &str) -> Result<CachedStatement<'_>> {
39 self.cache.get(self, sql)
40 }
41
42 #[inline]
48 pub fn set_prepared_statement_cache_capacity(&self, capacity: usize) {
49 self.cache.set_capacity(capacity)
50 }
51
52 #[inline]
54 pub fn flush_prepared_statement_cache(&self) {
55 self.cache.flush()
56 }
57}
58
59pub struct StatementCache(RefCell<LruCache<Arc<str>, RawStatement>>);
62
63pub struct CachedStatement<'conn> {
68 stmt: Option<Statement<'conn>>,
69 cache: &'conn StatementCache,
70}
71
72impl<'conn> Deref for CachedStatement<'conn> {
73 type Target = Statement<'conn>;
74
75 #[inline]
76 fn deref(&self) -> &Statement<'conn> {
77 self.stmt.as_ref().unwrap()
78 }
79}
80
81impl<'conn> DerefMut for CachedStatement<'conn> {
82 #[inline]
83 fn deref_mut(&mut self) -> &mut Statement<'conn> {
84 self.stmt.as_mut().unwrap()
85 }
86}
87
88impl Drop for CachedStatement<'_> {
89 #[allow(unused_must_use)]
90 #[inline]
91 fn drop(&mut self) {
92 if let Some(stmt) = self.stmt.take() {
93 self.cache.cache_stmt(unsafe { stmt.into_raw() });
94 }
95 }
96}
97
98impl CachedStatement<'_> {
99 #[inline]
100 fn new<'conn>(stmt: Statement<'conn>, cache: &'conn StatementCache) -> CachedStatement<'conn> {
101 CachedStatement {
102 stmt: Some(stmt),
103 cache,
104 }
105 }
106
107 #[inline]
110 pub fn discard(mut self) {
111 self.stmt = None;
112 }
113}
114
115impl StatementCache {
116 #[inline]
118 pub fn with_capacity(capacity: usize) -> StatementCache {
119 StatementCache(RefCell::new(LruCache::new(capacity)))
120 }
121
122 #[inline]
123 fn set_capacity(&self, capacity: usize) {
124 self.0.borrow_mut().set_capacity(capacity)
125 }
126
127 fn get<'conn>(
135 &'conn self,
136 conn: &'conn Connection,
137 sql: &str,
138 ) -> Result<CachedStatement<'conn>> {
139 let trimmed = sql.trim();
140 let mut cache = self.0.borrow_mut();
141 let stmt = match cache.remove(trimmed) {
142 Some(raw_stmt) => Ok(Statement::new(conn, raw_stmt)),
143 None => conn.prepare(trimmed),
144 };
145 stmt.map(|mut stmt| {
146 stmt.stmt.set_statement_cache_key(trimmed);
147 CachedStatement::new(stmt, self)
148 })
149 }
150
151 fn cache_stmt(&self, stmt: RawStatement) {
153 if stmt.is_null() {
154 return;
155 }
156 let mut cache = self.0.borrow_mut();
157 stmt.clear_bindings();
158 if let Some(sql) = stmt.statement_cache_key() {
159 cache.insert(sql, stmt);
160 } else {
161 debug_assert!(
162 false,
163 "bug in statement cache code, statement returned to cache that without key"
164 );
165 }
166 }
167
168 #[inline]
169 fn flush(&self) {
170 let mut cache = self.0.borrow_mut();
171 cache.clear()
172 }
173}
174
175#[cfg(test)]
176mod test {
177 use super::StatementCache;
178 use crate::{Connection, Result};
179 use fallible_iterator::FallibleIterator;
180
181 impl StatementCache {
182 fn clear(&self) {
183 self.0.borrow_mut().clear();
184 }
185
186 fn len(&self) -> usize {
187 self.0.borrow().len()
188 }
189
190 fn capacity(&self) -> usize {
191 self.0.borrow().capacity()
192 }
193 }
194
195 #[test]
196 fn test_cache() -> Result<()> {
197 let db = Connection::open_in_memory()?;
198 let cache = &db.cache;
199 let initial_capacity = cache.capacity();
200 assert_eq!(0, cache.len());
201 assert!(initial_capacity > 0);
202
203 let sql = "PRAGMA schema_version";
204 {
205 let mut stmt = db.prepare_cached(sql)?;
206 assert_eq!(0, cache.len());
207 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
208 }
209 assert_eq!(1, cache.len());
210
211 {
212 let mut stmt = db.prepare_cached(sql)?;
213 assert_eq!(0, cache.len());
214 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
215 }
216 assert_eq!(1, cache.len());
217
218 cache.clear();
219 assert_eq!(0, cache.len());
220 assert_eq!(initial_capacity, cache.capacity());
221 Ok(())
222 }
223
224 #[test]
225 fn test_set_capacity() -> Result<()> {
226 let db = Connection::open_in_memory()?;
227 let cache = &db.cache;
228
229 let sql = "PRAGMA schema_version";
230 {
231 let mut stmt = db.prepare_cached(sql)?;
232 assert_eq!(0, cache.len());
233 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
234 }
235 assert_eq!(1, cache.len());
236
237 db.set_prepared_statement_cache_capacity(0);
238 assert_eq!(0, cache.len());
239
240 {
241 let mut stmt = db.prepare_cached(sql)?;
242 assert_eq!(0, cache.len());
243 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
244 }
245 assert_eq!(0, cache.len());
246
247 db.set_prepared_statement_cache_capacity(8);
248 {
249 let mut stmt = db.prepare_cached(sql)?;
250 assert_eq!(0, cache.len());
251 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
252 }
253 assert_eq!(1, cache.len());
254 Ok(())
255 }
256
257 #[test]
258 fn test_discard() -> Result<()> {
259 let db = Connection::open_in_memory()?;
260 let cache = &db.cache;
261
262 let sql = "PRAGMA schema_version";
263 {
264 let mut stmt = db.prepare_cached(sql)?;
265 assert_eq!(0, cache.len());
266 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
267 stmt.discard();
268 }
269 assert_eq!(0, cache.len());
270 Ok(())
271 }
272
273 #[test]
274 fn test_ddl() -> Result<()> {
275 let db = Connection::open_in_memory()?;
276 db.execute_batch(
277 r#"
278 CREATE TABLE foo (x INT);
279 INSERT INTO foo VALUES (1);
280 "#,
281 )?;
282
283 let sql = "SELECT * FROM foo";
284
285 {
286 let mut stmt = db.prepare_cached(sql)?;
287 assert_eq!(Ok(Some(1i32)), stmt.query([])?.map(|r| r.get(0)).next());
288 }
289
290 db.execute_batch(
291 r#"
292 ALTER TABLE foo ADD COLUMN y INT;
293 UPDATE foo SET y = 2;
294 "#,
295 )?;
296
297 {
298 let mut stmt = db.prepare_cached(sql)?;
299 assert_eq!(
300 Ok(Some((1i32, 2i32))),
301 stmt.query([])?.map(|r| Ok((r.get(0)?, r.get(1)?))).next()
302 );
303 }
304 Ok(())
305 }
306
307 #[test]
308 fn test_connection_close() -> Result<()> {
309 let conn = Connection::open_in_memory()?;
310 conn.prepare_cached("SELECT * FROM sqlite_master;")?;
311
312 conn.close().expect("connection not closed");
313 Ok(())
314 }
315
316 #[test]
317 fn test_cache_key() -> Result<()> {
318 let db = Connection::open_in_memory()?;
319 let cache = &db.cache;
320 assert_eq!(0, cache.len());
321
322 let sql = "PRAGMA schema_version; ";
324 {
325 let mut stmt = db.prepare_cached(sql)?;
326 assert_eq!(0, cache.len());
327 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
328 }
329 assert_eq!(1, cache.len());
330
331 {
332 let mut stmt = db.prepare_cached(sql)?;
333 assert_eq!(0, cache.len());
334 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
335 }
336 assert_eq!(1, cache.len());
337 Ok(())
338 }
339
340 #[test]
341 fn test_empty_stmt() -> Result<()> {
342 let conn = Connection::open_in_memory()?;
343 conn.prepare_cached("")?;
344 Ok(())
345 }
346}