1use std::borrow::Cow;
26use std::collections::HashMap;
27use std::hash::{Hash, Hasher};
28use std::sync::{Arc, RwLock};
29use tracing::debug;
30
31#[derive(Debug)]
35pub struct QueryCache {
36 max_size: usize,
38 cache: RwLock<HashMap<QueryKey, CachedQuery>>,
40 stats: RwLock<CacheStats>,
42}
43
44#[derive(Debug, Clone, PartialEq, Eq, Hash)]
46pub struct QueryKey {
47 key: Cow<'static, str>,
49}
50
51impl QueryKey {
52 #[inline]
54 pub const fn new(key: &'static str) -> Self {
55 Self {
56 key: Cow::Borrowed(key),
57 }
58 }
59
60 #[inline]
62 pub fn owned(key: String) -> Self {
63 Self {
64 key: Cow::Owned(key),
65 }
66 }
67}
68
69impl From<&'static str> for QueryKey {
70 fn from(s: &'static str) -> Self {
71 Self::new(s)
72 }
73}
74
75impl From<String> for QueryKey {
76 fn from(s: String) -> Self {
77 Self::owned(s)
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct CachedQuery {
84 pub sql: String,
86 pub param_count: usize,
88 access_count: u64,
90}
91
92impl CachedQuery {
93 pub fn new(sql: impl Into<String>, param_count: usize) -> Self {
95 Self {
96 sql: sql.into(),
97 param_count,
98 access_count: 0,
99 }
100 }
101
102 #[inline]
104 pub fn sql(&self) -> &str {
105 &self.sql
106 }
107
108 #[inline]
110 pub fn param_count(&self) -> usize {
111 self.param_count
112 }
113}
114
115#[derive(Debug, Default, Clone)]
117pub struct CacheStats {
118 pub hits: u64,
120 pub misses: u64,
122 pub evictions: u64,
124 pub insertions: u64,
126}
127
128impl CacheStats {
129 #[inline]
131 pub fn hit_rate(&self) -> f64 {
132 let total = self.hits + self.misses;
133 if total == 0 {
134 0.0
135 } else {
136 self.hits as f64 / total as f64
137 }
138 }
139}
140
141impl QueryCache {
142 pub fn new(max_size: usize) -> Self {
144 tracing::info!(max_size, "QueryCache initialized");
145 Self {
146 max_size,
147 cache: RwLock::new(HashMap::with_capacity(max_size)),
148 stats: RwLock::new(CacheStats::default()),
149 }
150 }
151
152 pub fn insert(&self, key: impl Into<QueryKey>, sql: impl Into<String>) {
154 let key = key.into();
155 let sql = sql.into();
156 let param_count = count_placeholders(&sql);
157 debug!(key = ?key.key, sql_len = sql.len(), param_count, "QueryCache::insert()");
158
159 let mut cache = self.cache.write().unwrap();
160 let mut stats = self.stats.write().unwrap();
161
162 if cache.len() >= self.max_size && !cache.contains_key(&key) {
164 self.evict_lru(&mut cache);
165 stats.evictions += 1;
166 debug!("QueryCache evicted entry");
167 }
168
169 cache.insert(key, CachedQuery::new(sql, param_count));
170 stats.insertions += 1;
171 }
172
173 pub fn insert_with_params(&self, key: impl Into<QueryKey>, sql: impl Into<String>, param_count: usize) {
175 let key = key.into();
176 let sql = sql.into();
177
178 let mut cache = self.cache.write().unwrap();
179 let mut stats = self.stats.write().unwrap();
180
181 if cache.len() >= self.max_size && !cache.contains_key(&key) {
183 self.evict_lru(&mut cache);
184 stats.evictions += 1;
185 }
186
187 cache.insert(key, CachedQuery::new(sql, param_count));
188 stats.insertions += 1;
189 }
190
191 pub fn get(&self, key: impl Into<QueryKey>) -> Option<String> {
193 let key = key.into();
194
195 {
197 let cache = self.cache.read().unwrap();
198 if let Some(entry) = cache.get(&key) {
199 let mut stats = self.stats.write().unwrap();
200 stats.hits += 1;
201 debug!(key = ?key.key, "QueryCache hit");
202 return Some(entry.sql.clone());
203 }
204 }
205
206 let mut stats = self.stats.write().unwrap();
207 stats.misses += 1;
208 debug!(key = ?key.key, "QueryCache miss");
209 None
210 }
211
212 pub fn get_entry(&self, key: impl Into<QueryKey>) -> Option<CachedQuery> {
214 let key = key.into();
215
216 let cache = self.cache.read().unwrap();
217 if let Some(entry) = cache.get(&key) {
218 let mut stats = self.stats.write().unwrap();
219 stats.hits += 1;
220 return Some(entry.clone());
221 }
222
223 let mut stats = self.stats.write().unwrap();
224 stats.misses += 1;
225 None
226 }
227
228 pub fn get_or_insert<F>(&self, key: impl Into<QueryKey>, f: F) -> String
233 where
234 F: FnOnce() -> String,
235 {
236 let key = key.into();
237
238 if let Some(sql) = self.get(key.clone()) {
240 return sql;
241 }
242
243 let sql = f();
245 self.insert(key, sql.clone());
246 sql
247 }
248
249 pub fn contains(&self, key: impl Into<QueryKey>) -> bool {
251 let key = key.into();
252 let cache = self.cache.read().unwrap();
253 cache.contains_key(&key)
254 }
255
256 pub fn remove(&self, key: impl Into<QueryKey>) -> Option<String> {
258 let key = key.into();
259 let mut cache = self.cache.write().unwrap();
260 cache.remove(&key).map(|e| e.sql)
261 }
262
263 pub fn clear(&self) {
265 let mut cache = self.cache.write().unwrap();
266 cache.clear();
267 }
268
269 pub fn len(&self) -> usize {
271 let cache = self.cache.read().unwrap();
272 cache.len()
273 }
274
275 pub fn is_empty(&self) -> bool {
277 self.len() == 0
278 }
279
280 pub fn max_size(&self) -> usize {
282 self.max_size
283 }
284
285 pub fn stats(&self) -> CacheStats {
287 let stats = self.stats.read().unwrap();
288 stats.clone()
289 }
290
291 pub fn reset_stats(&self) {
293 let mut stats = self.stats.write().unwrap();
294 *stats = CacheStats::default();
295 }
296
297 fn evict_lru(&self, cache: &mut HashMap<QueryKey, CachedQuery>) {
299 let to_evict = cache.len() / 4; if to_evict == 0 {
303 return;
304 }
305
306 let mut entries: Vec<_> = cache.iter().map(|(k, v)| (k.clone(), v.access_count)).collect();
307 entries.sort_by_key(|(_, count)| *count);
308
309 for (key, _) in entries.into_iter().take(to_evict) {
310 cache.remove(&key);
311 }
312 }
313}
314
315impl Default for QueryCache {
316 fn default() -> Self {
317 Self::new(1000)
318 }
319}
320
321fn count_placeholders(sql: &str) -> usize {
323 let mut count = 0;
324 let mut chars = sql.chars().peekable();
325
326 while let Some(c) = chars.next() {
327 if c == '$' {
328 let mut num = String::new();
330 while let Some(&d) = chars.peek() {
331 if d.is_ascii_digit() {
332 num.push(d);
333 chars.next();
334 } else {
335 break;
336 }
337 }
338 if !num.is_empty() {
339 if let Ok(n) = num.parse::<usize>() {
340 count = count.max(n);
341 }
342 }
343 } else if c == '?' {
344 count += 1;
346 }
347 }
348
349 count
350}
351
352#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
354pub struct QueryHash(u64);
355
356impl QueryHash {
357 pub fn new(sql: &str) -> Self {
359 let mut hasher = std::collections::hash_map::DefaultHasher::new();
360 sql.hash(&mut hasher);
361 Self(hasher.finish())
362 }
363
364 #[inline]
366 pub fn value(&self) -> u64 {
367 self.0
368 }
369}
370
371pub mod patterns {
373 use super::QueryKey;
374
375 #[inline]
377 pub fn select_by_id(table: &str) -> QueryKey {
378 QueryKey::owned(format!("select_by_id:{}", table))
379 }
380
381 #[inline]
383 pub fn select_all(table: &str) -> QueryKey {
384 QueryKey::owned(format!("select_all:{}", table))
385 }
386
387 #[inline]
389 pub fn insert(table: &str, columns: usize) -> QueryKey {
390 QueryKey::owned(format!("insert:{}:{}", table, columns))
391 }
392
393 #[inline]
395 pub fn update_by_id(table: &str, columns: usize) -> QueryKey {
396 QueryKey::owned(format!("update_by_id:{}:{}", table, columns))
397 }
398
399 #[inline]
401 pub fn delete_by_id(table: &str) -> QueryKey {
402 QueryKey::owned(format!("delete_by_id:{}", table))
403 }
404
405 #[inline]
407 pub fn count(table: &str) -> QueryKey {
408 QueryKey::owned(format!("count:{}", table))
409 }
410
411 #[inline]
413 pub fn count_filtered(table: &str, filter_hash: u64) -> QueryKey {
414 QueryKey::owned(format!("count:{}:{}", table, filter_hash))
415 }
416}
417
418#[derive(Debug)]
448pub struct SqlTemplateCache {
449 max_size: usize,
451 templates: parking_lot::RwLock<HashMap<u64, Arc<SqlTemplate>>>,
453 key_index: parking_lot::RwLock<HashMap<Cow<'static, str>, u64>>,
455 stats: parking_lot::RwLock<CacheStats>,
457}
458
459#[derive(Debug)]
461pub struct SqlTemplate {
462 pub sql: Arc<str>,
464 pub hash: u64,
466 pub param_count: usize,
468 last_access: std::sync::atomic::AtomicU64,
470}
471
472impl Clone for SqlTemplate {
473 fn clone(&self) -> Self {
474 use std::sync::atomic::Ordering;
475 Self {
476 sql: Arc::clone(&self.sql),
477 hash: self.hash,
478 param_count: self.param_count,
479 last_access: std::sync::atomic::AtomicU64::new(
480 self.last_access.load(Ordering::Relaxed)
481 ),
482 }
483 }
484}
485
486impl SqlTemplate {
487 pub fn new(sql: impl AsRef<str>) -> Self {
489 let sql_str = sql.as_ref();
490 let param_count = count_placeholders(sql_str);
491 let hash = {
492 let mut hasher = std::collections::hash_map::DefaultHasher::new();
493 sql_str.hash(&mut hasher);
494 hasher.finish()
495 };
496
497 Self {
498 sql: Arc::from(sql_str),
499 hash,
500 param_count,
501 last_access: std::sync::atomic::AtomicU64::new(0),
502 }
503 }
504
505 #[inline(always)]
507 pub fn sql(&self) -> &str {
508 &self.sql
509 }
510
511 #[inline(always)]
513 pub fn sql_arc(&self) -> Arc<str> {
514 Arc::clone(&self.sql)
515 }
516
517 #[inline]
519 fn touch(&self) {
520 use std::sync::atomic::Ordering;
521 use std::time::{SystemTime, UNIX_EPOCH};
522 let now = SystemTime::now()
523 .duration_since(UNIX_EPOCH)
524 .map(|d| d.as_secs())
525 .unwrap_or(0);
526 self.last_access.store(now, Ordering::Relaxed);
527 }
528}
529
530impl SqlTemplateCache {
531 pub fn new(max_size: usize) -> Self {
533 tracing::info!(max_size, "SqlTemplateCache initialized");
534 Self {
535 max_size,
536 templates: parking_lot::RwLock::new(HashMap::with_capacity(max_size)),
537 key_index: parking_lot::RwLock::new(HashMap::with_capacity(max_size)),
538 stats: parking_lot::RwLock::new(CacheStats::default()),
539 }
540 }
541
542 #[inline]
546 pub fn register(&self, key: impl Into<Cow<'static, str>>, sql: impl AsRef<str>) -> Arc<SqlTemplate> {
547 let key = key.into();
548 let template = Arc::new(SqlTemplate::new(sql));
549 let hash = template.hash;
550
551 let mut templates = self.templates.write();
552 let mut key_index = self.key_index.write();
553 let mut stats = self.stats.write();
554
555 if templates.len() >= self.max_size {
557 self.evict_lru_internal(&mut templates, &mut key_index);
558 stats.evictions += 1;
559 }
560
561 key_index.insert(key, hash);
562 templates.insert(hash, Arc::clone(&template));
563 stats.insertions += 1;
564
565 debug!(hash, "SqlTemplateCache::register()");
566 template
567 }
568
569 #[inline]
571 pub fn register_by_hash(&self, hash: u64, sql: impl AsRef<str>) -> Arc<SqlTemplate> {
572 let template = Arc::new(SqlTemplate::new(sql));
573
574 let mut templates = self.templates.write();
575 let mut stats = self.stats.write();
576
577 if templates.len() >= self.max_size {
578 let mut key_index = self.key_index.write();
579 self.evict_lru_internal(&mut templates, &mut key_index);
580 stats.evictions += 1;
581 }
582
583 templates.insert(hash, Arc::clone(&template));
584 stats.insertions += 1;
585
586 template
587 }
588
589 #[inline]
597 pub fn get(&self, key: &str) -> Option<Arc<SqlTemplate>> {
598 let hash = {
599 let key_index = self.key_index.read();
600 match key_index.get(key) {
601 Some(&h) => h,
602 None => {
603 drop(key_index); let mut stats = self.stats.write();
605 stats.misses += 1;
606 return None;
607 }
608 }
609 };
610
611 let templates = self.templates.read();
612 if let Some(template) = templates.get(&hash) {
613 template.touch();
614 let mut stats = self.stats.write();
615 stats.hits += 1;
616 return Some(Arc::clone(template));
617 }
618
619 let mut stats = self.stats.write();
620 stats.misses += 1;
621 None
622 }
623
624 #[inline(always)]
630 pub fn get_by_hash(&self, hash: u64) -> Option<Arc<SqlTemplate>> {
631 let templates = self.templates.read();
632 if let Some(template) = templates.get(&hash) {
633 template.touch();
634 return Some(Arc::clone(template));
636 }
637 None
638 }
639
640 #[inline]
642 pub fn get_sql(&self, key: &str) -> Option<Arc<str>> {
643 self.get(key).map(|t| t.sql_arc())
644 }
645
646 #[inline]
648 pub fn get_or_register<F>(&self, key: impl Into<Cow<'static, str>>, f: F) -> Arc<SqlTemplate>
649 where
650 F: FnOnce() -> String,
651 {
652 let key = key.into();
653
654 if let Some(template) = self.get(&key) {
656 return template;
657 }
658
659 let sql = f();
661 self.register(key, sql)
662 }
663
664 #[inline]
666 pub fn contains(&self, key: &str) -> bool {
667 let key_index = self.key_index.read();
668 key_index.contains_key(key)
669 }
670
671 pub fn stats(&self) -> CacheStats {
673 self.stats.read().clone()
674 }
675
676 pub fn len(&self) -> usize {
678 self.templates.read().len()
679 }
680
681 pub fn is_empty(&self) -> bool {
683 self.len() == 0
684 }
685
686 pub fn clear(&self) {
688 self.templates.write().clear();
689 self.key_index.write().clear();
690 }
691
692 fn evict_lru_internal(
694 &self,
695 templates: &mut HashMap<u64, Arc<SqlTemplate>>,
696 key_index: &mut HashMap<Cow<'static, str>, u64>,
697 ) {
698 use std::sync::atomic::Ordering;
699
700 let to_evict = templates.len() / 4;
701 if to_evict == 0 {
702 return;
703 }
704
705 let mut entries: Vec<_> = templates
707 .iter()
708 .map(|(&hash, t)| (hash, t.last_access.load(Ordering::Relaxed)))
709 .collect();
710 entries.sort_by_key(|(_, time)| *time);
711
712 for (hash, _) in entries.into_iter().take(to_evict) {
714 templates.remove(&hash);
715 key_index.retain(|_, h| *h != hash);
717 }
718 }
719}
720
721impl Default for SqlTemplateCache {
722 fn default() -> Self {
723 Self::new(1000)
724 }
725}
726
727static GLOBAL_TEMPLATE_CACHE: std::sync::OnceLock<SqlTemplateCache> = std::sync::OnceLock::new();
750
751#[inline(always)]
753pub fn global_template_cache() -> &'static SqlTemplateCache {
754 GLOBAL_TEMPLATE_CACHE.get_or_init(|| SqlTemplateCache::new(10000))
755}
756
757#[inline]
759pub fn register_global_template(key: impl Into<Cow<'static, str>>, sql: impl AsRef<str>) -> Arc<SqlTemplate> {
760 global_template_cache().register(key, sql)
761}
762
763#[inline(always)]
765pub fn get_global_template(key: &str) -> Option<Arc<SqlTemplate>> {
766 global_template_cache().get(key)
767}
768
769#[inline]
775pub fn precompute_query_hash(key: &str) -> u64 {
776 let mut hasher = std::collections::hash_map::DefaultHasher::new();
777 key.hash(&mut hasher);
778 hasher.finish()
779}
780
781#[cfg(test)]
782mod tests {
783 use super::*;
784
785 #[test]
786 fn test_query_cache_basic() {
787 let cache = QueryCache::new(10);
788
789 cache.insert("users_by_id", "SELECT * FROM users WHERE id = $1");
790 assert!(cache.contains("users_by_id"));
791
792 let sql = cache.get("users_by_id");
793 assert_eq!(sql, Some("SELECT * FROM users WHERE id = $1".to_string()));
794 }
795
796 #[test]
797 fn test_query_cache_get_or_insert() {
798 let cache = QueryCache::new(10);
799
800 let sql1 = cache.get_or_insert("test", || "SELECT 1".to_string());
801 assert_eq!(sql1, "SELECT 1");
802
803 let sql2 = cache.get_or_insert("test", || "SELECT 2".to_string());
804 assert_eq!(sql2, "SELECT 1"); }
806
807 #[test]
808 fn test_query_cache_stats() {
809 let cache = QueryCache::new(10);
810
811 cache.insert("test", "SELECT 1");
812 cache.get("test"); cache.get("test"); cache.get("missing"); let stats = cache.stats();
817 assert_eq!(stats.hits, 2);
818 assert_eq!(stats.misses, 1);
819 assert_eq!(stats.insertions, 1);
820 }
821
822 #[test]
823 fn test_count_placeholders_postgres() {
824 assert_eq!(count_placeholders("SELECT * FROM users WHERE id = $1"), 1);
825 assert_eq!(count_placeholders("SELECT * FROM users WHERE id = $1 AND name = $2"), 2);
826 assert_eq!(count_placeholders("SELECT * FROM users WHERE id = $10"), 10);
827 }
828
829 #[test]
830 fn test_count_placeholders_mysql() {
831 assert_eq!(count_placeholders("SELECT * FROM users WHERE id = ?"), 1);
832 assert_eq!(count_placeholders("SELECT * FROM users WHERE id = ? AND name = ?"), 2);
833 }
834
835 #[test]
836 fn test_query_hash() {
837 let hash1 = QueryHash::new("SELECT * FROM users");
838 let hash2 = QueryHash::new("SELECT * FROM users");
839 let hash3 = QueryHash::new("SELECT * FROM posts");
840
841 assert_eq!(hash1, hash2);
842 assert_ne!(hash1, hash3);
843 }
844
845 #[test]
846 fn test_patterns() {
847 let key = patterns::select_by_id("users");
848 assert!(key.key.starts_with("select_by_id:"));
849 }
850
851 #[test]
856 fn test_sql_template_cache_basic() {
857 let cache = SqlTemplateCache::new(100);
858
859 let template = cache.register("users_by_id", "SELECT * FROM users WHERE id = $1");
860 assert_eq!(template.sql(), "SELECT * FROM users WHERE id = $1");
861 assert_eq!(template.param_count, 1);
862 }
863
864 #[test]
865 fn test_sql_template_cache_get() {
866 let cache = SqlTemplateCache::new(100);
867
868 cache.register("test_query", "SELECT * FROM test WHERE x = $1");
869
870 let result = cache.get("test_query");
871 assert!(result.is_some());
872 assert_eq!(result.unwrap().sql(), "SELECT * FROM test WHERE x = $1");
873
874 let missing = cache.get("nonexistent");
875 assert!(missing.is_none());
876 }
877
878 #[test]
879 fn test_sql_template_cache_get_by_hash() {
880 let cache = SqlTemplateCache::new(100);
881
882 let template = cache.register("fast_query", "SELECT 1");
883 let hash = template.hash;
884
885 let result = cache.get_by_hash(hash);
887 assert!(result.is_some());
888 assert_eq!(result.unwrap().sql(), "SELECT 1");
889 }
890
891 #[test]
892 fn test_sql_template_cache_get_or_register() {
893 let cache = SqlTemplateCache::new(100);
894
895 let t1 = cache.get_or_register("computed", || "SELECT * FROM computed".to_string());
896 assert_eq!(t1.sql(), "SELECT * FROM computed");
897
898 let t2 = cache.get_or_register("computed", || panic!("Should not be called"));
900 assert_eq!(t2.sql(), "SELECT * FROM computed");
901 assert_eq!(t1.hash, t2.hash);
902 }
903
904 #[test]
905 fn test_sql_template_cache_stats() {
906 let cache = SqlTemplateCache::new(100);
907
908 cache.register("q1", "SELECT 1");
909 cache.get("q1"); cache.get("q1"); cache.get("missing"); let stats = cache.stats();
914 assert_eq!(stats.hits, 2);
915 assert_eq!(stats.misses, 1);
916 assert_eq!(stats.insertions, 1);
917 }
918
919 #[test]
920 fn test_global_template_cache() {
921 let template = register_global_template("global_test", "SELECT * FROM global");
923 assert_eq!(template.sql(), "SELECT * FROM global");
924
925 let result = get_global_template("global_test");
927 assert!(result.is_some());
928 assert_eq!(result.unwrap().sql(), "SELECT * FROM global");
929 }
930
931 #[test]
932 fn test_precompute_query_hash() {
933 let hash1 = precompute_query_hash("test_key");
934 let hash2 = precompute_query_hash("test_key");
935 let hash3 = precompute_query_hash("other_key");
936
937 assert_eq!(hash1, hash2);
938 assert_ne!(hash1, hash3);
939 }
940}
941