1use crate::raw_statement::RawStatement;
4use crate::{Connection, Result, Statement};
5use hashlink::LruCache;
6use std::cell::RefCell;
7use std::ops::{Deref, DerefMut};
8use std::sync::Arc;
9
10impl Connection {
11 #[inline]
38 pub fn prepare_cached(&self, sql: &str) -> Result<CachedStatement<'_>> {
39 self.cache.get(self, sql)
40 }
41
42 #[inline]
48 pub fn set_prepared_statement_cache_capacity(&self, capacity: usize) {
49 self.cache.set_capacity(capacity);
50 }
51
52 #[inline]
54 pub fn flush_prepared_statement_cache(&self) {
55 self.cache.flush();
56 }
57}
58
59pub struct StatementCache(RefCell<LruCache<Arc<str>, RawStatement>>);
62
63#[allow(clippy::non_send_fields_in_send_ty)]
64unsafe impl Send for StatementCache {}
65
66pub struct CachedStatement<'conn> {
72 stmt: Option<Statement<'conn>>,
73 cache: &'conn StatementCache,
74}
75
76impl<'conn> Deref for CachedStatement<'conn> {
77 type Target = Statement<'conn>;
78
79 #[inline]
80 fn deref(&self) -> &Statement<'conn> {
81 self.stmt.as_ref().unwrap()
82 }
83}
84
85impl<'conn> DerefMut for CachedStatement<'conn> {
86 #[inline]
87 fn deref_mut(&mut self) -> &mut Statement<'conn> {
88 self.stmt.as_mut().unwrap()
89 }
90}
91
92impl Drop for CachedStatement<'_> {
93 #[allow(unused_must_use)]
94 #[inline]
95 fn drop(&mut self) {
96 if let Some(stmt) = self.stmt.take() {
97 self.cache.cache_stmt(unsafe { stmt.into_raw() });
98 }
99 }
100}
101
102impl CachedStatement<'_> {
103 #[inline]
104 fn new<'conn>(stmt: Statement<'conn>, cache: &'conn StatementCache) -> CachedStatement<'conn> {
105 CachedStatement {
106 stmt: Some(stmt),
107 cache,
108 }
109 }
110
111 #[inline]
114 pub fn discard(mut self) {
115 self.stmt = None;
116 }
117}
118
119impl StatementCache {
120 #[inline]
122 pub fn with_capacity(capacity: usize) -> StatementCache {
123 StatementCache(RefCell::new(LruCache::new(capacity)))
124 }
125
126 #[inline]
127 fn set_capacity(&self, capacity: usize) {
128 self.0.borrow_mut().set_capacity(capacity);
129 }
130
131 fn get<'conn>(
139 &'conn self,
140 conn: &'conn Connection,
141 sql: &str,
142 ) -> Result<CachedStatement<'conn>> {
143 let trimmed = sql.trim();
144 let mut cache = self.0.borrow_mut();
145 let stmt = match cache.remove(trimmed) {
146 Some(raw_stmt) => Ok(Statement::new(conn, raw_stmt)),
147 None => conn.prepare(trimmed),
148 };
149 stmt.map(|mut stmt| {
150 stmt.stmt.set_statement_cache_key(trimmed);
151 CachedStatement::new(stmt, self)
152 })
153 }
154
155 fn cache_stmt(&self, stmt: RawStatement) {
157 if stmt.is_null() {
158 return;
159 }
160 let mut cache = self.0.borrow_mut();
161 stmt.clear_bindings();
162 if let Some(sql) = stmt.statement_cache_key() {
163 cache.insert(sql, stmt);
164 } else {
165 debug_assert!(
166 false,
167 "bug in statement cache code, statement returned to cache that without key"
168 );
169 }
170 }
171
172 #[inline]
173 fn flush(&self) {
174 let mut cache = self.0.borrow_mut();
175 cache.clear();
176 }
177}
178
179#[cfg(test)]
180mod test {
181 use super::StatementCache;
182 use crate::{Connection, Result};
183 use fallible_iterator::FallibleIterator;
184
185 impl StatementCache {
186 fn clear(&self) {
187 self.0.borrow_mut().clear();
188 }
189
190 fn len(&self) -> usize {
191 self.0.borrow().len()
192 }
193
194 fn capacity(&self) -> usize {
195 self.0.borrow().capacity()
196 }
197 }
198
199 #[test]
200 fn test_cache() -> Result<()> {
201 let db = Connection::open_in_memory()?;
202 let cache = &db.cache;
203 let initial_capacity = cache.capacity();
204 assert_eq!(0, cache.len());
205 assert!(initial_capacity > 0);
206
207 let sql = "PRAGMA schema_version";
208 {
209 let mut stmt = db.prepare_cached(sql)?;
210 assert_eq!(0, cache.len());
211 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
212 }
213 assert_eq!(1, cache.len());
214
215 {
216 let mut stmt = db.prepare_cached(sql)?;
217 assert_eq!(0, cache.len());
218 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
219 }
220 assert_eq!(1, cache.len());
221
222 cache.clear();
223 assert_eq!(0, cache.len());
224 assert_eq!(initial_capacity, cache.capacity());
225 Ok(())
226 }
227
228 #[test]
229 fn test_set_capacity() -> Result<()> {
230 let db = Connection::open_in_memory()?;
231 let cache = &db.cache;
232
233 let sql = "PRAGMA schema_version";
234 {
235 let mut stmt = db.prepare_cached(sql)?;
236 assert_eq!(0, cache.len());
237 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
238 }
239 assert_eq!(1, cache.len());
240
241 db.set_prepared_statement_cache_capacity(0);
242 assert_eq!(0, cache.len());
243
244 {
245 let mut stmt = db.prepare_cached(sql)?;
246 assert_eq!(0, cache.len());
247 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
248 }
249 assert_eq!(0, cache.len());
250
251 db.set_prepared_statement_cache_capacity(8);
252 {
253 let mut stmt = db.prepare_cached(sql)?;
254 assert_eq!(0, cache.len());
255 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
256 }
257 assert_eq!(1, cache.len());
258 Ok(())
259 }
260
261 #[test]
262 fn test_discard() -> Result<()> {
263 let db = Connection::open_in_memory()?;
264 let cache = &db.cache;
265
266 let sql = "PRAGMA schema_version";
267 {
268 let mut stmt = db.prepare_cached(sql)?;
269 assert_eq!(0, cache.len());
270 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
271 stmt.discard();
272 }
273 assert_eq!(0, cache.len());
274 Ok(())
275 }
276
277 #[test]
278 fn test_ddl() -> Result<()> {
279 let db = Connection::open_in_memory()?;
280 db.execute_batch(
281 r#"
282 CREATE TABLE foo (x INT);
283 INSERT INTO foo VALUES (1);
284 "#,
285 )?;
286
287 let sql = "SELECT * FROM foo";
288
289 {
290 let mut stmt = db.prepare_cached(sql)?;
291 assert_eq!(Ok(Some(1i32)), stmt.query([])?.map(|r| r.get(0)).next());
292 }
293
294 db.execute_batch(
295 r#"
296 ALTER TABLE foo ADD COLUMN y INT;
297 UPDATE foo SET y = 2;
298 "#,
299 )?;
300
301 {
302 let mut stmt = db.prepare_cached(sql)?;
303 assert_eq!(
304 Ok(Some((1i32, 2i32))),
305 stmt.query([])?.map(|r| Ok((r.get(0)?, r.get(1)?))).next()
306 );
307 }
308 Ok(())
309 }
310
311 #[test]
312 fn test_connection_close() -> Result<()> {
313 let conn = Connection::open_in_memory()?;
314 conn.prepare_cached("SELECT * FROM sqlite_master;")?;
315
316 conn.close().expect("connection not closed");
317 Ok(())
318 }
319
320 #[test]
321 fn test_cache_key() -> Result<()> {
322 let db = Connection::open_in_memory()?;
323 let cache = &db.cache;
324 assert_eq!(0, cache.len());
325
326 let sql = "PRAGMA schema_version; ";
328 {
329 let mut stmt = db.prepare_cached(sql)?;
330 assert_eq!(0, cache.len());
331 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
332 }
333 assert_eq!(1, cache.len());
334
335 {
336 let mut stmt = db.prepare_cached(sql)?;
337 assert_eq!(0, cache.len());
338 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
339 }
340 assert_eq!(1, cache.len());
341 Ok(())
342 }
343
344 #[test]
345 fn test_empty_stmt() -> Result<()> {
346 let conn = Connection::open_in_memory()?;
347 conn.prepare_cached("")?;
348 Ok(())
349 }
350}