1use std::borrow::Cow;
26use std::collections::HashMap;
27use std::hash::{Hash, Hasher};
28use std::sync::{Arc, RwLock};
29use tracing::debug;
30
31#[derive(Debug)]
35pub struct QueryCache {
36 max_size: usize,
38 cache: RwLock<HashMap<QueryKey, CachedQuery>>,
40 stats: RwLock<CacheStats>,
42}
43
44#[derive(Debug, Clone, PartialEq, Eq, Hash)]
46pub struct QueryKey {
47 key: Cow<'static, str>,
49}
50
51impl QueryKey {
52 #[inline]
54 pub const fn new(key: &'static str) -> Self {
55 Self {
56 key: Cow::Borrowed(key),
57 }
58 }
59
60 #[inline]
62 pub fn owned(key: String) -> Self {
63 Self {
64 key: Cow::Owned(key),
65 }
66 }
67}
68
69impl From<&'static str> for QueryKey {
70 fn from(s: &'static str) -> Self {
71 Self::new(s)
72 }
73}
74
75impl From<String> for QueryKey {
76 fn from(s: String) -> Self {
77 Self::owned(s)
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct CachedQuery {
84 pub sql: String,
86 pub param_count: usize,
88 access_count: u64,
90}
91
92impl CachedQuery {
93 pub fn new(sql: impl Into<String>, param_count: usize) -> Self {
95 Self {
96 sql: sql.into(),
97 param_count,
98 access_count: 0,
99 }
100 }
101
102 #[inline]
104 pub fn sql(&self) -> &str {
105 &self.sql
106 }
107
108 #[inline]
110 pub fn param_count(&self) -> usize {
111 self.param_count
112 }
113}
114
115#[derive(Debug, Default, Clone)]
117pub struct CacheStats {
118 pub hits: u64,
120 pub misses: u64,
122 pub evictions: u64,
124 pub insertions: u64,
126}
127
128impl CacheStats {
129 #[inline]
131 pub fn hit_rate(&self) -> f64 {
132 let total = self.hits + self.misses;
133 if total == 0 {
134 0.0
135 } else {
136 self.hits as f64 / total as f64
137 }
138 }
139}
140
141impl QueryCache {
142 pub fn new(max_size: usize) -> Self {
144 tracing::info!(max_size, "QueryCache initialized");
145 Self {
146 max_size,
147 cache: RwLock::new(HashMap::with_capacity(max_size)),
148 stats: RwLock::new(CacheStats::default()),
149 }
150 }
151
152 pub fn insert(&self, key: impl Into<QueryKey>, sql: impl Into<String>) {
154 let key = key.into();
155 let sql = sql.into();
156 let param_count = count_placeholders(&sql);
157 debug!(key = ?key.key, sql_len = sql.len(), param_count, "QueryCache::insert()");
158
159 let mut cache = self.cache.write().unwrap();
160 let mut stats = self.stats.write().unwrap();
161
162 if cache.len() >= self.max_size && !cache.contains_key(&key) {
164 self.evict_lru(&mut cache);
165 stats.evictions += 1;
166 debug!("QueryCache evicted entry");
167 }
168
169 cache.insert(key, CachedQuery::new(sql, param_count));
170 stats.insertions += 1;
171 }
172
173 pub fn insert_with_params(
175 &self,
176 key: impl Into<QueryKey>,
177 sql: impl Into<String>,
178 param_count: usize,
179 ) {
180 let key = key.into();
181 let sql = sql.into();
182
183 let mut cache = self.cache.write().unwrap();
184 let mut stats = self.stats.write().unwrap();
185
186 if cache.len() >= self.max_size && !cache.contains_key(&key) {
188 self.evict_lru(&mut cache);
189 stats.evictions += 1;
190 }
191
192 cache.insert(key, CachedQuery::new(sql, param_count));
193 stats.insertions += 1;
194 }
195
196 pub fn get(&self, key: impl Into<QueryKey>) -> Option<String> {
198 let key = key.into();
199
200 {
202 let cache = self.cache.read().unwrap();
203 if let Some(entry) = cache.get(&key) {
204 let mut stats = self.stats.write().unwrap();
205 stats.hits += 1;
206 debug!(key = ?key.key, "QueryCache hit");
207 return Some(entry.sql.clone());
208 }
209 }
210
211 let mut stats = self.stats.write().unwrap();
212 stats.misses += 1;
213 debug!(key = ?key.key, "QueryCache miss");
214 None
215 }
216
217 pub fn get_entry(&self, key: impl Into<QueryKey>) -> Option<CachedQuery> {
219 let key = key.into();
220
221 let cache = self.cache.read().unwrap();
222 if let Some(entry) = cache.get(&key) {
223 let mut stats = self.stats.write().unwrap();
224 stats.hits += 1;
225 return Some(entry.clone());
226 }
227
228 let mut stats = self.stats.write().unwrap();
229 stats.misses += 1;
230 None
231 }
232
233 pub fn get_or_insert<F>(&self, key: impl Into<QueryKey>, f: F) -> String
238 where
239 F: FnOnce() -> String,
240 {
241 let key = key.into();
242
243 if let Some(sql) = self.get(key.clone()) {
245 return sql;
246 }
247
248 let sql = f();
250 self.insert(key, sql.clone());
251 sql
252 }
253
254 pub fn contains(&self, key: impl Into<QueryKey>) -> bool {
256 let key = key.into();
257 let cache = self.cache.read().unwrap();
258 cache.contains_key(&key)
259 }
260
261 pub fn remove(&self, key: impl Into<QueryKey>) -> Option<String> {
263 let key = key.into();
264 let mut cache = self.cache.write().unwrap();
265 cache.remove(&key).map(|e| e.sql)
266 }
267
268 pub fn clear(&self) {
270 let mut cache = self.cache.write().unwrap();
271 cache.clear();
272 }
273
274 pub fn len(&self) -> usize {
276 let cache = self.cache.read().unwrap();
277 cache.len()
278 }
279
280 pub fn is_empty(&self) -> bool {
282 self.len() == 0
283 }
284
285 pub fn max_size(&self) -> usize {
287 self.max_size
288 }
289
290 pub fn stats(&self) -> CacheStats {
292 let stats = self.stats.read().unwrap();
293 stats.clone()
294 }
295
296 pub fn reset_stats(&self) {
298 let mut stats = self.stats.write().unwrap();
299 *stats = CacheStats::default();
300 }
301
302 fn evict_lru(&self, cache: &mut HashMap<QueryKey, CachedQuery>) {
304 let to_evict = cache.len() / 4; if to_evict == 0 {
308 return;
309 }
310
311 let mut entries: Vec<_> = cache
312 .iter()
313 .map(|(k, v)| (k.clone(), v.access_count))
314 .collect();
315 entries.sort_by_key(|(_, count)| *count);
316
317 for (key, _) in entries.into_iter().take(to_evict) {
318 cache.remove(&key);
319 }
320 }
321}
322
323impl Default for QueryCache {
324 fn default() -> Self {
325 Self::new(1000)
326 }
327}
328
329fn count_placeholders(sql: &str) -> usize {
331 let mut count = 0;
332 let mut chars = sql.chars().peekable();
333
334 while let Some(c) = chars.next() {
335 if c == '$' {
336 let mut num = String::new();
338 while let Some(&d) = chars.peek() {
339 if d.is_ascii_digit() {
340 num.push(d);
341 chars.next();
342 } else {
343 break;
344 }
345 }
346 if !num.is_empty() {
347 if let Ok(n) = num.parse::<usize>() {
348 count = count.max(n);
349 }
350 }
351 } else if c == '?' {
352 count += 1;
354 }
355 }
356
357 count
358}
359
360#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
362pub struct QueryHash(u64);
363
364impl QueryHash {
365 pub fn new(sql: &str) -> Self {
367 let mut hasher = std::collections::hash_map::DefaultHasher::new();
368 sql.hash(&mut hasher);
369 Self(hasher.finish())
370 }
371
372 #[inline]
374 pub fn value(&self) -> u64 {
375 self.0
376 }
377}
378
379pub mod patterns {
381 use super::QueryKey;
382
383 #[inline]
385 pub fn select_by_id(table: &str) -> QueryKey {
386 QueryKey::owned(format!("select_by_id:{}", table))
387 }
388
389 #[inline]
391 pub fn select_all(table: &str) -> QueryKey {
392 QueryKey::owned(format!("select_all:{}", table))
393 }
394
395 #[inline]
397 pub fn insert(table: &str, columns: usize) -> QueryKey {
398 QueryKey::owned(format!("insert:{}:{}", table, columns))
399 }
400
401 #[inline]
403 pub fn update_by_id(table: &str, columns: usize) -> QueryKey {
404 QueryKey::owned(format!("update_by_id:{}:{}", table, columns))
405 }
406
407 #[inline]
409 pub fn delete_by_id(table: &str) -> QueryKey {
410 QueryKey::owned(format!("delete_by_id:{}", table))
411 }
412
413 #[inline]
415 pub fn count(table: &str) -> QueryKey {
416 QueryKey::owned(format!("count:{}", table))
417 }
418
419 #[inline]
421 pub fn count_filtered(table: &str, filter_hash: u64) -> QueryKey {
422 QueryKey::owned(format!("count:{}:{}", table, filter_hash))
423 }
424}
425
426#[derive(Debug)]
456pub struct SqlTemplateCache {
457 max_size: usize,
459 templates: parking_lot::RwLock<HashMap<u64, Arc<SqlTemplate>>>,
461 key_index: parking_lot::RwLock<HashMap<Cow<'static, str>, u64>>,
463 stats: parking_lot::RwLock<CacheStats>,
465}
466
467#[derive(Debug)]
469pub struct SqlTemplate {
470 pub sql: Arc<str>,
472 pub hash: u64,
474 pub param_count: usize,
476 last_access: std::sync::atomic::AtomicU64,
478}
479
480impl Clone for SqlTemplate {
481 fn clone(&self) -> Self {
482 use std::sync::atomic::Ordering;
483 Self {
484 sql: Arc::clone(&self.sql),
485 hash: self.hash,
486 param_count: self.param_count,
487 last_access: std::sync::atomic::AtomicU64::new(
488 self.last_access.load(Ordering::Relaxed),
489 ),
490 }
491 }
492}
493
494impl SqlTemplate {
495 pub fn new(sql: impl AsRef<str>) -> Self {
497 let sql_str = sql.as_ref();
498 let param_count = count_placeholders(sql_str);
499 let hash = {
500 let mut hasher = std::collections::hash_map::DefaultHasher::new();
501 sql_str.hash(&mut hasher);
502 hasher.finish()
503 };
504
505 Self {
506 sql: Arc::from(sql_str),
507 hash,
508 param_count,
509 last_access: std::sync::atomic::AtomicU64::new(0),
510 }
511 }
512
513 #[inline(always)]
515 pub fn sql(&self) -> &str {
516 &self.sql
517 }
518
519 #[inline(always)]
521 pub fn sql_arc(&self) -> Arc<str> {
522 Arc::clone(&self.sql)
523 }
524
525 #[inline]
527 fn touch(&self) {
528 use std::sync::atomic::Ordering;
529 use std::time::{SystemTime, UNIX_EPOCH};
530 let now = SystemTime::now()
531 .duration_since(UNIX_EPOCH)
532 .map(|d| d.as_secs())
533 .unwrap_or(0);
534 self.last_access.store(now, Ordering::Relaxed);
535 }
536}
537
538impl SqlTemplateCache {
539 pub fn new(max_size: usize) -> Self {
541 tracing::info!(max_size, "SqlTemplateCache initialized");
542 Self {
543 max_size,
544 templates: parking_lot::RwLock::new(HashMap::with_capacity(max_size)),
545 key_index: parking_lot::RwLock::new(HashMap::with_capacity(max_size)),
546 stats: parking_lot::RwLock::new(CacheStats::default()),
547 }
548 }
549
550 #[inline]
554 pub fn register(
555 &self,
556 key: impl Into<Cow<'static, str>>,
557 sql: impl AsRef<str>,
558 ) -> Arc<SqlTemplate> {
559 let key = key.into();
560 let template = Arc::new(SqlTemplate::new(sql));
561 let hash = template.hash;
562
563 let mut templates = self.templates.write();
564 let mut key_index = self.key_index.write();
565 let mut stats = self.stats.write();
566
567 if templates.len() >= self.max_size {
569 self.evict_lru_internal(&mut templates, &mut key_index);
570 stats.evictions += 1;
571 }
572
573 key_index.insert(key, hash);
574 templates.insert(hash, Arc::clone(&template));
575 stats.insertions += 1;
576
577 debug!(hash, "SqlTemplateCache::register()");
578 template
579 }
580
581 #[inline]
583 pub fn register_by_hash(&self, hash: u64, sql: impl AsRef<str>) -> Arc<SqlTemplate> {
584 let template = Arc::new(SqlTemplate::new(sql));
585
586 let mut templates = self.templates.write();
587 let mut stats = self.stats.write();
588
589 if templates.len() >= self.max_size {
590 let mut key_index = self.key_index.write();
591 self.evict_lru_internal(&mut templates, &mut key_index);
592 stats.evictions += 1;
593 }
594
595 templates.insert(hash, Arc::clone(&template));
596 stats.insertions += 1;
597
598 template
599 }
600
601 #[inline]
609 pub fn get(&self, key: &str) -> Option<Arc<SqlTemplate>> {
610 let hash = {
611 let key_index = self.key_index.read();
612 match key_index.get(key) {
613 Some(&h) => h,
614 None => {
615 drop(key_index); let mut stats = self.stats.write();
617 stats.misses += 1;
618 return None;
619 }
620 }
621 };
622
623 let templates = self.templates.read();
624 if let Some(template) = templates.get(&hash) {
625 template.touch();
626 let mut stats = self.stats.write();
627 stats.hits += 1;
628 return Some(Arc::clone(template));
629 }
630
631 let mut stats = self.stats.write();
632 stats.misses += 1;
633 None
634 }
635
636 #[inline(always)]
642 pub fn get_by_hash(&self, hash: u64) -> Option<Arc<SqlTemplate>> {
643 let templates = self.templates.read();
644 if let Some(template) = templates.get(&hash) {
645 template.touch();
646 return Some(Arc::clone(template));
648 }
649 None
650 }
651
652 #[inline]
654 pub fn get_sql(&self, key: &str) -> Option<Arc<str>> {
655 self.get(key).map(|t| t.sql_arc())
656 }
657
658 #[inline]
660 pub fn get_or_register<F>(&self, key: impl Into<Cow<'static, str>>, f: F) -> Arc<SqlTemplate>
661 where
662 F: FnOnce() -> String,
663 {
664 let key = key.into();
665
666 if let Some(template) = self.get(&key) {
668 return template;
669 }
670
671 let sql = f();
673 self.register(key, sql)
674 }
675
676 #[inline]
678 pub fn contains(&self, key: &str) -> bool {
679 let key_index = self.key_index.read();
680 key_index.contains_key(key)
681 }
682
683 pub fn stats(&self) -> CacheStats {
685 self.stats.read().clone()
686 }
687
688 pub fn len(&self) -> usize {
690 self.templates.read().len()
691 }
692
693 pub fn is_empty(&self) -> bool {
695 self.len() == 0
696 }
697
698 pub fn clear(&self) {
700 self.templates.write().clear();
701 self.key_index.write().clear();
702 }
703
704 fn evict_lru_internal(
706 &self,
707 templates: &mut HashMap<u64, Arc<SqlTemplate>>,
708 key_index: &mut HashMap<Cow<'static, str>, u64>,
709 ) {
710 use std::sync::atomic::Ordering;
711
712 let to_evict = templates.len() / 4;
713 if to_evict == 0 {
714 return;
715 }
716
717 let mut entries: Vec<_> = templates
719 .iter()
720 .map(|(&hash, t)| (hash, t.last_access.load(Ordering::Relaxed)))
721 .collect();
722 entries.sort_by_key(|(_, time)| *time);
723
724 for (hash, _) in entries.into_iter().take(to_evict) {
726 templates.remove(&hash);
727 key_index.retain(|_, h| *h != hash);
729 }
730 }
731}
732
733impl Default for SqlTemplateCache {
734 fn default() -> Self {
735 Self::new(1000)
736 }
737}
738
739static GLOBAL_TEMPLATE_CACHE: std::sync::OnceLock<SqlTemplateCache> = std::sync::OnceLock::new();
762
763#[inline(always)]
765pub fn global_template_cache() -> &'static SqlTemplateCache {
766 GLOBAL_TEMPLATE_CACHE.get_or_init(|| SqlTemplateCache::new(10000))
767}
768
769#[inline]
771pub fn register_global_template(
772 key: impl Into<Cow<'static, str>>,
773 sql: impl AsRef<str>,
774) -> Arc<SqlTemplate> {
775 global_template_cache().register(key, sql)
776}
777
778#[inline(always)]
780pub fn get_global_template(key: &str) -> Option<Arc<SqlTemplate>> {
781 global_template_cache().get(key)
782}
783
784#[inline]
790pub fn precompute_query_hash(key: &str) -> u64 {
791 let mut hasher = std::collections::hash_map::DefaultHasher::new();
792 key.hash(&mut hasher);
793 hasher.finish()
794}
795
796#[cfg(test)]
797mod tests {
798 use super::*;
799
800 #[test]
801 fn test_query_cache_basic() {
802 let cache = QueryCache::new(10);
803
804 cache.insert("users_by_id", "SELECT * FROM users WHERE id = $1");
805 assert!(cache.contains("users_by_id"));
806
807 let sql = cache.get("users_by_id");
808 assert_eq!(sql, Some("SELECT * FROM users WHERE id = $1".to_string()));
809 }
810
811 #[test]
812 fn test_query_cache_get_or_insert() {
813 let cache = QueryCache::new(10);
814
815 let sql1 = cache.get_or_insert("test", || "SELECT 1".to_string());
816 assert_eq!(sql1, "SELECT 1");
817
818 let sql2 = cache.get_or_insert("test", || "SELECT 2".to_string());
819 assert_eq!(sql2, "SELECT 1"); }
821
822 #[test]
823 fn test_query_cache_stats() {
824 let cache = QueryCache::new(10);
825
826 cache.insert("test", "SELECT 1");
827 cache.get("test"); cache.get("test"); cache.get("missing"); let stats = cache.stats();
832 assert_eq!(stats.hits, 2);
833 assert_eq!(stats.misses, 1);
834 assert_eq!(stats.insertions, 1);
835 }
836
837 #[test]
838 fn test_count_placeholders_postgres() {
839 assert_eq!(count_placeholders("SELECT * FROM users WHERE id = $1"), 1);
840 assert_eq!(
841 count_placeholders("SELECT * FROM users WHERE id = $1 AND name = $2"),
842 2
843 );
844 assert_eq!(count_placeholders("SELECT * FROM users WHERE id = $10"), 10);
845 }
846
847 #[test]
848 fn test_count_placeholders_mysql() {
849 assert_eq!(count_placeholders("SELECT * FROM users WHERE id = ?"), 1);
850 assert_eq!(
851 count_placeholders("SELECT * FROM users WHERE id = ? AND name = ?"),
852 2
853 );
854 }
855
856 #[test]
857 fn test_query_hash() {
858 let hash1 = QueryHash::new("SELECT * FROM users");
859 let hash2 = QueryHash::new("SELECT * FROM users");
860 let hash3 = QueryHash::new("SELECT * FROM posts");
861
862 assert_eq!(hash1, hash2);
863 assert_ne!(hash1, hash3);
864 }
865
866 #[test]
867 fn test_patterns() {
868 let key = patterns::select_by_id("users");
869 assert!(key.key.starts_with("select_by_id:"));
870 }
871
872 #[test]
877 fn test_sql_template_cache_basic() {
878 let cache = SqlTemplateCache::new(100);
879
880 let template = cache.register("users_by_id", "SELECT * FROM users WHERE id = $1");
881 assert_eq!(template.sql(), "SELECT * FROM users WHERE id = $1");
882 assert_eq!(template.param_count, 1);
883 }
884
885 #[test]
886 fn test_sql_template_cache_get() {
887 let cache = SqlTemplateCache::new(100);
888
889 cache.register("test_query", "SELECT * FROM test WHERE x = $1");
890
891 let result = cache.get("test_query");
892 assert!(result.is_some());
893 assert_eq!(result.unwrap().sql(), "SELECT * FROM test WHERE x = $1");
894
895 let missing = cache.get("nonexistent");
896 assert!(missing.is_none());
897 }
898
899 #[test]
900 fn test_sql_template_cache_get_by_hash() {
901 let cache = SqlTemplateCache::new(100);
902
903 let template = cache.register("fast_query", "SELECT 1");
904 let hash = template.hash;
905
906 let result = cache.get_by_hash(hash);
908 assert!(result.is_some());
909 assert_eq!(result.unwrap().sql(), "SELECT 1");
910 }
911
912 #[test]
913 fn test_sql_template_cache_get_or_register() {
914 let cache = SqlTemplateCache::new(100);
915
916 let t1 = cache.get_or_register("computed", || "SELECT * FROM computed".to_string());
917 assert_eq!(t1.sql(), "SELECT * FROM computed");
918
919 let t2 = cache.get_or_register("computed", || panic!("Should not be called"));
921 assert_eq!(t2.sql(), "SELECT * FROM computed");
922 assert_eq!(t1.hash, t2.hash);
923 }
924
925 #[test]
926 fn test_sql_template_cache_stats() {
927 let cache = SqlTemplateCache::new(100);
928
929 cache.register("q1", "SELECT 1");
930 cache.get("q1"); cache.get("q1"); cache.get("missing"); let stats = cache.stats();
935 assert_eq!(stats.hits, 2);
936 assert_eq!(stats.misses, 1);
937 assert_eq!(stats.insertions, 1);
938 }
939
940 #[test]
941 fn test_global_template_cache() {
942 let template = register_global_template("global_test", "SELECT * FROM global");
944 assert_eq!(template.sql(), "SELECT * FROM global");
945
946 let result = get_global_template("global_test");
948 assert!(result.is_some());
949 assert_eq!(result.unwrap().sql(), "SELECT * FROM global");
950 }
951
952 #[test]
953 fn test_precompute_query_hash() {
954 let hash1 = precompute_query_hash("test_key");
955 let hash2 = precompute_query_hash("test_key");
956 let hash3 = precompute_query_hash("other_key");
957
958 assert_eq!(hash1, hash2);
959 assert_ne!(hash1, hash3);
960 }
961
962 #[test]
963 fn test_execution_plan_cache() {
964 let cache = ExecutionPlanCache::new(100);
965
966 let plan = cache.register(
968 "users_by_email",
969 "SELECT * FROM users WHERE email = $1",
970 PlanHint::IndexScan("users_email_idx".into()),
971 );
972 assert_eq!(plan.sql.as_ref(), "SELECT * FROM users WHERE email = $1");
973
974 let result = cache.get("users_by_email");
976 assert!(result.is_some());
977 assert!(matches!(result.unwrap().hint, PlanHint::IndexScan(_)));
978 }
979}
980
981#[derive(Debug, Clone, Default)]
991pub enum PlanHint {
992 #[default]
994 None,
995 IndexScan(String),
997 SeqScan,
999 Parallel(u32),
1001 CachePlan,
1003 Timeout(std::time::Duration),
1005 Custom(String),
1007}
1008
1009#[derive(Debug)]
1011pub struct ExecutionPlan {
1012 pub sql: Arc<str>,
1014 pub hash: u64,
1016 pub hint: PlanHint,
1018 pub estimated_cost: Option<f64>,
1020 use_count: std::sync::atomic::AtomicU64,
1022 avg_execution_us: std::sync::atomic::AtomicU64,
1024}
1025
1026fn compute_hash(s: &str) -> u64 {
1028 let mut hasher = std::collections::hash_map::DefaultHasher::new();
1029 s.hash(&mut hasher);
1030 hasher.finish()
1031}
1032
1033impl ExecutionPlan {
1034 pub fn new(sql: impl AsRef<str>, hint: PlanHint) -> Self {
1036 let sql_str = sql.as_ref();
1037 Self {
1038 sql: Arc::from(sql_str),
1039 hash: compute_hash(sql_str),
1040 hint,
1041 estimated_cost: None,
1042 use_count: std::sync::atomic::AtomicU64::new(0),
1043 avg_execution_us: std::sync::atomic::AtomicU64::new(0),
1044 }
1045 }
1046
1047 pub fn with_cost(sql: impl AsRef<str>, hint: PlanHint, cost: f64) -> Self {
1049 let sql_str = sql.as_ref();
1050 Self {
1051 sql: Arc::from(sql_str),
1052 hash: compute_hash(sql_str),
1053 hint,
1054 estimated_cost: Some(cost),
1055 use_count: std::sync::atomic::AtomicU64::new(0),
1056 avg_execution_us: std::sync::atomic::AtomicU64::new(0),
1057 }
1058 }
1059
1060 pub fn record_execution(&self, duration_us: u64) {
1062 let old_count = self
1063 .use_count
1064 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1065 let old_avg = self
1066 .avg_execution_us
1067 .load(std::sync::atomic::Ordering::Relaxed);
1068
1069 let new_avg = if old_count == 0 {
1071 duration_us
1072 } else {
1073 (old_avg * old_count + duration_us) / (old_count + 1)
1075 };
1076
1077 self.avg_execution_us
1078 .store(new_avg, std::sync::atomic::Ordering::Relaxed);
1079 }
1080
1081 pub fn use_count(&self) -> u64 {
1083 self.use_count.load(std::sync::atomic::Ordering::Relaxed)
1084 }
1085
1086 pub fn avg_execution_us(&self) -> u64 {
1088 self.avg_execution_us
1089 .load(std::sync::atomic::Ordering::Relaxed)
1090 }
1091}
1092
1093#[derive(Debug)]
1118pub struct ExecutionPlanCache {
1119 max_size: usize,
1121 plans: parking_lot::RwLock<HashMap<u64, Arc<ExecutionPlan>>>,
1123 key_index: parking_lot::RwLock<HashMap<Cow<'static, str>, u64>>,
1125}
1126
1127impl ExecutionPlanCache {
1128 pub fn new(max_size: usize) -> Self {
1130 Self {
1131 max_size,
1132 plans: parking_lot::RwLock::new(HashMap::with_capacity(max_size / 2)),
1133 key_index: parking_lot::RwLock::new(HashMap::with_capacity(max_size / 2)),
1134 }
1135 }
1136
1137 pub fn register(
1139 &self,
1140 key: impl Into<Cow<'static, str>>,
1141 sql: impl AsRef<str>,
1142 hint: PlanHint,
1143 ) -> Arc<ExecutionPlan> {
1144 let key = key.into();
1145 let plan = Arc::new(ExecutionPlan::new(sql, hint));
1146 let hash = plan.hash;
1147
1148 let mut plans = self.plans.write();
1149 let mut key_index = self.key_index.write();
1150
1151 if plans.len() >= self.max_size && !plans.contains_key(&hash) {
1153 if let Some((&evict_hash, _)) = plans.iter().min_by_key(|(_, p)| p.use_count()) {
1155 plans.remove(&evict_hash);
1156 key_index.retain(|_, &mut v| v != evict_hash);
1157 }
1158 }
1159
1160 plans.insert(hash, Arc::clone(&plan));
1161 key_index.insert(key, hash);
1162
1163 plan
1164 }
1165
1166 pub fn register_with_cost(
1168 &self,
1169 key: impl Into<Cow<'static, str>>,
1170 sql: impl AsRef<str>,
1171 hint: PlanHint,
1172 cost: f64,
1173 ) -> Arc<ExecutionPlan> {
1174 let key = key.into();
1175 let plan = Arc::new(ExecutionPlan::with_cost(sql, hint, cost));
1176 let hash = plan.hash;
1177
1178 let mut plans = self.plans.write();
1179 let mut key_index = self.key_index.write();
1180
1181 if plans.len() >= self.max_size && !plans.contains_key(&hash) {
1182 if let Some((&evict_hash, _)) = plans.iter().min_by_key(|(_, p)| p.use_count()) {
1183 plans.remove(&evict_hash);
1184 key_index.retain(|_, &mut v| v != evict_hash);
1185 }
1186 }
1187
1188 plans.insert(hash, Arc::clone(&plan));
1189 key_index.insert(key, hash);
1190
1191 plan
1192 }
1193
1194 pub fn get(&self, key: &str) -> Option<Arc<ExecutionPlan>> {
1196 let hash = {
1197 let key_index = self.key_index.read();
1198 *key_index.get(key)?
1199 };
1200
1201 self.plans.read().get(&hash).cloned()
1202 }
1203
1204 pub fn get_by_hash(&self, hash: u64) -> Option<Arc<ExecutionPlan>> {
1206 self.plans.read().get(&hash).cloned()
1207 }
1208
1209 pub fn get_or_register<F>(
1211 &self,
1212 key: impl Into<Cow<'static, str>>,
1213 sql_fn: F,
1214 hint: PlanHint,
1215 ) -> Arc<ExecutionPlan>
1216 where
1217 F: FnOnce() -> String,
1218 {
1219 let key = key.into();
1220
1221 if let Some(plan) = self.get(key.as_ref()) {
1223 return plan;
1224 }
1225
1226 self.register(key, sql_fn(), hint)
1228 }
1229
1230 pub fn record_execution(&self, key: &str, duration_us: u64) {
1232 if let Some(plan) = self.get(key) {
1233 plan.record_execution(duration_us);
1234 }
1235 }
1236
1237 pub fn slowest_queries(&self, limit: usize) -> Vec<Arc<ExecutionPlan>> {
1239 let plans = self.plans.read();
1240 let mut sorted: Vec<_> = plans.values().cloned().collect();
1241 sorted.sort_by_key(|a| std::cmp::Reverse(a.avg_execution_us()));
1242 sorted.truncate(limit);
1243 sorted
1244 }
1245
1246 pub fn most_used(&self, limit: usize) -> Vec<Arc<ExecutionPlan>> {
1248 let plans = self.plans.read();
1249 let mut sorted: Vec<_> = plans.values().cloned().collect();
1250 sorted.sort_by_key(|a| std::cmp::Reverse(a.use_count()));
1251 sorted.truncate(limit);
1252 sorted
1253 }
1254
1255 pub fn clear(&self) {
1257 self.plans.write().clear();
1258 self.key_index.write().clear();
1259 }
1260
1261 pub fn len(&self) -> usize {
1263 self.plans.read().len()
1264 }
1265
1266 pub fn is_empty(&self) -> bool {
1268 self.plans.read().is_empty()
1269 }
1270}