1use std::borrow::Cow;
26use std::collections::HashMap;
27use std::hash::{Hash, Hasher};
28use std::sync::{Arc, RwLock};
29use tracing::debug;
30
31#[derive(Debug)]
35pub struct QueryCache {
36 max_size: usize,
38 cache: RwLock<HashMap<QueryKey, CachedQuery>>,
40 stats: RwLock<CacheStats>,
42}
43
44#[derive(Debug, Clone, PartialEq, Eq, Hash)]
46pub struct QueryKey {
47 key: Cow<'static, str>,
49}
50
51impl QueryKey {
52 #[inline]
54 pub const fn new(key: &'static str) -> Self {
55 Self {
56 key: Cow::Borrowed(key),
57 }
58 }
59
60 #[inline]
62 pub fn owned(key: String) -> Self {
63 Self {
64 key: Cow::Owned(key),
65 }
66 }
67}
68
69impl From<&'static str> for QueryKey {
70 fn from(s: &'static str) -> Self {
71 Self::new(s)
72 }
73}
74
75impl From<String> for QueryKey {
76 fn from(s: String) -> Self {
77 Self::owned(s)
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct CachedQuery {
84 pub sql: String,
86 pub param_count: usize,
88 access_count: u64,
90}
91
92impl CachedQuery {
93 pub fn new(sql: impl Into<String>, param_count: usize) -> Self {
95 Self {
96 sql: sql.into(),
97 param_count,
98 access_count: 0,
99 }
100 }
101
102 #[inline]
104 pub fn sql(&self) -> &str {
105 &self.sql
106 }
107
108 #[inline]
110 pub fn param_count(&self) -> usize {
111 self.param_count
112 }
113}
114
115#[derive(Debug, Default, Clone)]
117pub struct CacheStats {
118 pub hits: u64,
120 pub misses: u64,
122 pub evictions: u64,
124 pub insertions: u64,
126}
127
128impl CacheStats {
129 #[inline]
131 pub fn hit_rate(&self) -> f64 {
132 let total = self.hits + self.misses;
133 if total == 0 {
134 0.0
135 } else {
136 self.hits as f64 / total as f64
137 }
138 }
139}
140
141impl QueryCache {
142 pub fn new(max_size: usize) -> Self {
144 tracing::info!(max_size, "QueryCache initialized");
145 Self {
146 max_size,
147 cache: RwLock::new(HashMap::with_capacity(max_size)),
148 stats: RwLock::new(CacheStats::default()),
149 }
150 }
151
152 pub fn insert(&self, key: impl Into<QueryKey>, sql: impl Into<String>) {
154 let key = key.into();
155 let sql = sql.into();
156 let param_count = count_placeholders(&sql);
157 debug!(key = ?key.key, sql_len = sql.len(), param_count, "QueryCache::insert()");
158
159 let mut cache = self.cache.write().unwrap();
160 let mut stats = self.stats.write().unwrap();
161
162 if cache.len() >= self.max_size && !cache.contains_key(&key) {
164 self.evict_lru(&mut cache);
165 stats.evictions += 1;
166 debug!("QueryCache evicted entry");
167 }
168
169 cache.insert(key, CachedQuery::new(sql, param_count));
170 stats.insertions += 1;
171 }
172
173 pub fn insert_with_params(
175 &self,
176 key: impl Into<QueryKey>,
177 sql: impl Into<String>,
178 param_count: usize,
179 ) {
180 let key = key.into();
181 let sql = sql.into();
182
183 let mut cache = self.cache.write().unwrap();
184 let mut stats = self.stats.write().unwrap();
185
186 if cache.len() >= self.max_size && !cache.contains_key(&key) {
188 self.evict_lru(&mut cache);
189 stats.evictions += 1;
190 }
191
192 cache.insert(key, CachedQuery::new(sql, param_count));
193 stats.insertions += 1;
194 }
195
196 pub fn get(&self, key: impl Into<QueryKey>) -> Option<String> {
198 let key = key.into();
199
200 {
202 let cache = self.cache.read().unwrap();
203 if let Some(entry) = cache.get(&key) {
204 let mut stats = self.stats.write().unwrap();
205 stats.hits += 1;
206 debug!(key = ?key.key, "QueryCache hit");
207 return Some(entry.sql.clone());
208 }
209 }
210
211 let mut stats = self.stats.write().unwrap();
212 stats.misses += 1;
213 debug!(key = ?key.key, "QueryCache miss");
214 None
215 }
216
217 pub fn get_entry(&self, key: impl Into<QueryKey>) -> Option<CachedQuery> {
219 let key = key.into();
220
221 let cache = self.cache.read().unwrap();
222 if let Some(entry) = cache.get(&key) {
223 let mut stats = self.stats.write().unwrap();
224 stats.hits += 1;
225 return Some(entry.clone());
226 }
227
228 let mut stats = self.stats.write().unwrap();
229 stats.misses += 1;
230 None
231 }
232
233 pub fn get_or_insert<F>(&self, key: impl Into<QueryKey>, f: F) -> String
238 where
239 F: FnOnce() -> String,
240 {
241 let key = key.into();
242
243 if let Some(sql) = self.get(key.clone()) {
245 return sql;
246 }
247
248 let sql = f();
250 self.insert(key, sql.clone());
251 sql
252 }
253
254 pub fn contains(&self, key: impl Into<QueryKey>) -> bool {
256 let key = key.into();
257 let cache = self.cache.read().unwrap();
258 cache.contains_key(&key)
259 }
260
261 pub fn remove(&self, key: impl Into<QueryKey>) -> Option<String> {
263 let key = key.into();
264 let mut cache = self.cache.write().unwrap();
265 cache.remove(&key).map(|e| e.sql)
266 }
267
268 pub fn clear(&self) {
270 let mut cache = self.cache.write().unwrap();
271 cache.clear();
272 }
273
274 pub fn len(&self) -> usize {
276 let cache = self.cache.read().unwrap();
277 cache.len()
278 }
279
280 pub fn is_empty(&self) -> bool {
282 self.len() == 0
283 }
284
285 pub fn max_size(&self) -> usize {
287 self.max_size
288 }
289
290 pub fn stats(&self) -> CacheStats {
292 let stats = self.stats.read().unwrap();
293 stats.clone()
294 }
295
296 pub fn reset_stats(&self) {
298 let mut stats = self.stats.write().unwrap();
299 *stats = CacheStats::default();
300 }
301
302 fn evict_lru(&self, cache: &mut HashMap<QueryKey, CachedQuery>) {
304 let to_evict = cache.len() / 4; if to_evict == 0 {
308 return;
309 }
310
311 let mut entries: Vec<_> = cache
312 .iter()
313 .map(|(k, v)| (k.clone(), v.access_count))
314 .collect();
315 entries.sort_by_key(|(_, count)| *count);
316
317 for (key, _) in entries.into_iter().take(to_evict) {
318 cache.remove(&key);
319 }
320 }
321}
322
323impl Default for QueryCache {
324 fn default() -> Self {
325 Self::new(1000)
326 }
327}
328
329fn count_placeholders(sql: &str) -> usize {
331 let mut count = 0;
332 let mut chars = sql.chars().peekable();
333
334 while let Some(c) = chars.next() {
335 if c == '$' {
336 let mut num = String::new();
338 while let Some(&d) = chars.peek() {
339 if d.is_ascii_digit() {
340 num.push(d);
341 chars.next();
342 } else {
343 break;
344 }
345 }
346 if !num.is_empty() {
347 if let Ok(n) = num.parse::<usize>() {
348 count = count.max(n);
349 }
350 }
351 } else if c == '?' {
352 count += 1;
354 }
355 }
356
357 count
358}
359
360#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
362pub struct QueryHash(u64);
363
364impl QueryHash {
365 pub fn new(sql: &str) -> Self {
367 let mut hasher = std::collections::hash_map::DefaultHasher::new();
368 sql.hash(&mut hasher);
369 Self(hasher.finish())
370 }
371
372 #[inline]
374 pub fn value(&self) -> u64 {
375 self.0
376 }
377}
378
379pub mod patterns {
381 use super::QueryKey;
382
383 #[inline]
385 pub fn select_by_id(table: &str) -> QueryKey {
386 QueryKey::owned(format!("select_by_id:{}", table))
387 }
388
389 #[inline]
391 pub fn select_all(table: &str) -> QueryKey {
392 QueryKey::owned(format!("select_all:{}", table))
393 }
394
395 #[inline]
397 pub fn insert(table: &str, columns: usize) -> QueryKey {
398 QueryKey::owned(format!("insert:{}:{}", table, columns))
399 }
400
401 #[inline]
403 pub fn update_by_id(table: &str, columns: usize) -> QueryKey {
404 QueryKey::owned(format!("update_by_id:{}:{}", table, columns))
405 }
406
407 #[inline]
409 pub fn delete_by_id(table: &str) -> QueryKey {
410 QueryKey::owned(format!("delete_by_id:{}", table))
411 }
412
413 #[inline]
415 pub fn count(table: &str) -> QueryKey {
416 QueryKey::owned(format!("count:{}", table))
417 }
418
419 #[inline]
421 pub fn count_filtered(table: &str, filter_hash: u64) -> QueryKey {
422 QueryKey::owned(format!("count:{}:{}", table, filter_hash))
423 }
424}
425
426#[derive(Debug)]
456pub struct SqlTemplateCache {
457 max_size: usize,
459 templates: parking_lot::RwLock<HashMap<u64, Arc<SqlTemplate>>>,
461 key_index: parking_lot::RwLock<HashMap<Cow<'static, str>, u64>>,
463 stats: parking_lot::RwLock<CacheStats>,
465}
466
467#[derive(Debug)]
469pub struct SqlTemplate {
470 pub sql: Arc<str>,
472 pub hash: u64,
474 pub param_count: usize,
476 last_access: std::sync::atomic::AtomicU64,
478}
479
480impl Clone for SqlTemplate {
481 fn clone(&self) -> Self {
482 use std::sync::atomic::Ordering;
483 Self {
484 sql: Arc::clone(&self.sql),
485 hash: self.hash,
486 param_count: self.param_count,
487 last_access: std::sync::atomic::AtomicU64::new(
488 self.last_access.load(Ordering::Relaxed),
489 ),
490 }
491 }
492}
493
494impl SqlTemplate {
495 pub fn new(sql: impl AsRef<str>) -> Self {
497 let sql_str = sql.as_ref();
498 let param_count = count_placeholders(sql_str);
499 let hash = {
500 let mut hasher = std::collections::hash_map::DefaultHasher::new();
501 sql_str.hash(&mut hasher);
502 hasher.finish()
503 };
504
505 Self {
506 sql: Arc::from(sql_str),
507 hash,
508 param_count,
509 last_access: std::sync::atomic::AtomicU64::new(0),
510 }
511 }
512
513 #[inline(always)]
515 pub fn sql(&self) -> &str {
516 &self.sql
517 }
518
519 #[inline(always)]
521 pub fn sql_arc(&self) -> Arc<str> {
522 Arc::clone(&self.sql)
523 }
524
525 #[inline]
527 fn touch(&self) {
528 use std::sync::atomic::Ordering;
529 use std::time::{SystemTime, UNIX_EPOCH};
530 let now = SystemTime::now()
531 .duration_since(UNIX_EPOCH)
532 .map(|d| d.as_secs())
533 .unwrap_or(0);
534 self.last_access.store(now, Ordering::Relaxed);
535 }
536}
537
538impl SqlTemplateCache {
539 pub fn new(max_size: usize) -> Self {
541 tracing::info!(max_size, "SqlTemplateCache initialized");
542 Self {
543 max_size,
544 templates: parking_lot::RwLock::new(HashMap::with_capacity(max_size)),
545 key_index: parking_lot::RwLock::new(HashMap::with_capacity(max_size)),
546 stats: parking_lot::RwLock::new(CacheStats::default()),
547 }
548 }
549
550 #[inline]
554 pub fn register(
555 &self,
556 key: impl Into<Cow<'static, str>>,
557 sql: impl AsRef<str>,
558 ) -> Arc<SqlTemplate> {
559 let key = key.into();
560 let template = Arc::new(SqlTemplate::new(sql));
561 let hash = template.hash;
562
563 let mut templates = self.templates.write();
564 let mut key_index = self.key_index.write();
565 let mut stats = self.stats.write();
566
567 if templates.len() >= self.max_size {
569 self.evict_lru_internal(&mut templates, &mut key_index);
570 stats.evictions += 1;
571 }
572
573 key_index.insert(key, hash);
574 templates.insert(hash, Arc::clone(&template));
575 stats.insertions += 1;
576
577 debug!(hash, "SqlTemplateCache::register()");
578 template
579 }
580
581 #[inline]
583 pub fn register_by_hash(&self, hash: u64, sql: impl AsRef<str>) -> Arc<SqlTemplate> {
584 let template = Arc::new(SqlTemplate::new(sql));
585
586 let mut templates = self.templates.write();
587 let mut stats = self.stats.write();
588
589 if templates.len() >= self.max_size {
590 let mut key_index = self.key_index.write();
591 self.evict_lru_internal(&mut templates, &mut key_index);
592 stats.evictions += 1;
593 }
594
595 templates.insert(hash, Arc::clone(&template));
596 stats.insertions += 1;
597
598 template
599 }
600
601 #[inline]
609 pub fn get(&self, key: &str) -> Option<Arc<SqlTemplate>> {
610 let hash = {
611 let key_index = self.key_index.read();
612 match key_index.get(key) {
613 Some(&h) => h,
614 None => {
615 drop(key_index); let mut stats = self.stats.write();
617 stats.misses += 1;
618 return None;
619 }
620 }
621 };
622
623 let templates = self.templates.read();
624 if let Some(template) = templates.get(&hash) {
625 template.touch();
626 let mut stats = self.stats.write();
627 stats.hits += 1;
628 return Some(Arc::clone(template));
629 }
630
631 let mut stats = self.stats.write();
632 stats.misses += 1;
633 None
634 }
635
636 #[inline(always)]
642 pub fn get_by_hash(&self, hash: u64) -> Option<Arc<SqlTemplate>> {
643 let templates = self.templates.read();
644 if let Some(template) = templates.get(&hash) {
645 template.touch();
646 return Some(Arc::clone(template));
648 }
649 None
650 }
651
652 #[inline]
654 pub fn get_sql(&self, key: &str) -> Option<Arc<str>> {
655 self.get(key).map(|t| t.sql_arc())
656 }
657
658 #[inline]
660 pub fn get_or_register<F>(&self, key: impl Into<Cow<'static, str>>, f: F) -> Arc<SqlTemplate>
661 where
662 F: FnOnce() -> String,
663 {
664 let key = key.into();
665
666 if let Some(template) = self.get(&key) {
668 return template;
669 }
670
671 let sql = f();
673 self.register(key, sql)
674 }
675
676 #[inline]
678 pub fn contains(&self, key: &str) -> bool {
679 let key_index = self.key_index.read();
680 key_index.contains_key(key)
681 }
682
683 pub fn stats(&self) -> CacheStats {
685 self.stats.read().clone()
686 }
687
688 pub fn len(&self) -> usize {
690 self.templates.read().len()
691 }
692
693 pub fn is_empty(&self) -> bool {
695 self.len() == 0
696 }
697
698 pub fn clear(&self) {
700 self.templates.write().clear();
701 self.key_index.write().clear();
702 }
703
704 fn evict_lru_internal(
706 &self,
707 templates: &mut HashMap<u64, Arc<SqlTemplate>>,
708 key_index: &mut HashMap<Cow<'static, str>, u64>,
709 ) {
710 use std::sync::atomic::Ordering;
711
712 let to_evict = templates.len() / 4;
713 if to_evict == 0 {
714 return;
715 }
716
717 let mut entries: Vec<_> = templates
719 .iter()
720 .map(|(&hash, t)| (hash, t.last_access.load(Ordering::Relaxed)))
721 .collect();
722 entries.sort_by_key(|(_, time)| *time);
723
724 for (hash, _) in entries.into_iter().take(to_evict) {
726 templates.remove(&hash);
727 key_index.retain(|_, h| *h != hash);
729 }
730 }
731}
732
733impl Default for SqlTemplateCache {
734 fn default() -> Self {
735 Self::new(1000)
736 }
737}
738
739static GLOBAL_TEMPLATE_CACHE: std::sync::OnceLock<SqlTemplateCache> = std::sync::OnceLock::new();
762
763#[inline(always)]
765pub fn global_template_cache() -> &'static SqlTemplateCache {
766 GLOBAL_TEMPLATE_CACHE.get_or_init(|| SqlTemplateCache::new(10000))
767}
768
769#[inline]
771pub fn register_global_template(
772 key: impl Into<Cow<'static, str>>,
773 sql: impl AsRef<str>,
774) -> Arc<SqlTemplate> {
775 global_template_cache().register(key, sql)
776}
777
778#[inline(always)]
780pub fn get_global_template(key: &str) -> Option<Arc<SqlTemplate>> {
781 global_template_cache().get(key)
782}
783
784#[inline]
790pub fn precompute_query_hash(key: &str) -> u64 {
791 let mut hasher = std::collections::hash_map::DefaultHasher::new();
792 key.hash(&mut hasher);
793 hasher.finish()
794}
795
796#[cfg(test)]
797mod tests {
798 use super::*;
799
800 #[test]
801 fn test_query_cache_basic() {
802 let cache = QueryCache::new(10);
803
804 cache.insert("users_by_id", "SELECT * FROM users WHERE id = $1");
805 assert!(cache.contains("users_by_id"));
806
807 let sql = cache.get("users_by_id");
808 assert_eq!(sql, Some("SELECT * FROM users WHERE id = $1".to_string()));
809 }
810
811 #[test]
812 fn test_query_cache_get_or_insert() {
813 let cache = QueryCache::new(10);
814
815 let sql1 = cache.get_or_insert("test", || "SELECT 1".to_string());
816 assert_eq!(sql1, "SELECT 1");
817
818 let sql2 = cache.get_or_insert("test", || "SELECT 2".to_string());
819 assert_eq!(sql2, "SELECT 1"); }
821
822 #[test]
823 fn test_query_cache_stats() {
824 let cache = QueryCache::new(10);
825
826 cache.insert("test", "SELECT 1");
827 cache.get("test"); cache.get("test"); cache.get("missing"); let stats = cache.stats();
832 assert_eq!(stats.hits, 2);
833 assert_eq!(stats.misses, 1);
834 assert_eq!(stats.insertions, 1);
835 }
836
837 #[test]
838 fn test_count_placeholders_postgres() {
839 assert_eq!(count_placeholders("SELECT * FROM users WHERE id = $1"), 1);
840 assert_eq!(
841 count_placeholders("SELECT * FROM users WHERE id = $1 AND name = $2"),
842 2
843 );
844 assert_eq!(count_placeholders("SELECT * FROM users WHERE id = $10"), 10);
845 }
846
847 #[test]
848 fn test_count_placeholders_mysql() {
849 assert_eq!(count_placeholders("SELECT * FROM users WHERE id = ?"), 1);
850 assert_eq!(
851 count_placeholders("SELECT * FROM users WHERE id = ? AND name = ?"),
852 2
853 );
854 }
855
856 #[test]
857 fn test_query_hash() {
858 let hash1 = QueryHash::new("SELECT * FROM users");
859 let hash2 = QueryHash::new("SELECT * FROM users");
860 let hash3 = QueryHash::new("SELECT * FROM posts");
861
862 assert_eq!(hash1, hash2);
863 assert_ne!(hash1, hash3);
864 }
865
866 #[test]
867 fn test_patterns() {
868 let key = patterns::select_by_id("users");
869 assert!(key.key.starts_with("select_by_id:"));
870 }
871
872 #[test]
877 fn test_sql_template_cache_basic() {
878 let cache = SqlTemplateCache::new(100);
879
880 let template = cache.register("users_by_id", "SELECT * FROM users WHERE id = $1");
881 assert_eq!(template.sql(), "SELECT * FROM users WHERE id = $1");
882 assert_eq!(template.param_count, 1);
883 }
884
885 #[test]
886 fn test_sql_template_cache_get() {
887 let cache = SqlTemplateCache::new(100);
888
889 cache.register("test_query", "SELECT * FROM test WHERE x = $1");
890
891 let result = cache.get("test_query");
892 assert!(result.is_some());
893 assert_eq!(result.unwrap().sql(), "SELECT * FROM test WHERE x = $1");
894
895 let missing = cache.get("nonexistent");
896 assert!(missing.is_none());
897 }
898
899 #[test]
900 fn test_sql_template_cache_get_by_hash() {
901 let cache = SqlTemplateCache::new(100);
902
903 let template = cache.register("fast_query", "SELECT 1");
904 let hash = template.hash;
905
906 let result = cache.get_by_hash(hash);
908 assert!(result.is_some());
909 assert_eq!(result.unwrap().sql(), "SELECT 1");
910 }
911
912 #[test]
913 fn test_sql_template_cache_get_or_register() {
914 let cache = SqlTemplateCache::new(100);
915
916 let t1 = cache.get_or_register("computed", || "SELECT * FROM computed".to_string());
917 assert_eq!(t1.sql(), "SELECT * FROM computed");
918
919 let t2 = cache.get_or_register("computed", || panic!("Should not be called"));
921 assert_eq!(t2.sql(), "SELECT * FROM computed");
922 assert_eq!(t1.hash, t2.hash);
923 }
924
925 #[test]
926 fn test_sql_template_cache_stats() {
927 let cache = SqlTemplateCache::new(100);
928
929 cache.register("q1", "SELECT 1");
930 cache.get("q1"); cache.get("q1"); cache.get("missing"); let stats = cache.stats();
935 assert_eq!(stats.hits, 2);
936 assert_eq!(stats.misses, 1);
937 assert_eq!(stats.insertions, 1);
938 }
939
940 #[test]
941 fn test_global_template_cache() {
942 let template = register_global_template("global_test", "SELECT * FROM global");
944 assert_eq!(template.sql(), "SELECT * FROM global");
945
946 let result = get_global_template("global_test");
948 assert!(result.is_some());
949 assert_eq!(result.unwrap().sql(), "SELECT * FROM global");
950 }
951
952 #[test]
953 fn test_precompute_query_hash() {
954 let hash1 = precompute_query_hash("test_key");
955 let hash2 = precompute_query_hash("test_key");
956 let hash3 = precompute_query_hash("other_key");
957
958 assert_eq!(hash1, hash2);
959 assert_ne!(hash1, hash3);
960 }
961
962 #[test]
963 fn test_execution_plan_cache() {
964 let cache = ExecutionPlanCache::new(100);
965
966 let plan = cache.register(
968 "users_by_email",
969 "SELECT * FROM users WHERE email = $1",
970 PlanHint::IndexScan("users_email_idx".into()),
971 );
972 assert_eq!(plan.sql.as_ref(), "SELECT * FROM users WHERE email = $1");
973
974 let result = cache.get("users_by_email");
976 assert!(result.is_some());
977 assert!(matches!(
978 result.unwrap().hint,
979 PlanHint::IndexScan(_)
980 ));
981 }
982}
983
984#[derive(Debug, Clone)]
994#[derive(Default)]
995pub enum PlanHint {
996 #[default]
998 None,
999 IndexScan(String),
1001 SeqScan,
1003 Parallel(u32),
1005 CachePlan,
1007 Timeout(std::time::Duration),
1009 Custom(String),
1011}
1012
1013
1014#[derive(Debug)]
1016pub struct ExecutionPlan {
1017 pub sql: Arc<str>,
1019 pub hash: u64,
1021 pub hint: PlanHint,
1023 pub estimated_cost: Option<f64>,
1025 use_count: std::sync::atomic::AtomicU64,
1027 avg_execution_us: std::sync::atomic::AtomicU64,
1029}
1030
1031fn compute_hash(s: &str) -> u64 {
1033 let mut hasher = std::collections::hash_map::DefaultHasher::new();
1034 s.hash(&mut hasher);
1035 hasher.finish()
1036}
1037
1038impl ExecutionPlan {
1039 pub fn new(sql: impl AsRef<str>, hint: PlanHint) -> Self {
1041 let sql_str = sql.as_ref();
1042 Self {
1043 sql: Arc::from(sql_str),
1044 hash: compute_hash(sql_str),
1045 hint,
1046 estimated_cost: None,
1047 use_count: std::sync::atomic::AtomicU64::new(0),
1048 avg_execution_us: std::sync::atomic::AtomicU64::new(0),
1049 }
1050 }
1051
1052 pub fn with_cost(sql: impl AsRef<str>, hint: PlanHint, cost: f64) -> Self {
1054 let sql_str = sql.as_ref();
1055 Self {
1056 sql: Arc::from(sql_str),
1057 hash: compute_hash(sql_str),
1058 hint,
1059 estimated_cost: Some(cost),
1060 use_count: std::sync::atomic::AtomicU64::new(0),
1061 avg_execution_us: std::sync::atomic::AtomicU64::new(0),
1062 }
1063 }
1064
1065 pub fn record_execution(&self, duration_us: u64) {
1067 let old_count = self
1068 .use_count
1069 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1070 let old_avg = self
1071 .avg_execution_us
1072 .load(std::sync::atomic::Ordering::Relaxed);
1073
1074 let new_avg = if old_count == 0 {
1076 duration_us
1077 } else {
1078 (old_avg * old_count + duration_us) / (old_count + 1)
1080 };
1081
1082 self.avg_execution_us
1083 .store(new_avg, std::sync::atomic::Ordering::Relaxed);
1084 }
1085
1086 pub fn use_count(&self) -> u64 {
1088 self.use_count.load(std::sync::atomic::Ordering::Relaxed)
1089 }
1090
1091 pub fn avg_execution_us(&self) -> u64 {
1093 self.avg_execution_us
1094 .load(std::sync::atomic::Ordering::Relaxed)
1095 }
1096}
1097
1098#[derive(Debug)]
1123pub struct ExecutionPlanCache {
1124 max_size: usize,
1126 plans: parking_lot::RwLock<HashMap<u64, Arc<ExecutionPlan>>>,
1128 key_index: parking_lot::RwLock<HashMap<Cow<'static, str>, u64>>,
1130}
1131
1132impl ExecutionPlanCache {
1133 pub fn new(max_size: usize) -> Self {
1135 Self {
1136 max_size,
1137 plans: parking_lot::RwLock::new(HashMap::with_capacity(max_size / 2)),
1138 key_index: parking_lot::RwLock::new(HashMap::with_capacity(max_size / 2)),
1139 }
1140 }
1141
1142 pub fn register(
1144 &self,
1145 key: impl Into<Cow<'static, str>>,
1146 sql: impl AsRef<str>,
1147 hint: PlanHint,
1148 ) -> Arc<ExecutionPlan> {
1149 let key = key.into();
1150 let plan = Arc::new(ExecutionPlan::new(sql, hint));
1151 let hash = plan.hash;
1152
1153 let mut plans = self.plans.write();
1154 let mut key_index = self.key_index.write();
1155
1156 if plans.len() >= self.max_size && !plans.contains_key(&hash) {
1158 if let Some((&evict_hash, _)) = plans.iter().min_by_key(|(_, p)| p.use_count()) {
1160 plans.remove(&evict_hash);
1161 key_index.retain(|_, &mut v| v != evict_hash);
1162 }
1163 }
1164
1165 plans.insert(hash, Arc::clone(&plan));
1166 key_index.insert(key, hash);
1167
1168 plan
1169 }
1170
1171 pub fn register_with_cost(
1173 &self,
1174 key: impl Into<Cow<'static, str>>,
1175 sql: impl AsRef<str>,
1176 hint: PlanHint,
1177 cost: f64,
1178 ) -> Arc<ExecutionPlan> {
1179 let key = key.into();
1180 let plan = Arc::new(ExecutionPlan::with_cost(sql, hint, cost));
1181 let hash = plan.hash;
1182
1183 let mut plans = self.plans.write();
1184 let mut key_index = self.key_index.write();
1185
1186 if plans.len() >= self.max_size && !plans.contains_key(&hash) {
1187 if let Some((&evict_hash, _)) = plans.iter().min_by_key(|(_, p)| p.use_count()) {
1188 plans.remove(&evict_hash);
1189 key_index.retain(|_, &mut v| v != evict_hash);
1190 }
1191 }
1192
1193 plans.insert(hash, Arc::clone(&plan));
1194 key_index.insert(key, hash);
1195
1196 plan
1197 }
1198
1199 pub fn get(&self, key: &str) -> Option<Arc<ExecutionPlan>> {
1201 let hash = {
1202 let key_index = self.key_index.read();
1203 *key_index.get(key)?
1204 };
1205
1206 self.plans.read().get(&hash).cloned()
1207 }
1208
1209 pub fn get_by_hash(&self, hash: u64) -> Option<Arc<ExecutionPlan>> {
1211 self.plans.read().get(&hash).cloned()
1212 }
1213
1214 pub fn get_or_register<F>(
1216 &self,
1217 key: impl Into<Cow<'static, str>>,
1218 sql_fn: F,
1219 hint: PlanHint,
1220 ) -> Arc<ExecutionPlan>
1221 where
1222 F: FnOnce() -> String,
1223 {
1224 let key = key.into();
1225
1226 if let Some(plan) = self.get(key.as_ref()) {
1228 return plan;
1229 }
1230
1231 self.register(key, sql_fn(), hint)
1233 }
1234
1235 pub fn record_execution(&self, key: &str, duration_us: u64) {
1237 if let Some(plan) = self.get(key) {
1238 plan.record_execution(duration_us);
1239 }
1240 }
1241
1242 pub fn slowest_queries(&self, limit: usize) -> Vec<Arc<ExecutionPlan>> {
1244 let plans = self.plans.read();
1245 let mut sorted: Vec<_> = plans.values().cloned().collect();
1246 sorted.sort_by_key(|a| std::cmp::Reverse(a.avg_execution_us()));
1247 sorted.truncate(limit);
1248 sorted
1249 }
1250
1251 pub fn most_used(&self, limit: usize) -> Vec<Arc<ExecutionPlan>> {
1253 let plans = self.plans.read();
1254 let mut sorted: Vec<_> = plans.values().cloned().collect();
1255 sorted.sort_by_key(|a| std::cmp::Reverse(a.use_count()));
1256 sorted.truncate(limit);
1257 sorted
1258 }
1259
1260 pub fn clear(&self) {
1262 self.plans.write().clear();
1263 self.key_index.write().clear();
1264 }
1265
1266 pub fn len(&self) -> usize {
1268 self.plans.read().len()
1269 }
1270
1271 pub fn is_empty(&self) -> bool {
1273 self.plans.read().is_empty()
1274 }
1275}