1use crate::raw_statement::RawStatement;
4use crate::{Connection, Result, Statement};
5use hashlink::LruCache;
6use std::cell::RefCell;
7use std::ops::{Deref, DerefMut};
8use std::sync::Arc;
9
10impl Connection {
11    #[inline]
38    pub fn prepare_cached(&self, sql: &str) -> Result<CachedStatement<'_>> {
39        self.cache.get(self, sql)
40    }
41
42    #[inline]
48    pub fn set_prepared_statement_cache_capacity(&self, capacity: usize) {
49        self.cache.set_capacity(capacity)
50    }
51
52    #[inline]
54    pub fn flush_prepared_statement_cache(&self) {
55        self.cache.flush()
56    }
57}
58
59pub struct StatementCache(RefCell<LruCache<Arc<str>, RawStatement>>);
62
63pub struct CachedStatement<'conn> {
68    stmt: Option<Statement<'conn>>,
69    cache: &'conn StatementCache,
70}
71
72impl<'conn> Deref for CachedStatement<'conn> {
73    type Target = Statement<'conn>;
74
75    #[inline]
76    fn deref(&self) -> &Statement<'conn> {
77        self.stmt.as_ref().unwrap()
78    }
79}
80
81impl<'conn> DerefMut for CachedStatement<'conn> {
82    #[inline]
83    fn deref_mut(&mut self) -> &mut Statement<'conn> {
84        self.stmt.as_mut().unwrap()
85    }
86}
87
88impl Drop for CachedStatement<'_> {
89    #[allow(unused_must_use)]
90    #[inline]
91    fn drop(&mut self) {
92        if let Some(stmt) = self.stmt.take() {
93            self.cache.cache_stmt(unsafe { stmt.into_raw() });
94        }
95    }
96}
97
98impl CachedStatement<'_> {
99    #[inline]
100    fn new<'conn>(stmt: Statement<'conn>, cache: &'conn StatementCache) -> CachedStatement<'conn> {
101        CachedStatement {
102            stmt: Some(stmt),
103            cache,
104        }
105    }
106
107    #[inline]
110    pub fn discard(mut self) {
111        self.stmt = None;
112    }
113}
114
115impl StatementCache {
116    #[inline]
118    pub fn with_capacity(capacity: usize) -> StatementCache {
119        StatementCache(RefCell::new(LruCache::new(capacity)))
120    }
121
122    #[inline]
123    fn set_capacity(&self, capacity: usize) {
124        self.0.borrow_mut().set_capacity(capacity)
125    }
126
127    fn get<'conn>(
135        &'conn self,
136        conn: &'conn Connection,
137        sql: &str,
138    ) -> Result<CachedStatement<'conn>> {
139        let trimmed = sql.trim();
140        let mut cache = self.0.borrow_mut();
141        let stmt = match cache.remove(trimmed) {
142            Some(raw_stmt) => Ok(Statement::new(conn, raw_stmt)),
143            None => conn.prepare(trimmed),
144        };
145        stmt.map(|mut stmt| {
146            stmt.stmt.set_statement_cache_key(trimmed);
147            CachedStatement::new(stmt, self)
148        })
149    }
150
151    fn cache_stmt(&self, stmt: RawStatement) {
153        if stmt.is_null() {
154            return;
155        }
156        let mut cache = self.0.borrow_mut();
157        stmt.clear_bindings();
158        if let Some(sql) = stmt.statement_cache_key() {
159            cache.insert(sql, stmt);
160        } else {
161            debug_assert!(
162                false,
163                "bug in statement cache code, statement returned to cache that without key"
164            );
165        }
166    }
167
168    #[inline]
169    fn flush(&self) {
170        let mut cache = self.0.borrow_mut();
171        cache.clear()
172    }
173}
174
175#[cfg(test)]
176mod test {
177    use super::StatementCache;
178    use crate::{Connection, Result};
179    use fallible_iterator::FallibleIterator;
180
181    impl StatementCache {
182        fn clear(&self) {
183            self.0.borrow_mut().clear();
184        }
185
186        fn len(&self) -> usize {
187            self.0.borrow().len()
188        }
189
190        fn capacity(&self) -> usize {
191            self.0.borrow().capacity()
192        }
193    }
194
195    #[test]
196    fn test_cache() -> Result<()> {
197        let db = Connection::open_in_memory()?;
198        let cache = &db.cache;
199        let initial_capacity = cache.capacity();
200        assert_eq!(0, cache.len());
201        assert!(initial_capacity > 0);
202
203        let sql = "PRAGMA schema_version";
204        {
205            let mut stmt = db.prepare_cached(sql)?;
206            assert_eq!(0, cache.len());
207            assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
208        }
209        assert_eq!(1, cache.len());
210
211        {
212            let mut stmt = db.prepare_cached(sql)?;
213            assert_eq!(0, cache.len());
214            assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
215        }
216        assert_eq!(1, cache.len());
217
218        cache.clear();
219        assert_eq!(0, cache.len());
220        assert_eq!(initial_capacity, cache.capacity());
221        Ok(())
222    }
223
224    #[test]
225    fn test_set_capacity() -> Result<()> {
226        let db = Connection::open_in_memory()?;
227        let cache = &db.cache;
228
229        let sql = "PRAGMA schema_version";
230        {
231            let mut stmt = db.prepare_cached(sql)?;
232            assert_eq!(0, cache.len());
233            assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
234        }
235        assert_eq!(1, cache.len());
236
237        db.set_prepared_statement_cache_capacity(0);
238        assert_eq!(0, cache.len());
239
240        {
241            let mut stmt = db.prepare_cached(sql)?;
242            assert_eq!(0, cache.len());
243            assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
244        }
245        assert_eq!(0, cache.len());
246
247        db.set_prepared_statement_cache_capacity(8);
248        {
249            let mut stmt = db.prepare_cached(sql)?;
250            assert_eq!(0, cache.len());
251            assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
252        }
253        assert_eq!(1, cache.len());
254        Ok(())
255    }
256
257    #[test]
258    fn test_discard() -> Result<()> {
259        let db = Connection::open_in_memory()?;
260        let cache = &db.cache;
261
262        let sql = "PRAGMA schema_version";
263        {
264            let mut stmt = db.prepare_cached(sql)?;
265            assert_eq!(0, cache.len());
266            assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
267            stmt.discard();
268        }
269        assert_eq!(0, cache.len());
270        Ok(())
271    }
272
273    #[test]
274    fn test_ddl() -> Result<()> {
275        let db = Connection::open_in_memory()?;
276        db.execute_batch(
277            r#"
278            CREATE TABLE foo (x INT);
279            INSERT INTO foo VALUES (1);
280        "#,
281        )?;
282
283        let sql = "SELECT * FROM foo";
284
285        {
286            let mut stmt = db.prepare_cached(sql)?;
287            assert_eq!(Ok(Some(1i32)), stmt.query([])?.map(|r| r.get(0)).next());
288        }
289
290        db.execute_batch(
291            r#"
292            ALTER TABLE foo ADD COLUMN y INT;
293            UPDATE foo SET y = 2;
294        "#,
295        )?;
296
297        {
298            let mut stmt = db.prepare_cached(sql)?;
299            assert_eq!(
300                Ok(Some((1i32, 2i32))),
301                stmt.query([])?.map(|r| Ok((r.get(0)?, r.get(1)?))).next()
302            );
303        }
304        Ok(())
305    }
306
307    #[test]
308    fn test_connection_close() -> Result<()> {
309        let conn = Connection::open_in_memory()?;
310        conn.prepare_cached("SELECT * FROM sqlite_master;")?;
311
312        conn.close().expect("connection not closed");
313        Ok(())
314    }
315
316    #[test]
317    fn test_cache_key() -> Result<()> {
318        let db = Connection::open_in_memory()?;
319        let cache = &db.cache;
320        assert_eq!(0, cache.len());
321
322        let sql = "PRAGMA schema_version; ";
324        {
325            let mut stmt = db.prepare_cached(sql)?;
326            assert_eq!(0, cache.len());
327            assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
328        }
329        assert_eq!(1, cache.len());
330
331        {
332            let mut stmt = db.prepare_cached(sql)?;
333            assert_eq!(0, cache.len());
334            assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
335        }
336        assert_eq!(1, cache.len());
337        Ok(())
338    }
339
340    #[test]
341    fn test_empty_stmt() -> Result<()> {
342        let conn = Connection::open_in_memory()?;
343        conn.prepare_cached("")?;
344        Ok(())
345    }
346}