rsfbclient/connection/
stmt_cache.rs

1//!
2//! Rust Firebird Client
3//!
4//! Statement Cache
5//!
6
7use lru_cache::LruCache;
8use std::{collections::HashSet, mem};
9
10use crate::{statement::StatementData, Connection, FbError, Transaction};
11use rsfbclient_core::FirebirdClient;
12
13/// Cache of prepared statements.
14///
15/// Must be emptied by calling `close_all` before dropping.
16pub struct StmtCache<T> {
17    cache: LruCache<String, T>,
18    sqls: HashSet<String>,
19}
20
21pub struct StmtCacheData<T> {
22    pub(crate) sql: String,
23    pub(crate) stmt: T,
24}
25
26/// General functions
27impl<T> StmtCache<T> {
28    pub fn new(capacity: usize) -> Self {
29        Self {
30            cache: LruCache::new(capacity),
31            sqls: HashSet::with_capacity(capacity),
32        }
33    }
34
35    /// Get a prepared statement from the cache
36    fn get(&mut self, sql: &str) -> Option<StmtCacheData<T>> {
37        if let Some(stmt) = self.cache.remove(sql) {
38            let sql = self.sqls.take(sql).unwrap();
39
40            Some(StmtCacheData { stmt, sql })
41        } else {
42            None
43        }
44    }
45
46    /// Adds a prepared statement to the cache, returning the previous one for this sql
47    /// or another if the cache is full
48    fn insert(&mut self, data: StmtCacheData<T>) -> Option<T> {
49        if self.sqls.contains(&data.sql) {
50            // Insert the new one and return the old
51            self.cache.insert(data.sql, data.stmt)
52        } else {
53            // Insert the sql
54            self.sqls.insert(data.sql.clone());
55
56            // If full, remove the last recently used
57            let old = if self.cache.len() == self.cache.capacity() {
58                if let Some((sql, stmt)) = self.cache.remove_lru() {
59                    // Remove the sql
60                    self.sqls.remove(&sql);
61
62                    Some(stmt)
63                } else {
64                    None
65                }
66            } else {
67                None
68            };
69
70            // Insert the new one
71            self.cache.insert(data.sql, data.stmt);
72
73            old
74        }
75    }
76}
77
78/// Functions specific for when the data is a `StatementData`
79impl<C> StmtCache<StatementData<C>>
80where
81    C: FirebirdClient,
82{
83    /// Get a prepared statement from the cache, or prepare one
84    pub fn get_or_prepare(
85        tr: &mut Transaction<C>,
86        sql: &str,
87        named_params: bool,
88    ) -> Result<StmtCacheData<StatementData<C>>, FbError> {
89        if let Some(data) = tr.conn.stmt_cache.get(sql) {
90            Ok(data)
91        } else {
92            Ok(StmtCacheData {
93                sql: sql.to_string(),
94                stmt: StatementData::prepare(tr.conn, &mut tr.data, sql, named_params)?,
95            })
96        }
97    }
98
99    /// Adds a prepared statement to the cache, closing the previous one for this sql
100    /// or another if the cache is full
101    pub fn insert_and_close(
102        conn: &mut Connection<C>,
103        data: StmtCacheData<StatementData<C>>,
104    ) -> Result<(), FbError> {
105        conn.stmt_cache.sqls.insert(data.sql.clone());
106
107        // Insert the new one and close the old if exists
108        if let Some(mut stmt) = conn.stmt_cache.insert(data) {
109            stmt.close(conn)?;
110        }
111
112        Ok(())
113    }
114
115    /// Closes all statements in the cache.
116    /// Needs to be called before dropping the cache.
117    pub fn close_all(conn: &mut Connection<C>) {
118        let mut stmt_cache = mem::replace(&mut conn.stmt_cache, StmtCache::new(0));
119
120        for (_, stmt) in stmt_cache.cache.iter_mut() {
121            stmt.close(conn).ok();
122        }
123    }
124}
125
126#[test]
127fn stmt_cache_test() {
128    let mut cache = StmtCache::new(2);
129
130    let mk_test_data = |n: usize| StmtCacheData {
131        sql: format!("sql {}", n),
132        stmt: n,
133    };
134
135    let sql1 = mk_test_data(1);
136    let sql2 = mk_test_data(2);
137    let sql3 = mk_test_data(3);
138    let sql4 = mk_test_data(4);
139    let sql5 = mk_test_data(5);
140    let sql6 = mk_test_data(6);
141
142    assert!(cache.get(&sql1.sql).is_none());
143
144    assert!(cache.insert(sql1).is_none());
145
146    assert!(cache.insert(sql2).is_none());
147
148    let stmt = cache.insert(sql3).expect("sql1 not returned");
149    assert_eq!(stmt, 1);
150
151    assert!(cache.get("sql 1").is_none());
152
153    // Marks sql2 as recently used, so 3 must be removed in the next insert
154    let sql2 = cache.get("sql 2").expect("Sql 2 not in the cache");
155    assert!(cache.insert(sql2).is_none());
156
157    let stmt = cache.insert(sql4).expect("sql3 not returned");
158    assert_eq!(stmt, 3);
159
160    let stmt = cache.insert(sql5).expect("sql2 not returned");
161    assert_eq!(stmt, 2);
162
163    let stmt = cache.insert(sql6).expect("sql4 not returned");
164    assert_eq!(stmt, 4);
165
166    assert_eq!(cache.get("sql 5").expect("sql5 not in the cache").stmt, 5);
167    assert_eq!(cache.get("sql 6").expect("sql6 not in the cache").stmt, 6);
168
169    assert!(cache.cache.is_empty());
170    assert!(cache.sqls.is_empty());
171}