1use std::borrow::Cow;
26use std::collections::HashMap;
27use std::hash::{Hash, Hasher};
28use std::sync::{Arc, RwLock};
29use tracing::debug;
30
31#[derive(Debug)]
35pub struct QueryCache {
36 max_size: usize,
38 cache: RwLock<HashMap<QueryKey, CachedQuery>>,
40 stats: RwLock<CacheStats>,
42}
43
44#[derive(Debug, Clone, PartialEq, Eq, Hash)]
46pub struct QueryKey {
47 key: Cow<'static, str>,
49}
50
51impl QueryKey {
52 #[inline]
54 pub const fn new(key: &'static str) -> Self {
55 Self {
56 key: Cow::Borrowed(key),
57 }
58 }
59
60 #[inline]
62 pub fn owned(key: String) -> Self {
63 Self {
64 key: Cow::Owned(key),
65 }
66 }
67}
68
69impl From<&'static str> for QueryKey {
70 fn from(s: &'static str) -> Self {
71 Self::new(s)
72 }
73}
74
75impl From<String> for QueryKey {
76 fn from(s: String) -> Self {
77 Self::owned(s)
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct CachedQuery {
84 pub sql: String,
86 pub param_count: usize,
88 access_count: u64,
90}
91
92impl CachedQuery {
93 pub fn new(sql: impl Into<String>, param_count: usize) -> Self {
95 Self {
96 sql: sql.into(),
97 param_count,
98 access_count: 0,
99 }
100 }
101
102 #[inline]
104 pub fn sql(&self) -> &str {
105 &self.sql
106 }
107
108 #[inline]
110 pub fn param_count(&self) -> usize {
111 self.param_count
112 }
113}
114
115#[derive(Debug, Default, Clone)]
117pub struct CacheStats {
118 pub hits: u64,
120 pub misses: u64,
122 pub evictions: u64,
124 pub insertions: u64,
126}
127
128impl CacheStats {
129 #[inline]
131 pub fn hit_rate(&self) -> f64 {
132 let total = self.hits + self.misses;
133 if total == 0 {
134 0.0
135 } else {
136 self.hits as f64 / total as f64
137 }
138 }
139}
140
141impl QueryCache {
142 pub fn new(max_size: usize) -> Self {
144 tracing::info!(max_size, "QueryCache initialized");
145 Self {
146 max_size,
147 cache: RwLock::new(HashMap::with_capacity(max_size)),
148 stats: RwLock::new(CacheStats::default()),
149 }
150 }
151
152 pub fn insert(&self, key: impl Into<QueryKey>, sql: impl Into<String>) {
154 let key = key.into();
155 let sql = sql.into();
156 let param_count = count_placeholders(&sql);
157 debug!(key = ?key.key, sql_len = sql.len(), param_count, "QueryCache::insert()");
158
159 let mut cache = self.cache.write().unwrap();
160 let mut stats = self.stats.write().unwrap();
161
162 if cache.len() >= self.max_size && !cache.contains_key(&key) {
164 self.evict_lru(&mut cache);
165 stats.evictions += 1;
166 debug!("QueryCache evicted entry");
167 }
168
169 cache.insert(key, CachedQuery::new(sql, param_count));
170 stats.insertions += 1;
171 }
172
173 pub fn insert_with_params(
175 &self,
176 key: impl Into<QueryKey>,
177 sql: impl Into<String>,
178 param_count: usize,
179 ) {
180 let key = key.into();
181 let sql = sql.into();
182
183 let mut cache = self.cache.write().unwrap();
184 let mut stats = self.stats.write().unwrap();
185
186 if cache.len() >= self.max_size && !cache.contains_key(&key) {
188 self.evict_lru(&mut cache);
189 stats.evictions += 1;
190 }
191
192 cache.insert(key, CachedQuery::new(sql, param_count));
193 stats.insertions += 1;
194 }
195
196 pub fn get(&self, key: impl Into<QueryKey>) -> Option<String> {
198 let key = key.into();
199
200 {
202 let cache = self.cache.read().unwrap();
203 if let Some(entry) = cache.get(&key) {
204 let mut stats = self.stats.write().unwrap();
205 stats.hits += 1;
206 debug!(key = ?key.key, "QueryCache hit");
207 return Some(entry.sql.clone());
208 }
209 }
210
211 let mut stats = self.stats.write().unwrap();
212 stats.misses += 1;
213 debug!(key = ?key.key, "QueryCache miss");
214 None
215 }
216
217 pub fn get_entry(&self, key: impl Into<QueryKey>) -> Option<CachedQuery> {
219 let key = key.into();
220
221 let cache = self.cache.read().unwrap();
222 if let Some(entry) = cache.get(&key) {
223 let mut stats = self.stats.write().unwrap();
224 stats.hits += 1;
225 return Some(entry.clone());
226 }
227
228 let mut stats = self.stats.write().unwrap();
229 stats.misses += 1;
230 None
231 }
232
233 pub fn get_or_insert<F>(&self, key: impl Into<QueryKey>, f: F) -> String
238 where
239 F: FnOnce() -> String,
240 {
241 let key = key.into();
242
243 if let Some(sql) = self.get(key.clone()) {
245 return sql;
246 }
247
248 let sql = f();
250 self.insert(key, sql.clone());
251 sql
252 }
253
254 pub fn contains(&self, key: impl Into<QueryKey>) -> bool {
256 let key = key.into();
257 let cache = self.cache.read().unwrap();
258 cache.contains_key(&key)
259 }
260
261 pub fn remove(&self, key: impl Into<QueryKey>) -> Option<String> {
263 let key = key.into();
264 let mut cache = self.cache.write().unwrap();
265 cache.remove(&key).map(|e| e.sql)
266 }
267
268 pub fn clear(&self) {
270 let mut cache = self.cache.write().unwrap();
271 cache.clear();
272 }
273
274 pub fn len(&self) -> usize {
276 let cache = self.cache.read().unwrap();
277 cache.len()
278 }
279
280 pub fn is_empty(&self) -> bool {
282 self.len() == 0
283 }
284
285 pub fn max_size(&self) -> usize {
287 self.max_size
288 }
289
290 pub fn stats(&self) -> CacheStats {
292 let stats = self.stats.read().unwrap();
293 stats.clone()
294 }
295
296 pub fn reset_stats(&self) {
298 let mut stats = self.stats.write().unwrap();
299 *stats = CacheStats::default();
300 }
301
302 fn evict_lru(&self, cache: &mut HashMap<QueryKey, CachedQuery>) {
304 let to_evict = cache.len() / 4; if to_evict == 0 {
308 return;
309 }
310
311 let mut entries: Vec<_> = cache
312 .iter()
313 .map(|(k, v)| (k.clone(), v.access_count))
314 .collect();
315 entries.sort_by_key(|(_, count)| *count);
316
317 for (key, _) in entries.into_iter().take(to_evict) {
318 cache.remove(&key);
319 }
320 }
321}
322
323impl Default for QueryCache {
324 fn default() -> Self {
325 Self::new(1000)
326 }
327}
328
329fn count_placeholders(sql: &str) -> usize {
331 let mut count = 0;
332 let mut chars = sql.chars().peekable();
333
334 while let Some(c) = chars.next() {
335 if c == '$' {
336 let mut num = String::new();
338 while let Some(&d) = chars.peek() {
339 if d.is_ascii_digit() {
340 num.push(d);
341 chars.next();
342 } else {
343 break;
344 }
345 }
346 if !num.is_empty() {
347 if let Ok(n) = num.parse::<usize>() {
348 count = count.max(n);
349 }
350 }
351 } else if c == '?' {
352 count += 1;
354 }
355 }
356
357 count
358}
359
360#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
362pub struct QueryHash(u64);
363
364impl QueryHash {
365 pub fn new(sql: &str) -> Self {
367 let mut hasher = std::collections::hash_map::DefaultHasher::new();
368 sql.hash(&mut hasher);
369 Self(hasher.finish())
370 }
371
372 #[inline]
374 pub fn value(&self) -> u64 {
375 self.0
376 }
377}
378
379pub mod patterns {
381 use super::QueryKey;
382
383 #[inline]
385 pub fn select_by_id(table: &str) -> QueryKey {
386 QueryKey::owned(format!("select_by_id:{}", table))
387 }
388
389 #[inline]
391 pub fn select_all(table: &str) -> QueryKey {
392 QueryKey::owned(format!("select_all:{}", table))
393 }
394
395 #[inline]
397 pub fn insert(table: &str, columns: usize) -> QueryKey {
398 QueryKey::owned(format!("insert:{}:{}", table, columns))
399 }
400
401 #[inline]
403 pub fn update_by_id(table: &str, columns: usize) -> QueryKey {
404 QueryKey::owned(format!("update_by_id:{}:{}", table, columns))
405 }
406
407 #[inline]
409 pub fn delete_by_id(table: &str) -> QueryKey {
410 QueryKey::owned(format!("delete_by_id:{}", table))
411 }
412
413 #[inline]
415 pub fn count(table: &str) -> QueryKey {
416 QueryKey::owned(format!("count:{}", table))
417 }
418
419 #[inline]
421 pub fn count_filtered(table: &str, filter_hash: u64) -> QueryKey {
422 QueryKey::owned(format!("count:{}:{}", table, filter_hash))
423 }
424}
425
426#[derive(Debug)]
456pub struct SqlTemplateCache {
457 max_size: usize,
459 templates: parking_lot::RwLock<HashMap<u64, Arc<SqlTemplate>>>,
461 key_index: parking_lot::RwLock<HashMap<Cow<'static, str>, u64>>,
463 stats: parking_lot::RwLock<CacheStats>,
465}
466
467#[derive(Debug)]
469pub struct SqlTemplate {
470 pub sql: Arc<str>,
472 pub hash: u64,
474 pub param_count: usize,
476 last_access: std::sync::atomic::AtomicU64,
478}
479
480impl Clone for SqlTemplate {
481 fn clone(&self) -> Self {
482 use std::sync::atomic::Ordering;
483 Self {
484 sql: Arc::clone(&self.sql),
485 hash: self.hash,
486 param_count: self.param_count,
487 last_access: std::sync::atomic::AtomicU64::new(
488 self.last_access.load(Ordering::Relaxed),
489 ),
490 }
491 }
492}
493
494impl SqlTemplate {
495 pub fn new(sql: impl AsRef<str>) -> Self {
497 let sql_str = sql.as_ref();
498 let param_count = count_placeholders(sql_str);
499 let hash = {
500 let mut hasher = std::collections::hash_map::DefaultHasher::new();
501 sql_str.hash(&mut hasher);
502 hasher.finish()
503 };
504
505 Self {
506 sql: Arc::from(sql_str),
507 hash,
508 param_count,
509 last_access: std::sync::atomic::AtomicU64::new(0),
510 }
511 }
512
513 #[inline(always)]
515 pub fn sql(&self) -> &str {
516 &self.sql
517 }
518
519 #[inline(always)]
521 pub fn sql_arc(&self) -> Arc<str> {
522 Arc::clone(&self.sql)
523 }
524
525 #[inline]
527 fn touch(&self) {
528 use std::sync::atomic::Ordering;
529 use std::time::{SystemTime, UNIX_EPOCH};
530 let now = SystemTime::now()
531 .duration_since(UNIX_EPOCH)
532 .map(|d| d.as_secs())
533 .unwrap_or(0);
534 self.last_access.store(now, Ordering::Relaxed);
535 }
536}
537
538impl SqlTemplateCache {
539 pub fn new(max_size: usize) -> Self {
541 tracing::info!(max_size, "SqlTemplateCache initialized");
542 Self {
543 max_size,
544 templates: parking_lot::RwLock::new(HashMap::with_capacity(max_size)),
545 key_index: parking_lot::RwLock::new(HashMap::with_capacity(max_size)),
546 stats: parking_lot::RwLock::new(CacheStats::default()),
547 }
548 }
549
550 #[inline]
554 pub fn register(
555 &self,
556 key: impl Into<Cow<'static, str>>,
557 sql: impl AsRef<str>,
558 ) -> Arc<SqlTemplate> {
559 let key = key.into();
560 let template = Arc::new(SqlTemplate::new(sql));
561 let hash = template.hash;
562
563 let mut templates = self.templates.write();
564 let mut key_index = self.key_index.write();
565 let mut stats = self.stats.write();
566
567 if templates.len() >= self.max_size {
569 self.evict_lru_internal(&mut templates, &mut key_index);
570 stats.evictions += 1;
571 }
572
573 key_index.insert(key, hash);
574 templates.insert(hash, Arc::clone(&template));
575 stats.insertions += 1;
576
577 debug!(hash, "SqlTemplateCache::register()");
578 template
579 }
580
581 #[inline]
583 pub fn register_by_hash(&self, hash: u64, sql: impl AsRef<str>) -> Arc<SqlTemplate> {
584 let template = Arc::new(SqlTemplate::new(sql));
585
586 let mut templates = self.templates.write();
587 let mut stats = self.stats.write();
588
589 if templates.len() >= self.max_size {
590 let mut key_index = self.key_index.write();
591 self.evict_lru_internal(&mut templates, &mut key_index);
592 stats.evictions += 1;
593 }
594
595 templates.insert(hash, Arc::clone(&template));
596 stats.insertions += 1;
597
598 template
599 }
600
601 #[inline]
609 pub fn get(&self, key: &str) -> Option<Arc<SqlTemplate>> {
610 let hash = {
611 let key_index = self.key_index.read();
612 match key_index.get(key) {
613 Some(&h) => h,
614 None => {
615 drop(key_index); let mut stats = self.stats.write();
617 stats.misses += 1;
618 return None;
619 }
620 }
621 };
622
623 let templates = self.templates.read();
624 if let Some(template) = templates.get(&hash) {
625 template.touch();
626 let mut stats = self.stats.write();
627 stats.hits += 1;
628 return Some(Arc::clone(template));
629 }
630
631 let mut stats = self.stats.write();
632 stats.misses += 1;
633 None
634 }
635
636 #[inline(always)]
642 pub fn get_by_hash(&self, hash: u64) -> Option<Arc<SqlTemplate>> {
643 let templates = self.templates.read();
644 if let Some(template) = templates.get(&hash) {
645 template.touch();
646 return Some(Arc::clone(template));
648 }
649 None
650 }
651
652 #[inline]
654 pub fn get_sql(&self, key: &str) -> Option<Arc<str>> {
655 self.get(key).map(|t| t.sql_arc())
656 }
657
658 #[inline]
660 pub fn get_or_register<F>(&self, key: impl Into<Cow<'static, str>>, f: F) -> Arc<SqlTemplate>
661 where
662 F: FnOnce() -> String,
663 {
664 let key = key.into();
665
666 if let Some(template) = self.get(&key) {
668 return template;
669 }
670
671 let sql = f();
673 self.register(key, sql)
674 }
675
676 #[inline]
678 pub fn contains(&self, key: &str) -> bool {
679 let key_index = self.key_index.read();
680 key_index.contains_key(key)
681 }
682
683 pub fn stats(&self) -> CacheStats {
685 self.stats.read().clone()
686 }
687
688 pub fn len(&self) -> usize {
690 self.templates.read().len()
691 }
692
693 pub fn is_empty(&self) -> bool {
695 self.len() == 0
696 }
697
698 pub fn clear(&self) {
700 self.templates.write().clear();
701 self.key_index.write().clear();
702 }
703
704 fn evict_lru_internal(
706 &self,
707 templates: &mut HashMap<u64, Arc<SqlTemplate>>,
708 key_index: &mut HashMap<Cow<'static, str>, u64>,
709 ) {
710 use std::sync::atomic::Ordering;
711
712 let to_evict = templates.len() / 4;
713 if to_evict == 0 {
714 return;
715 }
716
717 let mut entries: Vec<_> = templates
719 .iter()
720 .map(|(&hash, t)| (hash, t.last_access.load(Ordering::Relaxed)))
721 .collect();
722 entries.sort_by_key(|(_, time)| *time);
723
724 for (hash, _) in entries.into_iter().take(to_evict) {
726 templates.remove(&hash);
727 key_index.retain(|_, h| *h != hash);
729 }
730 }
731}
732
733impl Default for SqlTemplateCache {
734 fn default() -> Self {
735 Self::new(1000)
736 }
737}
738
739static GLOBAL_TEMPLATE_CACHE: std::sync::OnceLock<SqlTemplateCache> = std::sync::OnceLock::new();
762
763#[inline(always)]
765pub fn global_template_cache() -> &'static SqlTemplateCache {
766 GLOBAL_TEMPLATE_CACHE.get_or_init(|| SqlTemplateCache::new(10000))
767}
768
769#[inline]
771pub fn register_global_template(
772 key: impl Into<Cow<'static, str>>,
773 sql: impl AsRef<str>,
774) -> Arc<SqlTemplate> {
775 global_template_cache().register(key, sql)
776}
777
778#[inline(always)]
780pub fn get_global_template(key: &str) -> Option<Arc<SqlTemplate>> {
781 global_template_cache().get(key)
782}
783
784#[inline]
790pub fn precompute_query_hash(key: &str) -> u64 {
791 let mut hasher = std::collections::hash_map::DefaultHasher::new();
792 key.hash(&mut hasher);
793 hasher.finish()
794}
795
796#[cfg(test)]
797mod tests {
798 use super::*;
799
800 #[test]
801 fn test_query_cache_basic() {
802 let cache = QueryCache::new(10);
803
804 cache.insert("users_by_id", "SELECT * FROM users WHERE id = $1");
805 assert!(cache.contains("users_by_id"));
806
807 let sql = cache.get("users_by_id");
808 assert_eq!(sql, Some("SELECT * FROM users WHERE id = $1".to_string()));
809 }
810
811 #[test]
812 fn test_query_cache_get_or_insert() {
813 let cache = QueryCache::new(10);
814
815 let sql1 = cache.get_or_insert("test", || "SELECT 1".to_string());
816 assert_eq!(sql1, "SELECT 1");
817
818 let sql2 = cache.get_or_insert("test", || "SELECT 2".to_string());
819 assert_eq!(sql2, "SELECT 1"); }
821
822 #[test]
823 fn test_query_cache_stats() {
824 let cache = QueryCache::new(10);
825
826 cache.insert("test", "SELECT 1");
827 cache.get("test"); cache.get("test"); cache.get("missing"); let stats = cache.stats();
832 assert_eq!(stats.hits, 2);
833 assert_eq!(stats.misses, 1);
834 assert_eq!(stats.insertions, 1);
835 }
836
837 #[test]
838 fn test_count_placeholders_postgres() {
839 assert_eq!(count_placeholders("SELECT * FROM users WHERE id = $1"), 1);
840 assert_eq!(
841 count_placeholders("SELECT * FROM users WHERE id = $1 AND name = $2"),
842 2
843 );
844 assert_eq!(count_placeholders("SELECT * FROM users WHERE id = $10"), 10);
845 }
846
847 #[test]
848 fn test_count_placeholders_mysql() {
849 assert_eq!(count_placeholders("SELECT * FROM users WHERE id = ?"), 1);
850 assert_eq!(
851 count_placeholders("SELECT * FROM users WHERE id = ? AND name = ?"),
852 2
853 );
854 }
855
856 #[test]
857 fn test_query_hash() {
858 let hash1 = QueryHash::new("SELECT * FROM users");
859 let hash2 = QueryHash::new("SELECT * FROM users");
860 let hash3 = QueryHash::new("SELECT * FROM posts");
861
862 assert_eq!(hash1, hash2);
863 assert_ne!(hash1, hash3);
864 }
865
866 #[test]
867 fn test_patterns() {
868 let key = patterns::select_by_id("users");
869 assert!(key.key.starts_with("select_by_id:"));
870 }
871
872 #[test]
877 fn test_sql_template_cache_basic() {
878 let cache = SqlTemplateCache::new(100);
879
880 let template = cache.register("users_by_id", "SELECT * FROM users WHERE id = $1");
881 assert_eq!(template.sql(), "SELECT * FROM users WHERE id = $1");
882 assert_eq!(template.param_count, 1);
883 }
884
885 #[test]
886 fn test_sql_template_cache_get() {
887 let cache = SqlTemplateCache::new(100);
888
889 cache.register("test_query", "SELECT * FROM test WHERE x = $1");
890
891 let result = cache.get("test_query");
892 assert!(result.is_some());
893 assert_eq!(result.unwrap().sql(), "SELECT * FROM test WHERE x = $1");
894
895 let missing = cache.get("nonexistent");
896 assert!(missing.is_none());
897 }
898
899 #[test]
900 fn test_sql_template_cache_get_by_hash() {
901 let cache = SqlTemplateCache::new(100);
902
903 let template = cache.register("fast_query", "SELECT 1");
904 let hash = template.hash;
905
906 let result = cache.get_by_hash(hash);
908 assert!(result.is_some());
909 assert_eq!(result.unwrap().sql(), "SELECT 1");
910 }
911
912 #[test]
913 fn test_sql_template_cache_get_or_register() {
914 let cache = SqlTemplateCache::new(100);
915
916 let t1 = cache.get_or_register("computed", || "SELECT * FROM computed".to_string());
917 assert_eq!(t1.sql(), "SELECT * FROM computed");
918
919 let t2 = cache.get_or_register("computed", || panic!("Should not be called"));
921 assert_eq!(t2.sql(), "SELECT * FROM computed");
922 assert_eq!(t1.hash, t2.hash);
923 }
924
925 #[test]
926 fn test_sql_template_cache_stats() {
927 let cache = SqlTemplateCache::new(100);
928
929 cache.register("q1", "SELECT 1");
930 cache.get("q1"); cache.get("q1"); cache.get("missing"); let stats = cache.stats();
935 assert_eq!(stats.hits, 2);
936 assert_eq!(stats.misses, 1);
937 assert_eq!(stats.insertions, 1);
938 }
939
940 #[test]
941 fn test_global_template_cache() {
942 let template = register_global_template("global_test", "SELECT * FROM global");
944 assert_eq!(template.sql(), "SELECT * FROM global");
945
946 let result = get_global_template("global_test");
948 assert!(result.is_some());
949 assert_eq!(result.unwrap().sql(), "SELECT * FROM global");
950 }
951
952 #[test]
953 fn test_precompute_query_hash() {
954 let hash1 = precompute_query_hash("test_key");
955 let hash2 = precompute_query_hash("test_key");
956 let hash3 = precompute_query_hash("other_key");
957
958 assert_eq!(hash1, hash2);
959 assert_ne!(hash1, hash3);
960 }
961
962 #[test]
963 fn test_execution_plan_cache() {
964 let cache = ExecutionPlanCache::new(100);
965
966 let plan = cache.register(
968 "users_by_email",
969 "SELECT * FROM users WHERE email = $1",
970 PlanHint::IndexScan("users_email_idx".into()),
971 );
972 assert_eq!(plan.sql.as_ref(), "SELECT * FROM users WHERE email = $1");
973
974 let result = cache.get("users_by_email");
976 assert!(result.is_some());
977 assert!(matches!(
978 result.unwrap().hint,
979 PlanHint::IndexScan(_)
980 ));
981 }
982}
983
984#[derive(Debug, Clone)]
994pub enum PlanHint {
995 None,
997 IndexScan(String),
999 SeqScan,
1001 Parallel(u32),
1003 CachePlan,
1005 Timeout(std::time::Duration),
1007 Custom(String),
1009}
1010
1011impl Default for PlanHint {
1012 fn default() -> Self {
1013 Self::None
1014 }
1015}
1016
1017#[derive(Debug)]
1019pub struct ExecutionPlan {
1020 pub sql: Arc<str>,
1022 pub hash: u64,
1024 pub hint: PlanHint,
1026 pub estimated_cost: Option<f64>,
1028 use_count: std::sync::atomic::AtomicU64,
1030 avg_execution_us: std::sync::atomic::AtomicU64,
1032}
1033
1034fn compute_hash(s: &str) -> u64 {
1036 let mut hasher = std::collections::hash_map::DefaultHasher::new();
1037 s.hash(&mut hasher);
1038 hasher.finish()
1039}
1040
1041impl ExecutionPlan {
1042 pub fn new(sql: impl AsRef<str>, hint: PlanHint) -> Self {
1044 let sql_str = sql.as_ref();
1045 Self {
1046 sql: Arc::from(sql_str),
1047 hash: compute_hash(sql_str),
1048 hint,
1049 estimated_cost: None,
1050 use_count: std::sync::atomic::AtomicU64::new(0),
1051 avg_execution_us: std::sync::atomic::AtomicU64::new(0),
1052 }
1053 }
1054
1055 pub fn with_cost(sql: impl AsRef<str>, hint: PlanHint, cost: f64) -> Self {
1057 let sql_str = sql.as_ref();
1058 Self {
1059 sql: Arc::from(sql_str),
1060 hash: compute_hash(sql_str),
1061 hint,
1062 estimated_cost: Some(cost),
1063 use_count: std::sync::atomic::AtomicU64::new(0),
1064 avg_execution_us: std::sync::atomic::AtomicU64::new(0),
1065 }
1066 }
1067
1068 pub fn record_execution(&self, duration_us: u64) {
1070 let old_count = self
1071 .use_count
1072 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1073 let old_avg = self
1074 .avg_execution_us
1075 .load(std::sync::atomic::Ordering::Relaxed);
1076
1077 let new_avg = if old_count == 0 {
1079 duration_us
1080 } else {
1081 (old_avg * old_count + duration_us) / (old_count + 1)
1083 };
1084
1085 self.avg_execution_us
1086 .store(new_avg, std::sync::atomic::Ordering::Relaxed);
1087 }
1088
1089 pub fn use_count(&self) -> u64 {
1091 self.use_count.load(std::sync::atomic::Ordering::Relaxed)
1092 }
1093
1094 pub fn avg_execution_us(&self) -> u64 {
1096 self.avg_execution_us
1097 .load(std::sync::atomic::Ordering::Relaxed)
1098 }
1099}
1100
1101#[derive(Debug)]
1126pub struct ExecutionPlanCache {
1127 max_size: usize,
1129 plans: parking_lot::RwLock<HashMap<u64, Arc<ExecutionPlan>>>,
1131 key_index: parking_lot::RwLock<HashMap<Cow<'static, str>, u64>>,
1133}
1134
1135impl ExecutionPlanCache {
1136 pub fn new(max_size: usize) -> Self {
1138 Self {
1139 max_size,
1140 plans: parking_lot::RwLock::new(HashMap::with_capacity(max_size / 2)),
1141 key_index: parking_lot::RwLock::new(HashMap::with_capacity(max_size / 2)),
1142 }
1143 }
1144
1145 pub fn register(
1147 &self,
1148 key: impl Into<Cow<'static, str>>,
1149 sql: impl AsRef<str>,
1150 hint: PlanHint,
1151 ) -> Arc<ExecutionPlan> {
1152 let key = key.into();
1153 let plan = Arc::new(ExecutionPlan::new(sql, hint));
1154 let hash = plan.hash;
1155
1156 let mut plans = self.plans.write();
1157 let mut key_index = self.key_index.write();
1158
1159 if plans.len() >= self.max_size && !plans.contains_key(&hash) {
1161 if let Some((&evict_hash, _)) = plans.iter().min_by_key(|(_, p)| p.use_count()) {
1163 plans.remove(&evict_hash);
1164 key_index.retain(|_, &mut v| v != evict_hash);
1165 }
1166 }
1167
1168 plans.insert(hash, Arc::clone(&plan));
1169 key_index.insert(key, hash);
1170
1171 plan
1172 }
1173
1174 pub fn register_with_cost(
1176 &self,
1177 key: impl Into<Cow<'static, str>>,
1178 sql: impl AsRef<str>,
1179 hint: PlanHint,
1180 cost: f64,
1181 ) -> Arc<ExecutionPlan> {
1182 let key = key.into();
1183 let plan = Arc::new(ExecutionPlan::with_cost(sql, hint, cost));
1184 let hash = plan.hash;
1185
1186 let mut plans = self.plans.write();
1187 let mut key_index = self.key_index.write();
1188
1189 if plans.len() >= self.max_size && !plans.contains_key(&hash) {
1190 if let Some((&evict_hash, _)) = plans.iter().min_by_key(|(_, p)| p.use_count()) {
1191 plans.remove(&evict_hash);
1192 key_index.retain(|_, &mut v| v != evict_hash);
1193 }
1194 }
1195
1196 plans.insert(hash, Arc::clone(&plan));
1197 key_index.insert(key, hash);
1198
1199 plan
1200 }
1201
1202 pub fn get(&self, key: &str) -> Option<Arc<ExecutionPlan>> {
1204 let hash = {
1205 let key_index = self.key_index.read();
1206 *key_index.get(key)?
1207 };
1208
1209 self.plans.read().get(&hash).cloned()
1210 }
1211
1212 pub fn get_by_hash(&self, hash: u64) -> Option<Arc<ExecutionPlan>> {
1214 self.plans.read().get(&hash).cloned()
1215 }
1216
1217 pub fn get_or_register<F>(
1219 &self,
1220 key: impl Into<Cow<'static, str>>,
1221 sql_fn: F,
1222 hint: PlanHint,
1223 ) -> Arc<ExecutionPlan>
1224 where
1225 F: FnOnce() -> String,
1226 {
1227 let key = key.into();
1228
1229 if let Some(plan) = self.get(key.as_ref()) {
1231 return plan;
1232 }
1233
1234 self.register(key, sql_fn(), hint)
1236 }
1237
1238 pub fn record_execution(&self, key: &str, duration_us: u64) {
1240 if let Some(plan) = self.get(key) {
1241 plan.record_execution(duration_us);
1242 }
1243 }
1244
1245 pub fn slowest_queries(&self, limit: usize) -> Vec<Arc<ExecutionPlan>> {
1247 let plans = self.plans.read();
1248 let mut sorted: Vec<_> = plans.values().cloned().collect();
1249 sorted.sort_by(|a, b| b.avg_execution_us().cmp(&a.avg_execution_us()));
1250 sorted.truncate(limit);
1251 sorted
1252 }
1253
1254 pub fn most_used(&self, limit: usize) -> Vec<Arc<ExecutionPlan>> {
1256 let plans = self.plans.read();
1257 let mut sorted: Vec<_> = plans.values().cloned().collect();
1258 sorted.sort_by(|a, b| b.use_count().cmp(&a.use_count()));
1259 sorted.truncate(limit);
1260 sorted
1261 }
1262
1263 pub fn clear(&self) {
1265 self.plans.write().clear();
1266 self.key_index.write().clear();
1267 }
1268
1269 pub fn len(&self) -> usize {
1271 self.plans.read().len()
1272 }
1273
1274 pub fn is_empty(&self) -> bool {
1276 self.plans.read().is_empty()
1277 }
1278}