mssql_client/
statement_cache.rs1#![allow(clippy::expect_used)]
17
18use std::collections::hash_map::DefaultHasher;
19use std::hash::{Hash, Hasher};
20use std::num::NonZeroUsize;
21use std::time::Instant;
22
23use lru::LruCache;
24
25pub const DEFAULT_MAX_STATEMENTS: usize = 256;
27
28#[derive(Debug, Clone)]
32pub struct PreparedStatement {
33 handle: i32,
35 sql_hash: u64,
37 sql: String,
39 created_at: Instant,
41}
42
43impl PreparedStatement {
44 pub fn new(handle: i32, sql: String) -> Self {
46 Self {
47 handle,
48 sql_hash: hash_sql(&sql),
49 sql,
50 created_at: Instant::now(),
51 }
52 }
53
54 #[must_use]
56 pub fn handle(&self) -> i32 {
57 self.handle
58 }
59
60 #[must_use]
62 pub fn sql_hash(&self) -> u64 {
63 self.sql_hash
64 }
65
66 #[must_use]
68 pub fn sql(&self) -> &str {
69 &self.sql
70 }
71
72 #[must_use]
74 pub fn created_at(&self) -> Instant {
75 self.created_at
76 }
77
78 #[must_use]
80 pub fn age(&self) -> std::time::Duration {
81 self.created_at.elapsed()
82 }
83}
84
85pub struct StatementCache {
91 cache: LruCache<u64, PreparedStatement>,
93 max_size: usize,
95 hits: u64,
97 misses: u64,
99}
100
101impl StatementCache {
102 #[must_use]
108 pub fn new(max_size: usize) -> Self {
109 assert!(max_size > 0, "max_size must be greater than 0");
110 Self {
111 cache: LruCache::new(NonZeroUsize::new(max_size).expect("max_size > 0")),
112 max_size,
113 hits: 0,
114 misses: 0,
115 }
116 }
117
118 #[must_use]
120 pub fn with_default_size() -> Self {
121 Self::new(DEFAULT_MAX_STATEMENTS)
122 }
123
124 pub fn get(&mut self, sql: &str) -> Option<i32> {
129 let hash = hash_sql(sql);
130 if let Some(stmt) = self.cache.get(&hash) {
131 self.hits += 1;
132 tracing::trace!(sql = sql, handle = stmt.handle, "statement cache hit");
133 Some(stmt.handle)
134 } else {
135 self.misses += 1;
136 tracing::trace!(sql = sql, "statement cache miss");
137 None
138 }
139 }
140
141 pub fn peek(&self, sql: &str) -> Option<&PreparedStatement> {
143 let hash = hash_sql(sql);
144 self.cache.peek(&hash)
145 }
146
147 pub fn insert(&mut self, stmt: PreparedStatement) -> Option<PreparedStatement> {
151 let hash = stmt.sql_hash;
152 tracing::debug!(
153 sql = stmt.sql(),
154 handle = stmt.handle,
155 "caching prepared statement"
156 );
157
158 let evicted = if self.cache.len() >= self.max_size {
160 self.cache.pop_lru().map(|(_, stmt)| stmt)
162 } else {
163 None
164 };
165
166 self.cache.put(hash, stmt);
167 evicted
168 }
169
170 pub fn remove(&mut self, sql: &str) -> Option<PreparedStatement> {
174 let hash = hash_sql(sql);
175 self.cache.pop(&hash)
176 }
177
178 pub fn clear(&mut self) -> impl Iterator<Item = PreparedStatement> + '_ {
183 let mut statements = Vec::with_capacity(self.cache.len());
184 while let Some((_, stmt)) = self.cache.pop_lru() {
185 statements.push(stmt);
186 }
187 tracing::debug!(count = statements.len(), "cleared statement cache");
188 statements.into_iter()
189 }
190
191 #[must_use]
193 pub fn len(&self) -> usize {
194 self.cache.len()
195 }
196
197 #[must_use]
199 pub fn is_empty(&self) -> bool {
200 self.cache.is_empty()
201 }
202
203 #[must_use]
205 pub fn max_size(&self) -> usize {
206 self.max_size
207 }
208
209 #[must_use]
211 pub fn hits(&self) -> u64 {
212 self.hits
213 }
214
215 #[must_use]
217 pub fn misses(&self) -> u64 {
218 self.misses
219 }
220
221 #[must_use]
223 pub fn hit_ratio(&self) -> f64 {
224 let total = self.hits + self.misses;
225 if total == 0 {
226 0.0
227 } else {
228 self.hits as f64 / total as f64
229 }
230 }
231
232 pub fn reset_stats(&mut self) {
234 self.hits = 0;
235 self.misses = 0;
236 }
237}
238
239impl Default for StatementCache {
240 fn default() -> Self {
241 Self::with_default_size()
242 }
243}
244
245impl std::fmt::Debug for StatementCache {
246 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247 f.debug_struct("StatementCache")
248 .field("len", &self.cache.len())
249 .field("max_size", &self.max_size)
250 .field("hits", &self.hits)
251 .field("misses", &self.misses)
252 .finish()
253 }
254}
255
256#[must_use]
260pub fn hash_sql(sql: &str) -> u64 {
261 let mut hasher = DefaultHasher::new();
262 sql.hash(&mut hasher);
263 hasher.finish()
264}
265
266#[derive(Debug, Clone)]
268pub struct StatementCacheConfig {
269 pub enabled: bool,
271 pub max_size: usize,
273}
274
275impl Default for StatementCacheConfig {
276 fn default() -> Self {
277 Self {
278 enabled: true,
279 max_size: DEFAULT_MAX_STATEMENTS,
280 }
281 }
282}
283
284impl StatementCacheConfig {
285 #[must_use]
287 pub fn disabled() -> Self {
288 Self {
289 enabled: false,
290 max_size: 0,
291 }
292 }
293
294 #[must_use]
296 pub fn with_max_size(max_size: usize) -> Self {
297 Self {
298 enabled: true,
299 max_size,
300 }
301 }
302}
303
304#[cfg(test)]
305#[allow(clippy::unwrap_used)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn test_statement_cache_new() {
311 let cache = StatementCache::new(10);
312 assert_eq!(cache.max_size(), 10);
313 assert!(cache.is_empty());
314 assert_eq!(cache.len(), 0);
315 }
316
317 #[test]
318 fn test_statement_cache_insert_and_get() {
319 let mut cache = StatementCache::new(10);
320
321 let stmt = PreparedStatement::new(1, "SELECT * FROM users".to_string());
322 cache.insert(stmt);
323
324 assert_eq!(cache.len(), 1);
325 assert_eq!(cache.get("SELECT * FROM users"), Some(1));
326 assert_eq!(cache.hits(), 1);
327 assert_eq!(cache.misses(), 0);
328 }
329
330 #[test]
331 fn test_statement_cache_miss() {
332 let mut cache = StatementCache::new(10);
333
334 assert_eq!(cache.get("SELECT 1"), None);
335 assert_eq!(cache.misses(), 1);
336 assert_eq!(cache.hits(), 0);
337 }
338
339 #[test]
340 fn test_statement_cache_lru_eviction() {
341 let mut cache = StatementCache::new(2);
342
343 cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
345 cache.insert(PreparedStatement::new(2, "SELECT 2".to_string()));
346 assert_eq!(cache.len(), 2);
347
348 cache.get("SELECT 1");
350
351 let evicted = cache.insert(PreparedStatement::new(3, "SELECT 3".to_string()));
353
354 assert!(evicted.is_some());
355 assert_eq!(evicted.unwrap().handle(), 2);
356 assert_eq!(cache.len(), 2);
357
358 assert_eq!(cache.get("SELECT 1"), Some(1));
360 assert_eq!(cache.get("SELECT 2"), None);
362 assert_eq!(cache.get("SELECT 3"), Some(3));
364 }
365
366 #[test]
367 fn test_statement_cache_clear() {
368 let mut cache = StatementCache::new(10);
369
370 cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
371 cache.insert(PreparedStatement::new(2, "SELECT 2".to_string()));
372
373 let cleared: Vec<_> = cache.clear().collect();
374 assert_eq!(cleared.len(), 2);
375 assert!(cache.is_empty());
376 }
377
378 #[test]
379 fn test_statement_cache_remove() {
380 let mut cache = StatementCache::new(10);
381
382 cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
383 assert_eq!(cache.len(), 1);
384
385 let removed = cache.remove("SELECT 1");
386 assert!(removed.is_some());
387 assert_eq!(removed.unwrap().handle(), 1);
388 assert!(cache.is_empty());
389 }
390
391 #[test]
392 fn test_statement_cache_hit_ratio() {
393 let mut cache = StatementCache::new(10);
394
395 cache.insert(PreparedStatement::new(1, "SELECT 1".to_string()));
396
397 cache.get("SELECT 1");
399 cache.get("SELECT 1");
400 cache.get("SELECT 2");
401
402 assert_eq!(cache.hits(), 2);
403 assert_eq!(cache.misses(), 1);
404 assert!((cache.hit_ratio() - 0.666666).abs() < 0.001);
405 }
406
407 #[test]
408 fn test_hash_sql_consistency() {
409 let sql = "SELECT * FROM users WHERE id = @p1";
410 let hash1 = hash_sql(sql);
411 let hash2 = hash_sql(sql);
412 assert_eq!(hash1, hash2);
413 }
414
415 #[test]
416 fn test_hash_sql_different() {
417 let hash1 = hash_sql("SELECT 1");
418 let hash2 = hash_sql("SELECT 2");
419 assert_ne!(hash1, hash2);
420 }
421
422 #[test]
423 fn test_prepared_statement_age() {
424 let stmt = PreparedStatement::new(1, "SELECT 1".to_string());
425 std::thread::sleep(std::time::Duration::from_millis(10));
426 assert!(stmt.age().as_millis() >= 10);
427 }
428
429 #[test]
430 fn test_statement_cache_config_default() {
431 let config = StatementCacheConfig::default();
432 assert!(config.enabled);
433 assert_eq!(config.max_size, DEFAULT_MAX_STATEMENTS);
434 }
435
436 #[test]
437 fn test_statement_cache_config_disabled() {
438 let config = StatementCacheConfig::disabled();
439 assert!(!config.enabled);
440 }
441}