rsfbclient/connection/
stmt_cache.rs1use lru_cache::LruCache;
8use std::{collections::HashSet, mem};
9
10use crate::{statement::StatementData, Connection, FbError, Transaction};
11use rsfbclient_core::FirebirdClient;
12
13pub struct StmtCache<T> {
17 cache: LruCache<String, T>,
18 sqls: HashSet<String>,
19}
20
21pub struct StmtCacheData<T> {
22 pub(crate) sql: String,
23 pub(crate) stmt: T,
24}
25
26impl<T> StmtCache<T> {
28 pub fn new(capacity: usize) -> Self {
29 Self {
30 cache: LruCache::new(capacity),
31 sqls: HashSet::with_capacity(capacity),
32 }
33 }
34
35 fn get(&mut self, sql: &str) -> Option<StmtCacheData<T>> {
37 if let Some(stmt) = self.cache.remove(sql) {
38 let sql = self.sqls.take(sql).unwrap();
39
40 Some(StmtCacheData { stmt, sql })
41 } else {
42 None
43 }
44 }
45
46 fn insert(&mut self, data: StmtCacheData<T>) -> Option<T> {
49 if self.sqls.contains(&data.sql) {
50 self.cache.insert(data.sql, data.stmt)
52 } else {
53 self.sqls.insert(data.sql.clone());
55
56 let old = if self.cache.len() == self.cache.capacity() {
58 if let Some((sql, stmt)) = self.cache.remove_lru() {
59 self.sqls.remove(&sql);
61
62 Some(stmt)
63 } else {
64 None
65 }
66 } else {
67 None
68 };
69
70 self.cache.insert(data.sql, data.stmt);
72
73 old
74 }
75 }
76}
77
78impl<C> StmtCache<StatementData<C>>
80where
81 C: FirebirdClient,
82{
83 pub fn get_or_prepare(
85 tr: &mut Transaction<C>,
86 sql: &str,
87 named_params: bool,
88 ) -> Result<StmtCacheData<StatementData<C>>, FbError> {
89 if let Some(data) = tr.conn.stmt_cache.get(sql) {
90 Ok(data)
91 } else {
92 Ok(StmtCacheData {
93 sql: sql.to_string(),
94 stmt: StatementData::prepare(tr.conn, &mut tr.data, sql, named_params)?,
95 })
96 }
97 }
98
99 pub fn insert_and_close(
102 conn: &mut Connection<C>,
103 data: StmtCacheData<StatementData<C>>,
104 ) -> Result<(), FbError> {
105 conn.stmt_cache.sqls.insert(data.sql.clone());
106
107 if let Some(mut stmt) = conn.stmt_cache.insert(data) {
109 stmt.close(conn)?;
110 }
111
112 Ok(())
113 }
114
115 pub fn close_all(conn: &mut Connection<C>) {
118 let mut stmt_cache = mem::replace(&mut conn.stmt_cache, StmtCache::new(0));
119
120 for (_, stmt) in stmt_cache.cache.iter_mut() {
121 stmt.close(conn).ok();
122 }
123 }
124}
125
126#[test]
127fn stmt_cache_test() {
128 let mut cache = StmtCache::new(2);
129
130 let mk_test_data = |n: usize| StmtCacheData {
131 sql: format!("sql {}", n),
132 stmt: n,
133 };
134
135 let sql1 = mk_test_data(1);
136 let sql2 = mk_test_data(2);
137 let sql3 = mk_test_data(3);
138 let sql4 = mk_test_data(4);
139 let sql5 = mk_test_data(5);
140 let sql6 = mk_test_data(6);
141
142 assert!(cache.get(&sql1.sql).is_none());
143
144 assert!(cache.insert(sql1).is_none());
145
146 assert!(cache.insert(sql2).is_none());
147
148 let stmt = cache.insert(sql3).expect("sql1 not returned");
149 assert_eq!(stmt, 1);
150
151 assert!(cache.get("sql 1").is_none());
152
153 let sql2 = cache.get("sql 2").expect("Sql 2 not in the cache");
155 assert!(cache.insert(sql2).is_none());
156
157 let stmt = cache.insert(sql4).expect("sql3 not returned");
158 assert_eq!(stmt, 3);
159
160 let stmt = cache.insert(sql5).expect("sql2 not returned");
161 assert_eq!(stmt, 2);
162
163 let stmt = cache.insert(sql6).expect("sql4 not returned");
164 assert_eq!(stmt, 4);
165
166 assert_eq!(cache.get("sql 5").expect("sql5 not in the cache").stmt, 5);
167 assert_eq!(cache.get("sql 6").expect("sql6 not in the cache").stmt, 6);
168
169 assert!(cache.cache.is_empty());
170 assert!(cache.sqls.is_empty());
171}