1use std::borrow::Cow;
26use std::collections::HashMap;
27use std::hash::{Hash, Hasher};
28use std::sync::atomic::{AtomicU64, Ordering};
29use std::sync::{Arc, RwLock};
30use tracing::debug;
31
32#[derive(Debug)]
36pub struct QueryCache {
37 max_size: usize,
39 cache: RwLock<HashMap<QueryKey, CachedQuery>>,
41 stats: AtomicCacheStats,
48}
49
50#[derive(Debug, Default)]
57struct AtomicCacheStats {
58 hits: AtomicU64,
59 misses: AtomicU64,
60 evictions: AtomicU64,
61 insertions: AtomicU64,
62}
63
64impl AtomicCacheStats {
65 #[inline]
66 fn record_hit(&self) {
67 self.hits.fetch_add(1, Ordering::Relaxed);
68 }
69
70 #[inline]
71 fn record_miss(&self) {
72 self.misses.fetch_add(1, Ordering::Relaxed);
73 }
74
75 #[inline]
76 fn record_eviction(&self) {
77 self.evictions.fetch_add(1, Ordering::Relaxed);
78 }
79
80 #[inline]
81 fn record_insertion(&self) {
82 self.insertions.fetch_add(1, Ordering::Relaxed);
83 }
84
85 fn snapshot(&self) -> CacheStats {
86 CacheStats {
87 hits: self.hits.load(Ordering::Relaxed),
88 misses: self.misses.load(Ordering::Relaxed),
89 evictions: self.evictions.load(Ordering::Relaxed),
90 insertions: self.insertions.load(Ordering::Relaxed),
91 }
92 }
93
94 fn reset(&self) {
95 self.hits.store(0, Ordering::Relaxed);
96 self.misses.store(0, Ordering::Relaxed);
97 self.evictions.store(0, Ordering::Relaxed);
98 self.insertions.store(0, Ordering::Relaxed);
99 }
100}
101
102#[derive(Debug, Clone, PartialEq, Eq, Hash)]
104pub struct QueryKey {
105 key: Cow<'static, str>,
107}
108
109impl QueryKey {
110 #[inline]
112 pub const fn new(key: &'static str) -> Self {
113 Self {
114 key: Cow::Borrowed(key),
115 }
116 }
117
118 #[inline]
120 pub fn owned(key: String) -> Self {
121 Self {
122 key: Cow::Owned(key),
123 }
124 }
125}
126
127impl From<&'static str> for QueryKey {
128 fn from(s: &'static str) -> Self {
129 Self::new(s)
130 }
131}
132
133impl From<String> for QueryKey {
134 fn from(s: String) -> Self {
135 Self::owned(s)
136 }
137}
138
139#[derive(Debug, Clone)]
141pub struct CachedQuery {
142 pub sql: String,
144 pub param_count: usize,
146 access_count: u64,
148}
149
150impl CachedQuery {
151 pub fn new(sql: impl Into<String>, param_count: usize) -> Self {
153 Self {
154 sql: sql.into(),
155 param_count,
156 access_count: 0,
157 }
158 }
159
160 #[inline]
162 pub fn sql(&self) -> &str {
163 &self.sql
164 }
165
166 #[inline]
168 pub fn param_count(&self) -> usize {
169 self.param_count
170 }
171}
172
173#[derive(Debug, Default, Clone)]
175pub struct CacheStats {
176 pub hits: u64,
178 pub misses: u64,
180 pub evictions: u64,
182 pub insertions: u64,
184}
185
186impl CacheStats {
187 #[inline]
189 pub fn hit_rate(&self) -> f64 {
190 let total = self.hits + self.misses;
191 if total == 0 {
192 0.0
193 } else {
194 self.hits as f64 / total as f64
195 }
196 }
197}
198
199impl QueryCache {
200 pub fn new(max_size: usize) -> Self {
202 tracing::info!(max_size, "QueryCache initialized");
203 Self {
204 max_size,
205 cache: RwLock::new(HashMap::with_capacity(max_size)),
206 stats: AtomicCacheStats::default(),
207 }
208 }
209
210 pub fn insert(&self, key: impl Into<QueryKey>, sql: impl Into<String>) {
212 let key = key.into();
213 let sql = sql.into();
214 let param_count = count_placeholders(&sql);
215 debug!(key = ?key.key, sql_len = sql.len(), param_count, "QueryCache::insert()");
216
217 let mut cache = self.cache.write().unwrap();
218
219 if cache.len() >= self.max_size && !cache.contains_key(&key) {
221 self.evict_lru(&mut cache);
222 self.stats.record_eviction();
223 debug!("QueryCache evicted entry");
224 }
225
226 cache.insert(key, CachedQuery::new(sql, param_count));
227 self.stats.record_insertion();
228 }
229
230 pub fn insert_with_params(
232 &self,
233 key: impl Into<QueryKey>,
234 sql: impl Into<String>,
235 param_count: usize,
236 ) {
237 let key = key.into();
238 let sql = sql.into();
239
240 let mut cache = self.cache.write().unwrap();
241
242 if cache.len() >= self.max_size && !cache.contains_key(&key) {
244 self.evict_lru(&mut cache);
245 self.stats.record_eviction();
246 }
247
248 cache.insert(key, CachedQuery::new(sql, param_count));
249 self.stats.record_insertion();
250 }
251
252 pub fn get(&self, key: impl Into<QueryKey>) -> Option<String> {
254 let key = key.into();
255
256 let cache = self.cache.read().unwrap();
257 if let Some(entry) = cache.get(&key) {
258 self.stats.record_hit();
261 debug!(key = ?key.key, "QueryCache hit");
262 return Some(entry.sql.clone());
263 }
264 drop(cache);
265
266 self.stats.record_miss();
267 debug!(key = ?key.key, "QueryCache miss");
268 None
269 }
270
271 pub fn get_entry(&self, key: impl Into<QueryKey>) -> Option<CachedQuery> {
273 let key = key.into();
274
275 let cache = self.cache.read().unwrap();
276 if let Some(entry) = cache.get(&key) {
277 self.stats.record_hit();
278 return Some(entry.clone());
279 }
280 drop(cache);
281
282 self.stats.record_miss();
283 None
284 }
285
286 pub fn get_or_insert<F>(&self, key: impl Into<QueryKey>, f: F) -> String
291 where
292 F: FnOnce() -> String,
293 {
294 let key = key.into();
295
296 if let Some(sql) = self.get(key.clone()) {
298 return sql;
299 }
300
301 let sql = f();
303 self.insert(key, sql.clone());
304 sql
305 }
306
307 pub fn contains(&self, key: impl Into<QueryKey>) -> bool {
309 let key = key.into();
310 let cache = self.cache.read().unwrap();
311 cache.contains_key(&key)
312 }
313
314 pub fn remove(&self, key: impl Into<QueryKey>) -> Option<String> {
316 let key = key.into();
317 let mut cache = self.cache.write().unwrap();
318 cache.remove(&key).map(|e| e.sql)
319 }
320
321 pub fn clear(&self) {
323 let mut cache = self.cache.write().unwrap();
324 cache.clear();
325 }
326
327 pub fn len(&self) -> usize {
329 let cache = self.cache.read().unwrap();
330 cache.len()
331 }
332
333 pub fn is_empty(&self) -> bool {
335 self.len() == 0
336 }
337
338 pub fn max_size(&self) -> usize {
340 self.max_size
341 }
342
343 pub fn stats(&self) -> CacheStats {
345 self.stats.snapshot()
346 }
347
348 pub fn reset_stats(&self) {
350 self.stats.reset();
351 }
352
353 fn evict_lru(&self, cache: &mut HashMap<QueryKey, CachedQuery>) {
355 let to_evict = cache.len() / 4; if to_evict == 0 {
359 return;
360 }
361
362 let mut entries: Vec<_> = cache
363 .iter()
364 .map(|(k, v)| (k.clone(), v.access_count))
365 .collect();
366 entries.sort_by_key(|(_, count)| *count);
367
368 for (key, _) in entries.into_iter().take(to_evict) {
369 cache.remove(&key);
370 }
371 }
372}
373
374impl Default for QueryCache {
375 fn default() -> Self {
376 Self::new(1000)
377 }
378}
379
380fn count_placeholders(sql: &str) -> usize {
382 let mut count = 0;
383 let mut chars = sql.chars().peekable();
384
385 while let Some(c) = chars.next() {
386 if c == '$' {
387 let mut num = String::new();
389 while let Some(&d) = chars.peek() {
390 if d.is_ascii_digit() {
391 num.push(d);
392 chars.next();
393 } else {
394 break;
395 }
396 }
397 if !num.is_empty()
398 && let Ok(n) = num.parse::<usize>()
399 {
400 count = count.max(n);
401 }
402 } else if c == '?' {
403 count += 1;
405 }
406 }
407
408 count
409}
410
411#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
413pub struct QueryHash(u64);
414
415impl QueryHash {
416 pub fn new(sql: &str) -> Self {
418 let mut hasher = std::collections::hash_map::DefaultHasher::new();
419 sql.hash(&mut hasher);
420 Self(hasher.finish())
421 }
422
423 #[inline]
425 pub fn value(&self) -> u64 {
426 self.0
427 }
428}
429
430pub mod patterns {
432 use super::QueryKey;
433
434 #[inline]
436 pub fn select_by_id(table: &str) -> QueryKey {
437 QueryKey::owned(format!("select_by_id:{}", table))
438 }
439
440 #[inline]
442 pub fn select_all(table: &str) -> QueryKey {
443 QueryKey::owned(format!("select_all:{}", table))
444 }
445
446 #[inline]
448 pub fn insert(table: &str, columns: usize) -> QueryKey {
449 QueryKey::owned(format!("insert:{}:{}", table, columns))
450 }
451
452 #[inline]
454 pub fn update_by_id(table: &str, columns: usize) -> QueryKey {
455 QueryKey::owned(format!("update_by_id:{}:{}", table, columns))
456 }
457
458 #[inline]
460 pub fn delete_by_id(table: &str) -> QueryKey {
461 QueryKey::owned(format!("delete_by_id:{}", table))
462 }
463
464 #[inline]
466 pub fn count(table: &str) -> QueryKey {
467 QueryKey::owned(format!("count:{}", table))
468 }
469
470 #[inline]
472 pub fn count_filtered(table: &str, filter_hash: u64) -> QueryKey {
473 QueryKey::owned(format!("count:{}:{}", table, filter_hash))
474 }
475}
476
477#[derive(Debug)]
507pub struct SqlTemplateCache {
508 max_size: usize,
510 templates: parking_lot::RwLock<HashMap<u64, Arc<SqlTemplate>>>,
512 key_index: parking_lot::RwLock<HashMap<Cow<'static, str>, u64>>,
514 stats: parking_lot::RwLock<CacheStats>,
516}
517
518#[derive(Debug)]
520pub struct SqlTemplate {
521 pub sql: Arc<str>,
523 pub hash: u64,
525 pub param_count: usize,
527 last_access: std::sync::atomic::AtomicU64,
529}
530
531impl Clone for SqlTemplate {
532 fn clone(&self) -> Self {
533 use std::sync::atomic::Ordering;
534 Self {
535 sql: Arc::clone(&self.sql),
536 hash: self.hash,
537 param_count: self.param_count,
538 last_access: std::sync::atomic::AtomicU64::new(
539 self.last_access.load(Ordering::Relaxed),
540 ),
541 }
542 }
543}
544
545impl SqlTemplate {
546 pub fn new(sql: impl AsRef<str>) -> Self {
548 let sql_str = sql.as_ref();
549 let param_count = count_placeholders(sql_str);
550 let hash = {
551 let mut hasher = std::collections::hash_map::DefaultHasher::new();
552 sql_str.hash(&mut hasher);
553 hasher.finish()
554 };
555
556 Self {
557 sql: Arc::from(sql_str),
558 hash,
559 param_count,
560 last_access: std::sync::atomic::AtomicU64::new(0),
561 }
562 }
563
564 #[inline(always)]
566 pub fn sql(&self) -> &str {
567 &self.sql
568 }
569
570 #[inline(always)]
572 pub fn sql_arc(&self) -> Arc<str> {
573 Arc::clone(&self.sql)
574 }
575
576 #[inline]
578 fn touch(&self) {
579 use std::sync::atomic::Ordering;
580 use std::time::{SystemTime, UNIX_EPOCH};
581 let now = SystemTime::now()
582 .duration_since(UNIX_EPOCH)
583 .map(|d| d.as_secs())
584 .unwrap_or(0);
585 self.last_access.store(now, Ordering::Relaxed);
586 }
587}
588
589impl SqlTemplateCache {
590 pub fn new(max_size: usize) -> Self {
592 tracing::info!(max_size, "SqlTemplateCache initialized");
593 Self {
594 max_size,
595 templates: parking_lot::RwLock::new(HashMap::with_capacity(max_size)),
596 key_index: parking_lot::RwLock::new(HashMap::with_capacity(max_size)),
597 stats: parking_lot::RwLock::new(CacheStats::default()),
598 }
599 }
600
601 #[inline]
605 pub fn register(
606 &self,
607 key: impl Into<Cow<'static, str>>,
608 sql: impl AsRef<str>,
609 ) -> Arc<SqlTemplate> {
610 let key = key.into();
611 let template = Arc::new(SqlTemplate::new(sql));
612 let hash = template.hash;
613
614 let mut templates = self.templates.write();
615 let mut key_index = self.key_index.write();
616 let mut stats = self.stats.write();
617
618 if templates.len() >= self.max_size {
620 self.evict_lru_internal(&mut templates, &mut key_index);
621 stats.evictions += 1;
622 }
623
624 key_index.insert(key, hash);
625 templates.insert(hash, Arc::clone(&template));
626 stats.insertions += 1;
627
628 debug!(hash, "SqlTemplateCache::register()");
629 template
630 }
631
632 #[inline]
634 pub fn register_by_hash(&self, hash: u64, sql: impl AsRef<str>) -> Arc<SqlTemplate> {
635 let template = Arc::new(SqlTemplate::new(sql));
636
637 let mut templates = self.templates.write();
638 let mut stats = self.stats.write();
639
640 if templates.len() >= self.max_size {
641 let mut key_index = self.key_index.write();
642 self.evict_lru_internal(&mut templates, &mut key_index);
643 stats.evictions += 1;
644 }
645
646 templates.insert(hash, Arc::clone(&template));
647 stats.insertions += 1;
648
649 template
650 }
651
652 #[inline]
660 pub fn get(&self, key: &str) -> Option<Arc<SqlTemplate>> {
661 let hash = {
662 let key_index = self.key_index.read();
663 match key_index.get(key) {
664 Some(&h) => h,
665 None => {
666 drop(key_index); let mut stats = self.stats.write();
668 stats.misses += 1;
669 return None;
670 }
671 }
672 };
673
674 let templates = self.templates.read();
675 if let Some(template) = templates.get(&hash) {
676 template.touch();
677 let mut stats = self.stats.write();
678 stats.hits += 1;
679 return Some(Arc::clone(template));
680 }
681
682 let mut stats = self.stats.write();
683 stats.misses += 1;
684 None
685 }
686
687 #[inline(always)]
693 pub fn get_by_hash(&self, hash: u64) -> Option<Arc<SqlTemplate>> {
694 let templates = self.templates.read();
695 if let Some(template) = templates.get(&hash) {
696 template.touch();
697 return Some(Arc::clone(template));
699 }
700 None
701 }
702
703 #[inline]
705 pub fn get_sql(&self, key: &str) -> Option<Arc<str>> {
706 self.get(key).map(|t| t.sql_arc())
707 }
708
709 #[inline]
711 pub fn get_or_register<F>(&self, key: impl Into<Cow<'static, str>>, f: F) -> Arc<SqlTemplate>
712 where
713 F: FnOnce() -> String,
714 {
715 let key = key.into();
716
717 if let Some(template) = self.get(&key) {
719 return template;
720 }
721
722 let sql = f();
724 self.register(key, sql)
725 }
726
727 #[inline]
729 pub fn contains(&self, key: &str) -> bool {
730 let key_index = self.key_index.read();
731 key_index.contains_key(key)
732 }
733
734 pub fn stats(&self) -> CacheStats {
736 self.stats.read().clone()
737 }
738
739 pub fn len(&self) -> usize {
741 self.templates.read().len()
742 }
743
744 pub fn is_empty(&self) -> bool {
746 self.len() == 0
747 }
748
749 pub fn clear(&self) {
751 self.templates.write().clear();
752 self.key_index.write().clear();
753 }
754
755 fn evict_lru_internal(
757 &self,
758 templates: &mut HashMap<u64, Arc<SqlTemplate>>,
759 key_index: &mut HashMap<Cow<'static, str>, u64>,
760 ) {
761 use std::sync::atomic::Ordering;
762
763 let to_evict = templates.len() / 4;
764 if to_evict == 0 {
765 return;
766 }
767
768 let mut entries: Vec<_> = templates
770 .iter()
771 .map(|(&hash, t)| (hash, t.last_access.load(Ordering::Relaxed)))
772 .collect();
773 entries.sort_by_key(|(_, time)| *time);
774
775 let evicted: std::collections::HashSet<u64> = entries
778 .into_iter()
779 .take(to_evict)
780 .map(|(hash, _)| {
781 templates.remove(&hash);
782 hash
783 })
784 .collect();
785 key_index.retain(|_, h| !evicted.contains(h));
786 }
787}
788
789impl Default for SqlTemplateCache {
790 fn default() -> Self {
791 Self::new(1000)
792 }
793}
794
795static GLOBAL_TEMPLATE_CACHE: std::sync::OnceLock<SqlTemplateCache> = std::sync::OnceLock::new();
818
819#[inline(always)]
821pub fn global_template_cache() -> &'static SqlTemplateCache {
822 GLOBAL_TEMPLATE_CACHE.get_or_init(|| SqlTemplateCache::new(10000))
823}
824
825#[inline]
827pub fn register_global_template(
828 key: impl Into<Cow<'static, str>>,
829 sql: impl AsRef<str>,
830) -> Arc<SqlTemplate> {
831 global_template_cache().register(key, sql)
832}
833
834#[inline(always)]
836pub fn get_global_template(key: &str) -> Option<Arc<SqlTemplate>> {
837 global_template_cache().get(key)
838}
839
840#[inline]
846pub fn precompute_query_hash(key: &str) -> u64 {
847 let mut hasher = std::collections::hash_map::DefaultHasher::new();
848 key.hash(&mut hasher);
849 hasher.finish()
850}
851
852#[cfg(test)]
853mod tests {
854 use super::*;
855
856 #[test]
857 fn test_query_cache_basic() {
858 let cache = QueryCache::new(10);
859
860 cache.insert("users_by_id", "SELECT * FROM users WHERE id = $1");
861 assert!(cache.contains("users_by_id"));
862
863 let sql = cache.get("users_by_id");
864 assert_eq!(sql, Some("SELECT * FROM users WHERE id = $1".to_string()));
865 }
866
867 #[test]
868 fn test_query_cache_get_or_insert() {
869 let cache = QueryCache::new(10);
870
871 let sql1 = cache.get_or_insert("test", || "SELECT 1".to_string());
872 assert_eq!(sql1, "SELECT 1");
873
874 let sql2 = cache.get_or_insert("test", || "SELECT 2".to_string());
875 assert_eq!(sql2, "SELECT 1"); }
877
878 #[test]
879 fn test_query_cache_stats() {
880 let cache = QueryCache::new(10);
881
882 cache.insert("test", "SELECT 1");
883 cache.get("test"); cache.get("test"); cache.get("missing"); let stats = cache.stats();
888 assert_eq!(stats.hits, 2);
889 assert_eq!(stats.misses, 1);
890 assert_eq!(stats.insertions, 1);
891 }
892
893 #[test]
894 fn test_count_placeholders_postgres() {
895 assert_eq!(count_placeholders("SELECT * FROM users WHERE id = $1"), 1);
896 assert_eq!(
897 count_placeholders("SELECT * FROM users WHERE id = $1 AND name = $2"),
898 2
899 );
900 assert_eq!(count_placeholders("SELECT * FROM users WHERE id = $10"), 10);
901 }
902
903 #[test]
904 fn test_count_placeholders_mysql() {
905 assert_eq!(count_placeholders("SELECT * FROM users WHERE id = ?"), 1);
906 assert_eq!(
907 count_placeholders("SELECT * FROM users WHERE id = ? AND name = ?"),
908 2
909 );
910 }
911
912 #[test]
913 fn test_query_hash() {
914 let hash1 = QueryHash::new("SELECT * FROM users");
915 let hash2 = QueryHash::new("SELECT * FROM users");
916 let hash3 = QueryHash::new("SELECT * FROM posts");
917
918 assert_eq!(hash1, hash2);
919 assert_ne!(hash1, hash3);
920 }
921
922 #[test]
923 fn test_patterns() {
924 let key = patterns::select_by_id("users");
925 assert!(key.key.starts_with("select_by_id:"));
926 }
927
928 #[test]
933 fn test_sql_template_cache_basic() {
934 let cache = SqlTemplateCache::new(100);
935
936 let template = cache.register("users_by_id", "SELECT * FROM users WHERE id = $1");
937 assert_eq!(template.sql(), "SELECT * FROM users WHERE id = $1");
938 assert_eq!(template.param_count, 1);
939 }
940
941 #[test]
942 fn test_sql_template_cache_get() {
943 let cache = SqlTemplateCache::new(100);
944
945 cache.register("test_query", "SELECT * FROM test WHERE x = $1");
946
947 let result = cache.get("test_query");
948 assert!(result.is_some());
949 assert_eq!(result.unwrap().sql(), "SELECT * FROM test WHERE x = $1");
950
951 let missing = cache.get("nonexistent");
952 assert!(missing.is_none());
953 }
954
955 #[test]
956 fn test_sql_template_cache_get_by_hash() {
957 let cache = SqlTemplateCache::new(100);
958
959 let template = cache.register("fast_query", "SELECT 1");
960 let hash = template.hash;
961
962 let result = cache.get_by_hash(hash);
964 assert!(result.is_some());
965 assert_eq!(result.unwrap().sql(), "SELECT 1");
966 }
967
968 #[test]
969 fn test_sql_template_cache_get_or_register() {
970 let cache = SqlTemplateCache::new(100);
971
972 let t1 = cache.get_or_register("computed", || "SELECT * FROM computed".to_string());
973 assert_eq!(t1.sql(), "SELECT * FROM computed");
974
975 let t2 = cache.get_or_register("computed", || panic!("Should not be called"));
977 assert_eq!(t2.sql(), "SELECT * FROM computed");
978 assert_eq!(t1.hash, t2.hash);
979 }
980
981 #[test]
982 fn test_sql_template_cache_stats() {
983 let cache = SqlTemplateCache::new(100);
984
985 cache.register("q1", "SELECT 1");
986 cache.get("q1"); cache.get("q1"); cache.get("missing"); let stats = cache.stats();
991 assert_eq!(stats.hits, 2);
992 assert_eq!(stats.misses, 1);
993 assert_eq!(stats.insertions, 1);
994 }
995
996 #[test]
997 fn test_global_template_cache() {
998 let template = register_global_template("global_test", "SELECT * FROM global");
1000 assert_eq!(template.sql(), "SELECT * FROM global");
1001
1002 let result = get_global_template("global_test");
1004 assert!(result.is_some());
1005 assert_eq!(result.unwrap().sql(), "SELECT * FROM global");
1006 }
1007
1008 #[test]
1009 fn test_precompute_query_hash() {
1010 let hash1 = precompute_query_hash("test_key");
1011 let hash2 = precompute_query_hash("test_key");
1012 let hash3 = precompute_query_hash("other_key");
1013
1014 assert_eq!(hash1, hash2);
1015 assert_ne!(hash1, hash3);
1016 }
1017
1018 #[test]
1019 fn test_execution_plan_cache() {
1020 let cache = ExecutionPlanCache::new(100);
1021
1022 let plan = cache.register(
1024 "users_by_email",
1025 "SELECT * FROM users WHERE email = $1",
1026 PlanHint::IndexScan("users_email_idx".into()),
1027 );
1028 assert_eq!(plan.sql.as_ref(), "SELECT * FROM users WHERE email = $1");
1029
1030 let result = cache.get("users_by_email");
1032 assert!(result.is_some());
1033 assert!(matches!(result.unwrap().hint, PlanHint::IndexScan(_)));
1034 }
1035}
1036
1037#[derive(Debug, Clone, Default)]
1047pub enum PlanHint {
1048 #[default]
1050 None,
1051 IndexScan(String),
1053 SeqScan,
1055 Parallel(u32),
1057 CachePlan,
1059 Timeout(std::time::Duration),
1061 Custom(String),
1063}
1064
1065#[derive(Debug)]
1067pub struct ExecutionPlan {
1068 pub sql: Arc<str>,
1070 pub hash: u64,
1072 pub hint: PlanHint,
1074 pub estimated_cost: Option<f64>,
1076 use_count: std::sync::atomic::AtomicU64,
1078 avg_execution_us: std::sync::atomic::AtomicU64,
1080}
1081
1082fn compute_hash(s: &str) -> u64 {
1084 let mut hasher = std::collections::hash_map::DefaultHasher::new();
1085 s.hash(&mut hasher);
1086 hasher.finish()
1087}
1088
1089impl ExecutionPlan {
1090 pub fn new(sql: impl AsRef<str>, hint: PlanHint) -> Self {
1092 let sql_str = sql.as_ref();
1093 Self {
1094 sql: Arc::from(sql_str),
1095 hash: compute_hash(sql_str),
1096 hint,
1097 estimated_cost: None,
1098 use_count: std::sync::atomic::AtomicU64::new(0),
1099 avg_execution_us: std::sync::atomic::AtomicU64::new(0),
1100 }
1101 }
1102
1103 pub fn with_cost(sql: impl AsRef<str>, hint: PlanHint, cost: f64) -> Self {
1105 let sql_str = sql.as_ref();
1106 Self {
1107 sql: Arc::from(sql_str),
1108 hash: compute_hash(sql_str),
1109 hint,
1110 estimated_cost: Some(cost),
1111 use_count: std::sync::atomic::AtomicU64::new(0),
1112 avg_execution_us: std::sync::atomic::AtomicU64::new(0),
1113 }
1114 }
1115
1116 pub fn record_execution(&self, duration_us: u64) {
1118 let old_count = self
1119 .use_count
1120 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1121 let old_avg = self
1122 .avg_execution_us
1123 .load(std::sync::atomic::Ordering::Relaxed);
1124
1125 let new_avg = if old_count == 0 {
1127 duration_us
1128 } else {
1129 (old_avg * old_count + duration_us) / (old_count + 1)
1131 };
1132
1133 self.avg_execution_us
1134 .store(new_avg, std::sync::atomic::Ordering::Relaxed);
1135 }
1136
1137 pub fn use_count(&self) -> u64 {
1139 self.use_count.load(std::sync::atomic::Ordering::Relaxed)
1140 }
1141
1142 pub fn avg_execution_us(&self) -> u64 {
1144 self.avg_execution_us
1145 .load(std::sync::atomic::Ordering::Relaxed)
1146 }
1147}
1148
1149#[derive(Debug)]
1174pub struct ExecutionPlanCache {
1175 max_size: usize,
1177 plans: parking_lot::RwLock<HashMap<u64, Arc<ExecutionPlan>>>,
1179 key_index: parking_lot::RwLock<HashMap<Cow<'static, str>, u64>>,
1181}
1182
1183impl ExecutionPlanCache {
1184 pub fn new(max_size: usize) -> Self {
1186 Self {
1187 max_size,
1188 plans: parking_lot::RwLock::new(HashMap::with_capacity(max_size / 2)),
1189 key_index: parking_lot::RwLock::new(HashMap::with_capacity(max_size / 2)),
1190 }
1191 }
1192
1193 pub fn register(
1195 &self,
1196 key: impl Into<Cow<'static, str>>,
1197 sql: impl AsRef<str>,
1198 hint: PlanHint,
1199 ) -> Arc<ExecutionPlan> {
1200 let key = key.into();
1201 let plan = Arc::new(ExecutionPlan::new(sql, hint));
1202 let hash = plan.hash;
1203
1204 let mut plans = self.plans.write();
1205 let mut key_index = self.key_index.write();
1206
1207 if plans.len() >= self.max_size && !plans.contains_key(&hash) {
1209 if let Some((&evict_hash, _)) = plans.iter().min_by_key(|(_, p)| p.use_count()) {
1211 plans.remove(&evict_hash);
1212 key_index.retain(|_, &mut v| v != evict_hash);
1213 }
1214 }
1215
1216 plans.insert(hash, Arc::clone(&plan));
1217 key_index.insert(key, hash);
1218
1219 plan
1220 }
1221
1222 pub fn register_with_cost(
1224 &self,
1225 key: impl Into<Cow<'static, str>>,
1226 sql: impl AsRef<str>,
1227 hint: PlanHint,
1228 cost: f64,
1229 ) -> Arc<ExecutionPlan> {
1230 let key = key.into();
1231 let plan = Arc::new(ExecutionPlan::with_cost(sql, hint, cost));
1232 let hash = plan.hash;
1233
1234 let mut plans = self.plans.write();
1235 let mut key_index = self.key_index.write();
1236
1237 if plans.len() >= self.max_size
1238 && !plans.contains_key(&hash)
1239 && let Some((&evict_hash, _)) = plans.iter().min_by_key(|(_, p)| p.use_count())
1240 {
1241 plans.remove(&evict_hash);
1242 key_index.retain(|_, &mut v| v != evict_hash);
1243 }
1244
1245 plans.insert(hash, Arc::clone(&plan));
1246 key_index.insert(key, hash);
1247
1248 plan
1249 }
1250
1251 pub fn get(&self, key: &str) -> Option<Arc<ExecutionPlan>> {
1253 let hash = {
1254 let key_index = self.key_index.read();
1255 *key_index.get(key)?
1256 };
1257
1258 self.plans.read().get(&hash).cloned()
1259 }
1260
1261 pub fn get_by_hash(&self, hash: u64) -> Option<Arc<ExecutionPlan>> {
1263 self.plans.read().get(&hash).cloned()
1264 }
1265
1266 pub fn get_or_register<F>(
1268 &self,
1269 key: impl Into<Cow<'static, str>>,
1270 sql_fn: F,
1271 hint: PlanHint,
1272 ) -> Arc<ExecutionPlan>
1273 where
1274 F: FnOnce() -> String,
1275 {
1276 let key = key.into();
1277
1278 if let Some(plan) = self.get(key.as_ref()) {
1280 return plan;
1281 }
1282
1283 self.register(key, sql_fn(), hint)
1285 }
1286
1287 pub fn record_execution(&self, key: &str, duration_us: u64) {
1289 if let Some(plan) = self.get(key) {
1290 plan.record_execution(duration_us);
1291 }
1292 }
1293
1294 pub fn slowest_queries(&self, limit: usize) -> Vec<Arc<ExecutionPlan>> {
1296 let plans = self.plans.read();
1297 let mut sorted: Vec<_> = plans.values().cloned().collect();
1298 sorted.sort_by_key(|a| std::cmp::Reverse(a.avg_execution_us()));
1299 sorted.truncate(limit);
1300 sorted
1301 }
1302
1303 pub fn most_used(&self, limit: usize) -> Vec<Arc<ExecutionPlan>> {
1305 let plans = self.plans.read();
1306 let mut sorted: Vec<_> = plans.values().cloned().collect();
1307 sorted.sort_by_key(|a| std::cmp::Reverse(a.use_count()));
1308 sorted.truncate(limit);
1309 sorted
1310 }
1311
1312 pub fn clear(&self) {
1314 self.plans.write().clear();
1315 self.key_index.write().clear();
1316 }
1317
1318 pub fn len(&self) -> usize {
1320 self.plans.read().len()
1321 }
1322
1323 pub fn is_empty(&self) -> bool {
1325 self.plans.read().is_empty()
1326 }
1327}