Skip to main content

prax_query/
cache.rs

1//! Query caching and prepared statement management.
2//!
3//! This module provides utilities for caching SQL queries and managing
4//! prepared statements to improve performance.
5//!
6//! # Query Cache
7//!
8//! The `QueryCache` stores recently executed queries by their hash,
9//! allowing fast lookup of previously built SQL strings.
10//!
11//! ```rust
12//! use prax_query::cache::QueryCache;
13//!
14//! let cache = QueryCache::new(1000);
15//!
16//! // Cache a query
17//! cache.insert("users_by_id", "SELECT * FROM users WHERE id = $1");
18//!
19//! // Retrieve later
20//! if let Some(sql) = cache.get("users_by_id") {
21//!     println!("Cached SQL: {}", sql);
22//! }
23//! ```
24
25use std::borrow::Cow;
26use std::collections::HashMap;
27use std::hash::{Hash, Hasher};
28use std::sync::atomic::{AtomicU64, Ordering};
29use std::sync::{Arc, RwLock};
30use tracing::debug;
31
32/// A thread-safe cache for SQL queries.
33///
34/// Uses a simple LRU-like eviction strategy when the cache is full.
35#[derive(Debug)]
36pub struct QueryCache {
37    /// Maximum number of entries in the cache.
38    max_size: usize,
39    /// The cached queries.
40    cache: RwLock<HashMap<QueryKey, CachedQuery>>,
41    /// Statistics about cache usage.
42    ///
43    /// Atomic counters so the `get()` hot path can record hits/misses
44    /// without taking the stats write lock while still holding the
45    /// cache read lock (the previous nested-lock pattern was a
46    /// contention hotspot under concurrent reads).
47    stats: AtomicCacheStats,
48}
49
50/// Atomic backing for [`CacheStats`].
51///
52/// `record_*` methods use `Relaxed` because the counters are
53/// strictly monotonic and don't synchronize with anything else; the
54/// only consumer is `QueryCache::stats()` which snapshots them into a
55/// regular [`CacheStats`] value.
56#[derive(Debug, Default)]
57struct AtomicCacheStats {
58    hits: AtomicU64,
59    misses: AtomicU64,
60    evictions: AtomicU64,
61    insertions: AtomicU64,
62}
63
64impl AtomicCacheStats {
65    #[inline]
66    fn record_hit(&self) {
67        self.hits.fetch_add(1, Ordering::Relaxed);
68    }
69
70    #[inline]
71    fn record_miss(&self) {
72        self.misses.fetch_add(1, Ordering::Relaxed);
73    }
74
75    #[inline]
76    fn record_eviction(&self) {
77        self.evictions.fetch_add(1, Ordering::Relaxed);
78    }
79
80    #[inline]
81    fn record_insertion(&self) {
82        self.insertions.fetch_add(1, Ordering::Relaxed);
83    }
84
85    fn snapshot(&self) -> CacheStats {
86        CacheStats {
87            hits: self.hits.load(Ordering::Relaxed),
88            misses: self.misses.load(Ordering::Relaxed),
89            evictions: self.evictions.load(Ordering::Relaxed),
90            insertions: self.insertions.load(Ordering::Relaxed),
91        }
92    }
93
94    fn reset(&self) {
95        self.hits.store(0, Ordering::Relaxed);
96        self.misses.store(0, Ordering::Relaxed);
97        self.evictions.store(0, Ordering::Relaxed);
98        self.insertions.store(0, Ordering::Relaxed);
99    }
100}
101
102/// A key for looking up cached queries.
103#[derive(Debug, Clone, PartialEq, Eq, Hash)]
104pub struct QueryKey {
105    /// The unique identifier for this query type.
106    key: Cow<'static, str>,
107}
108
109impl QueryKey {
110    /// Create a new query key from a static string.
111    #[inline]
112    pub const fn new(key: &'static str) -> Self {
113        Self {
114            key: Cow::Borrowed(key),
115        }
116    }
117
118    /// Create a new query key from an owned string.
119    #[inline]
120    pub fn owned(key: String) -> Self {
121        Self {
122            key: Cow::Owned(key),
123        }
124    }
125}
126
127impl From<&'static str> for QueryKey {
128    fn from(s: &'static str) -> Self {
129        Self::new(s)
130    }
131}
132
133impl From<String> for QueryKey {
134    fn from(s: String) -> Self {
135        Self::owned(s)
136    }
137}
138
139/// A cached SQL query.
140#[derive(Debug, Clone)]
141pub struct CachedQuery {
142    /// The SQL string.
143    pub sql: String,
144    /// The number of parameters expected.
145    pub param_count: usize,
146    /// Number of times this query has been accessed.
147    access_count: u64,
148}
149
150impl CachedQuery {
151    /// Create a new cached query.
152    pub fn new(sql: impl Into<String>, param_count: usize) -> Self {
153        Self {
154            sql: sql.into(),
155            param_count,
156            access_count: 0,
157        }
158    }
159
160    /// Get the SQL string.
161    #[inline]
162    pub fn sql(&self) -> &str {
163        &self.sql
164    }
165
166    /// Get the expected parameter count.
167    #[inline]
168    pub fn param_count(&self) -> usize {
169        self.param_count
170    }
171}
172
173/// Statistics about cache usage.
174#[derive(Debug, Default, Clone)]
175pub struct CacheStats {
176    /// Number of cache hits.
177    pub hits: u64,
178    /// Number of cache misses.
179    pub misses: u64,
180    /// Number of evictions.
181    pub evictions: u64,
182    /// Number of insertions.
183    pub insertions: u64,
184}
185
186impl CacheStats {
187    /// Calculate the hit rate.
188    #[inline]
189    pub fn hit_rate(&self) -> f64 {
190        let total = self.hits + self.misses;
191        if total == 0 {
192            0.0
193        } else {
194            self.hits as f64 / total as f64
195        }
196    }
197}
198
199impl QueryCache {
200    /// Create a new query cache with the given maximum size.
201    pub fn new(max_size: usize) -> Self {
202        tracing::info!(max_size, "QueryCache initialized");
203        Self {
204            max_size,
205            cache: RwLock::new(HashMap::with_capacity(max_size)),
206            stats: AtomicCacheStats::default(),
207        }
208    }
209
210    /// Insert a query into the cache.
211    pub fn insert(&self, key: impl Into<QueryKey>, sql: impl Into<String>) {
212        let key = key.into();
213        let sql = sql.into();
214        let param_count = count_placeholders(&sql);
215        debug!(key = ?key.key, sql_len = sql.len(), param_count, "QueryCache::insert()");
216
217        let mut cache = self.cache.write().unwrap();
218
219        // Evict if full
220        if cache.len() >= self.max_size && !cache.contains_key(&key) {
221            self.evict_lru(&mut cache);
222            self.stats.record_eviction();
223            debug!("QueryCache evicted entry");
224        }
225
226        cache.insert(key, CachedQuery::new(sql, param_count));
227        self.stats.record_insertion();
228    }
229
230    /// Insert a query with known parameter count.
231    pub fn insert_with_params(
232        &self,
233        key: impl Into<QueryKey>,
234        sql: impl Into<String>,
235        param_count: usize,
236    ) {
237        let key = key.into();
238        let sql = sql.into();
239
240        let mut cache = self.cache.write().unwrap();
241
242        // Evict if full
243        if cache.len() >= self.max_size && !cache.contains_key(&key) {
244            self.evict_lru(&mut cache);
245            self.stats.record_eviction();
246        }
247
248        cache.insert(key, CachedQuery::new(sql, param_count));
249        self.stats.record_insertion();
250    }
251
252    /// Get a query from the cache.
253    pub fn get(&self, key: impl Into<QueryKey>) -> Option<String> {
254        let key = key.into();
255
256        let cache = self.cache.read().unwrap();
257        if let Some(entry) = cache.get(&key) {
258            // Atomic counter — no write-lock conflict with the read
259            // lock we still hold on `cache`.
260            self.stats.record_hit();
261            debug!(key = ?key.key, "QueryCache hit");
262            return Some(entry.sql.clone());
263        }
264        drop(cache);
265
266        self.stats.record_miss();
267        debug!(key = ?key.key, "QueryCache miss");
268        None
269    }
270
271    /// Get a cached query entry (includes metadata).
272    pub fn get_entry(&self, key: impl Into<QueryKey>) -> Option<CachedQuery> {
273        let key = key.into();
274
275        let cache = self.cache.read().unwrap();
276        if let Some(entry) = cache.get(&key) {
277            self.stats.record_hit();
278            return Some(entry.clone());
279        }
280        drop(cache);
281
282        self.stats.record_miss();
283        None
284    }
285
286    /// Get or compute a query.
287    ///
288    /// If the query is cached, returns the cached version.
289    /// Otherwise, computes it using the provided function and caches it.
290    pub fn get_or_insert<F>(&self, key: impl Into<QueryKey>, f: F) -> String
291    where
292        F: FnOnce() -> String,
293    {
294        let key = key.into();
295
296        // Try to get from cache
297        if let Some(sql) = self.get(key.clone()) {
298            return sql;
299        }
300
301        // Compute and insert
302        let sql = f();
303        self.insert(key, sql.clone());
304        sql
305    }
306
307    /// Check if a key exists in the cache.
308    pub fn contains(&self, key: impl Into<QueryKey>) -> bool {
309        let key = key.into();
310        let cache = self.cache.read().unwrap();
311        cache.contains_key(&key)
312    }
313
314    /// Remove a query from the cache.
315    pub fn remove(&self, key: impl Into<QueryKey>) -> Option<String> {
316        let key = key.into();
317        let mut cache = self.cache.write().unwrap();
318        cache.remove(&key).map(|e| e.sql)
319    }
320
321    /// Clear the entire cache.
322    pub fn clear(&self) {
323        let mut cache = self.cache.write().unwrap();
324        cache.clear();
325    }
326
327    /// Get the current number of cached queries.
328    pub fn len(&self) -> usize {
329        let cache = self.cache.read().unwrap();
330        cache.len()
331    }
332
333    /// Check if the cache is empty.
334    pub fn is_empty(&self) -> bool {
335        self.len() == 0
336    }
337
338    /// Get the maximum cache size.
339    pub fn max_size(&self) -> usize {
340        self.max_size
341    }
342
343    /// Get cache statistics.
344    pub fn stats(&self) -> CacheStats {
345        self.stats.snapshot()
346    }
347
348    /// Reset cache statistics.
349    pub fn reset_stats(&self) {
350        self.stats.reset();
351    }
352
353    /// Evict the least recently used entries.
354    fn evict_lru(&self, cache: &mut HashMap<QueryKey, CachedQuery>) {
355        // Simple strategy: evict entries with lowest access count
356        // In production, consider using a proper LRU data structure
357        let to_evict = cache.len() / 4; // Evict 25%
358        if to_evict == 0 {
359            return;
360        }
361
362        let mut entries: Vec<_> = cache
363            .iter()
364            .map(|(k, v)| (k.clone(), v.access_count))
365            .collect();
366        entries.sort_by_key(|(_, count)| *count);
367
368        for (key, _) in entries.into_iter().take(to_evict) {
369            cache.remove(&key);
370        }
371    }
372}
373
374impl Default for QueryCache {
375    fn default() -> Self {
376        Self::new(1000)
377    }
378}
379
380/// Count the number of parameter placeholders in a SQL string.
381fn count_placeholders(sql: &str) -> usize {
382    let mut count = 0;
383    let mut chars = sql.chars().peekable();
384
385    while let Some(c) = chars.next() {
386        if c == '$' {
387            // PostgreSQL-style: $1, $2, etc.
388            let mut num = String::new();
389            while let Some(&d) = chars.peek() {
390                if d.is_ascii_digit() {
391                    num.push(d);
392                    chars.next();
393                } else {
394                    break;
395                }
396            }
397            if !num.is_empty()
398                && let Ok(n) = num.parse::<usize>()
399            {
400                count = count.max(n);
401            }
402        } else if c == '?' {
403            // MySQL/SQLite-style
404            count += 1;
405        }
406    }
407
408    count
409}
410
411/// A query hash for fast lookup.
412#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
413pub struct QueryHash(u64);
414
415impl QueryHash {
416    /// Compute a hash for the given SQL query.
417    pub fn new(sql: &str) -> Self {
418        let mut hasher = std::collections::hash_map::DefaultHasher::new();
419        sql.hash(&mut hasher);
420        Self(hasher.finish())
421    }
422
423    /// Get the raw hash value.
424    #[inline]
425    pub fn value(&self) -> u64 {
426        self.0
427    }
428}
429
430/// Common query patterns for caching.
431pub mod patterns {
432    use super::QueryKey;
433
434    /// Query key for SELECT by ID.
435    #[inline]
436    pub fn select_by_id(table: &str) -> QueryKey {
437        QueryKey::owned(format!("select_by_id:{}", table))
438    }
439
440    /// Query key for SELECT all.
441    #[inline]
442    pub fn select_all(table: &str) -> QueryKey {
443        QueryKey::owned(format!("select_all:{}", table))
444    }
445
446    /// Query key for INSERT.
447    #[inline]
448    pub fn insert(table: &str, columns: usize) -> QueryKey {
449        QueryKey::owned(format!("insert:{}:{}", table, columns))
450    }
451
452    /// Query key for UPDATE by ID.
453    #[inline]
454    pub fn update_by_id(table: &str, columns: usize) -> QueryKey {
455        QueryKey::owned(format!("update_by_id:{}:{}", table, columns))
456    }
457
458    /// Query key for DELETE by ID.
459    #[inline]
460    pub fn delete_by_id(table: &str) -> QueryKey {
461        QueryKey::owned(format!("delete_by_id:{}", table))
462    }
463
464    /// Query key for COUNT.
465    #[inline]
466    pub fn count(table: &str) -> QueryKey {
467        QueryKey::owned(format!("count:{}", table))
468    }
469
470    /// Query key for COUNT with filter.
471    #[inline]
472    pub fn count_filtered(table: &str, filter_hash: u64) -> QueryKey {
473        QueryKey::owned(format!("count:{}:{}", table, filter_hash))
474    }
475}
476
477// =============================================================================
478// High-Performance SQL Template Cache
479// =============================================================================
480
481/// A high-performance SQL template cache optimized for repeated queries.
482///
483/// Unlike `QueryCache` which stores full SQL strings, `SqlTemplateCache` stores
484/// template structures with pre-computed placeholder positions for very fast
485/// instantiation.
486///
487/// # Performance
488///
489/// - Cache lookup: O(1) hash lookup, ~5-10ns
490/// - Template instantiation: O(n) where n is parameter count
491/// - Thread-safe with minimal contention (parking_lot RwLock)
492///
493/// # Examples
494///
495/// ```rust
496/// use prax_query::cache::SqlTemplateCache;
497///
498/// let cache = SqlTemplateCache::new(1000);
499///
500/// // Register a template
501/// let template = cache.register("users_by_id", "SELECT * FROM users WHERE id = $1");
502///
503/// // Instant retrieval (~5ns)
504/// let sql = cache.get("users_by_id");
505/// ```
506#[derive(Debug)]
507pub struct SqlTemplateCache {
508    /// Maximum number of templates.
509    max_size: usize,
510    /// Cached templates (using Arc for cheap cloning).
511    templates: parking_lot::RwLock<HashMap<u64, Arc<SqlTemplate>>>,
512    /// String key to hash lookup.
513    key_index: parking_lot::RwLock<HashMap<Cow<'static, str>, u64>>,
514    /// Statistics.
515    stats: parking_lot::RwLock<CacheStats>,
516}
517
518/// A pre-parsed SQL template for fast instantiation.
519#[derive(Debug)]
520pub struct SqlTemplate {
521    /// The complete SQL string (for direct use).
522    pub sql: Arc<str>,
523    /// Pre-computed hash for fast lookup.
524    pub hash: u64,
525    /// Number of parameters.
526    pub param_count: usize,
527    /// Access timestamp for LRU.
528    last_access: std::sync::atomic::AtomicU64,
529}
530
531impl Clone for SqlTemplate {
532    fn clone(&self) -> Self {
533        use std::sync::atomic::Ordering;
534        Self {
535            sql: Arc::clone(&self.sql),
536            hash: self.hash,
537            param_count: self.param_count,
538            last_access: std::sync::atomic::AtomicU64::new(
539                self.last_access.load(Ordering::Relaxed),
540            ),
541        }
542    }
543}
544
545impl SqlTemplate {
546    /// Create a new SQL template.
547    pub fn new(sql: impl AsRef<str>) -> Self {
548        let sql_str = sql.as_ref();
549        let param_count = count_placeholders(sql_str);
550        let hash = {
551            let mut hasher = std::collections::hash_map::DefaultHasher::new();
552            sql_str.hash(&mut hasher);
553            hasher.finish()
554        };
555
556        Self {
557            sql: Arc::from(sql_str),
558            hash,
559            param_count,
560            last_access: std::sync::atomic::AtomicU64::new(0),
561        }
562    }
563
564    /// Get the SQL string as a reference.
565    #[inline(always)]
566    pub fn sql(&self) -> &str {
567        &self.sql
568    }
569
570    /// Get the SQL string as an Arc (zero-copy clone).
571    #[inline(always)]
572    pub fn sql_arc(&self) -> Arc<str> {
573        Arc::clone(&self.sql)
574    }
575
576    /// Touch the template to update LRU access time.
577    #[inline]
578    fn touch(&self) {
579        use std::sync::atomic::Ordering;
580        use std::time::{SystemTime, UNIX_EPOCH};
581        let now = SystemTime::now()
582            .duration_since(UNIX_EPOCH)
583            .map(|d| d.as_secs())
584            .unwrap_or(0);
585        self.last_access.store(now, Ordering::Relaxed);
586    }
587}
588
589impl SqlTemplateCache {
590    /// Create a new template cache with the given maximum size.
591    pub fn new(max_size: usize) -> Self {
592        tracing::info!(max_size, "SqlTemplateCache initialized");
593        Self {
594            max_size,
595            templates: parking_lot::RwLock::new(HashMap::with_capacity(max_size)),
596            key_index: parking_lot::RwLock::new(HashMap::with_capacity(max_size)),
597            stats: parking_lot::RwLock::new(CacheStats::default()),
598        }
599    }
600
601    /// Register a SQL template with a string key.
602    ///
603    /// Returns the template for immediate use.
604    #[inline]
605    pub fn register(
606        &self,
607        key: impl Into<Cow<'static, str>>,
608        sql: impl AsRef<str>,
609    ) -> Arc<SqlTemplate> {
610        let key = key.into();
611        let template = Arc::new(SqlTemplate::new(sql));
612        let hash = template.hash;
613
614        let mut templates = self.templates.write();
615        let mut key_index = self.key_index.write();
616        let mut stats = self.stats.write();
617
618        // Evict if full
619        if templates.len() >= self.max_size {
620            self.evict_lru_internal(&mut templates, &mut key_index);
621            stats.evictions += 1;
622        }
623
624        key_index.insert(key, hash);
625        templates.insert(hash, Arc::clone(&template));
626        stats.insertions += 1;
627
628        debug!(hash, "SqlTemplateCache::register()");
629        template
630    }
631
632    /// Register a template by hash (for pre-computed hashes).
633    #[inline]
634    pub fn register_by_hash(&self, hash: u64, sql: impl AsRef<str>) -> Arc<SqlTemplate> {
635        let template = Arc::new(SqlTemplate::new(sql));
636
637        let mut templates = self.templates.write();
638        let mut stats = self.stats.write();
639
640        if templates.len() >= self.max_size {
641            let mut key_index = self.key_index.write();
642            self.evict_lru_internal(&mut templates, &mut key_index);
643            stats.evictions += 1;
644        }
645
646        templates.insert(hash, Arc::clone(&template));
647        stats.insertions += 1;
648
649        template
650    }
651
652    /// Get a template by string key (returns Arc for zero-copy).
653    ///
654    /// # Performance
655    ///
656    /// This is the fastest way to get cached SQL:
657    /// - Hash lookup: ~5ns
658    /// - Returns Arc<SqlTemplate> (no allocation)
659    #[inline]
660    pub fn get(&self, key: &str) -> Option<Arc<SqlTemplate>> {
661        let hash = {
662            let key_index = self.key_index.read();
663            match key_index.get(key) {
664                Some(&h) => h,
665                None => {
666                    drop(key_index); // Release read lock before write
667                    let mut stats = self.stats.write();
668                    stats.misses += 1;
669                    return None;
670                }
671            }
672        };
673
674        let templates = self.templates.read();
675        if let Some(template) = templates.get(&hash) {
676            template.touch();
677            let mut stats = self.stats.write();
678            stats.hits += 1;
679            return Some(Arc::clone(template));
680        }
681
682        let mut stats = self.stats.write();
683        stats.misses += 1;
684        None
685    }
686
687    /// Get a template by pre-computed hash (fastest path).
688    ///
689    /// # Performance
690    ///
691    /// ~3-5ns for cache hit with pre-computed hash.
692    #[inline(always)]
693    pub fn get_by_hash(&self, hash: u64) -> Option<Arc<SqlTemplate>> {
694        let templates = self.templates.read();
695        if let Some(template) = templates.get(&hash) {
696            template.touch();
697            // Skip stats update for maximum performance
698            return Some(Arc::clone(template));
699        }
700        None
701    }
702
703    /// Get the SQL string directly (convenience method).
704    #[inline]
705    pub fn get_sql(&self, key: &str) -> Option<Arc<str>> {
706        self.get(key).map(|t| t.sql_arc())
707    }
708
709    /// Get or compute a template.
710    #[inline]
711    pub fn get_or_register<F>(&self, key: impl Into<Cow<'static, str>>, f: F) -> Arc<SqlTemplate>
712    where
713        F: FnOnce() -> String,
714    {
715        let key = key.into();
716
717        // Fast path: check if exists
718        if let Some(template) = self.get(&key) {
719            return template;
720        }
721
722        // Slow path: compute and register
723        let sql = f();
724        self.register(key, sql)
725    }
726
727    /// Check if a key exists.
728    #[inline]
729    pub fn contains(&self, key: &str) -> bool {
730        let key_index = self.key_index.read();
731        key_index.contains_key(key)
732    }
733
734    /// Get cache statistics.
735    pub fn stats(&self) -> CacheStats {
736        self.stats.read().clone()
737    }
738
739    /// Get the number of cached templates.
740    pub fn len(&self) -> usize {
741        self.templates.read().len()
742    }
743
744    /// Check if the cache is empty.
745    pub fn is_empty(&self) -> bool {
746        self.len() == 0
747    }
748
749    /// Clear the cache.
750    pub fn clear(&self) {
751        self.templates.write().clear();
752        self.key_index.write().clear();
753    }
754
755    /// Evict least recently used templates (internal, assumes locks held).
756    fn evict_lru_internal(
757        &self,
758        templates: &mut HashMap<u64, Arc<SqlTemplate>>,
759        key_index: &mut HashMap<Cow<'static, str>, u64>,
760    ) {
761        use std::sync::atomic::Ordering;
762
763        let to_evict = templates.len() / 4;
764        if to_evict == 0 {
765            return;
766        }
767
768        // Find templates with oldest access times
769        let mut entries: Vec<_> = templates
770            .iter()
771            .map(|(&hash, t)| (hash, t.last_access.load(Ordering::Relaxed)))
772            .collect();
773        entries.sort_by_key(|(_, time)| *time);
774
775        // Collect the set of hashes to evict, remove their templates, then
776        // prune key_index in a single pass (avoids O(evicted * key_index.len())).
777        let evicted: std::collections::HashSet<u64> = entries
778            .into_iter()
779            .take(to_evict)
780            .map(|(hash, _)| {
781                templates.remove(&hash);
782                hash
783            })
784            .collect();
785        key_index.retain(|_, h| !evicted.contains(h));
786    }
787}
788
789impl Default for SqlTemplateCache {
790    fn default() -> Self {
791        Self::new(1000)
792    }
793}
794
795// =============================================================================
796// Global Template Cache (for zero-overhead repeated queries)
797// =============================================================================
798
799/// Global SQL template cache for maximum performance.
800///
801/// Use this for queries that are repeated many times with only parameter changes.
802/// The global cache avoids the overhead of passing cache references around.
803///
804/// # Examples
805///
806/// ```rust
807/// use prax_query::cache::{global_template_cache, register_global_template};
808///
809/// // Pre-register common queries at startup
810/// register_global_template("users_by_id", "SELECT * FROM users WHERE id = $1");
811///
812/// // Later, get the cached SQL (~5ns)
813/// if let Some(template) = global_template_cache().get("users_by_id") {
814///     println!("SQL: {}", template.sql());
815/// }
816/// ```
817static GLOBAL_TEMPLATE_CACHE: std::sync::OnceLock<SqlTemplateCache> = std::sync::OnceLock::new();
818
819/// Get the global SQL template cache.
820#[inline(always)]
821pub fn global_template_cache() -> &'static SqlTemplateCache {
822    GLOBAL_TEMPLATE_CACHE.get_or_init(|| SqlTemplateCache::new(10000))
823}
824
825/// Register a template in the global cache.
826#[inline]
827pub fn register_global_template(
828    key: impl Into<Cow<'static, str>>,
829    sql: impl AsRef<str>,
830) -> Arc<SqlTemplate> {
831    global_template_cache().register(key, sql)
832}
833
834/// Get a template from the global cache.
835#[inline(always)]
836pub fn get_global_template(key: &str) -> Option<Arc<SqlTemplate>> {
837    global_template_cache().get(key)
838}
839
840/// Pre-compute a query hash for repeated lookups.
841///
842/// Use this when you have a query key that will be used many times.
843/// Computing the hash once and using `get_by_hash` is faster than
844/// string key lookups.
845#[inline]
846pub fn precompute_query_hash(key: &str) -> u64 {
847    let mut hasher = std::collections::hash_map::DefaultHasher::new();
848    key.hash(&mut hasher);
849    hasher.finish()
850}
851
852#[cfg(test)]
853mod tests {
854    use super::*;
855
856    #[test]
857    fn test_query_cache_basic() {
858        let cache = QueryCache::new(10);
859
860        cache.insert("users_by_id", "SELECT * FROM users WHERE id = $1");
861        assert!(cache.contains("users_by_id"));
862
863        let sql = cache.get("users_by_id");
864        assert_eq!(sql, Some("SELECT * FROM users WHERE id = $1".to_string()));
865    }
866
867    #[test]
868    fn test_query_cache_get_or_insert() {
869        let cache = QueryCache::new(10);
870
871        let sql1 = cache.get_or_insert("test", || "SELECT 1".to_string());
872        assert_eq!(sql1, "SELECT 1");
873
874        let sql2 = cache.get_or_insert("test", || "SELECT 2".to_string());
875        assert_eq!(sql2, "SELECT 1"); // Should return cached value
876    }
877
878    #[test]
879    fn test_query_cache_stats() {
880        let cache = QueryCache::new(10);
881
882        cache.insert("test", "SELECT 1");
883        cache.get("test"); // Hit
884        cache.get("test"); // Hit
885        cache.get("missing"); // Miss
886
887        let stats = cache.stats();
888        assert_eq!(stats.hits, 2);
889        assert_eq!(stats.misses, 1);
890        assert_eq!(stats.insertions, 1);
891    }
892
893    #[test]
894    fn test_count_placeholders_postgres() {
895        assert_eq!(count_placeholders("SELECT * FROM users WHERE id = $1"), 1);
896        assert_eq!(
897            count_placeholders("SELECT * FROM users WHERE id = $1 AND name = $2"),
898            2
899        );
900        assert_eq!(count_placeholders("SELECT * FROM users WHERE id = $10"), 10);
901    }
902
903    #[test]
904    fn test_count_placeholders_mysql() {
905        assert_eq!(count_placeholders("SELECT * FROM users WHERE id = ?"), 1);
906        assert_eq!(
907            count_placeholders("SELECT * FROM users WHERE id = ? AND name = ?"),
908            2
909        );
910    }
911
912    #[test]
913    fn test_query_hash() {
914        let hash1 = QueryHash::new("SELECT * FROM users");
915        let hash2 = QueryHash::new("SELECT * FROM users");
916        let hash3 = QueryHash::new("SELECT * FROM posts");
917
918        assert_eq!(hash1, hash2);
919        assert_ne!(hash1, hash3);
920    }
921
922    #[test]
923    fn test_patterns() {
924        let key = patterns::select_by_id("users");
925        assert!(key.key.starts_with("select_by_id:"));
926    }
927
928    // =========================================================================
929    // SqlTemplateCache Tests
930    // =========================================================================
931
932    #[test]
933    fn test_sql_template_cache_basic() {
934        let cache = SqlTemplateCache::new(100);
935
936        let template = cache.register("users_by_id", "SELECT * FROM users WHERE id = $1");
937        assert_eq!(template.sql(), "SELECT * FROM users WHERE id = $1");
938        assert_eq!(template.param_count, 1);
939    }
940
941    #[test]
942    fn test_sql_template_cache_get() {
943        let cache = SqlTemplateCache::new(100);
944
945        cache.register("test_query", "SELECT * FROM test WHERE x = $1");
946
947        let result = cache.get("test_query");
948        assert!(result.is_some());
949        assert_eq!(result.unwrap().sql(), "SELECT * FROM test WHERE x = $1");
950
951        let missing = cache.get("nonexistent");
952        assert!(missing.is_none());
953    }
954
955    #[test]
956    fn test_sql_template_cache_get_by_hash() {
957        let cache = SqlTemplateCache::new(100);
958
959        let template = cache.register("fast_query", "SELECT 1");
960        let hash = template.hash;
961
962        // Get by hash should be very fast
963        let result = cache.get_by_hash(hash);
964        assert!(result.is_some());
965        assert_eq!(result.unwrap().sql(), "SELECT 1");
966    }
967
968    #[test]
969    fn test_sql_template_cache_get_or_register() {
970        let cache = SqlTemplateCache::new(100);
971
972        let t1 = cache.get_or_register("computed", || "SELECT * FROM computed".to_string());
973        assert_eq!(t1.sql(), "SELECT * FROM computed");
974
975        // Second call should return cached version
976        let t2 = cache.get_or_register("computed", || panic!("Should not be called"));
977        assert_eq!(t2.sql(), "SELECT * FROM computed");
978        assert_eq!(t1.hash, t2.hash);
979    }
980
981    #[test]
982    fn test_sql_template_cache_stats() {
983        let cache = SqlTemplateCache::new(100);
984
985        cache.register("q1", "SELECT 1");
986        cache.get("q1"); // Hit
987        cache.get("q1"); // Hit
988        cache.get("missing"); // Miss
989
990        let stats = cache.stats();
991        assert_eq!(stats.hits, 2);
992        assert_eq!(stats.misses, 1);
993        assert_eq!(stats.insertions, 1);
994    }
995
996    #[test]
997    fn test_global_template_cache() {
998        // Register in global cache
999        let template = register_global_template("global_test", "SELECT * FROM global");
1000        assert_eq!(template.sql(), "SELECT * FROM global");
1001
1002        // Retrieve from global cache
1003        let result = get_global_template("global_test");
1004        assert!(result.is_some());
1005        assert_eq!(result.unwrap().sql(), "SELECT * FROM global");
1006    }
1007
1008    #[test]
1009    fn test_precompute_query_hash() {
1010        let hash1 = precompute_query_hash("test_key");
1011        let hash2 = precompute_query_hash("test_key");
1012        let hash3 = precompute_query_hash("other_key");
1013
1014        assert_eq!(hash1, hash2);
1015        assert_ne!(hash1, hash3);
1016    }
1017
1018    #[test]
1019    fn test_execution_plan_cache() {
1020        let cache = ExecutionPlanCache::new(100);
1021
1022        // Register a plan
1023        let plan = cache.register(
1024            "users_by_email",
1025            "SELECT * FROM users WHERE email = $1",
1026            PlanHint::IndexScan("users_email_idx".into()),
1027        );
1028        assert_eq!(plan.sql.as_ref(), "SELECT * FROM users WHERE email = $1");
1029
1030        // Get cached plan
1031        let result = cache.get("users_by_email");
1032        assert!(result.is_some());
1033        assert!(matches!(result.unwrap().hint, PlanHint::IndexScan(_)));
1034    }
1035}
1036
1037// ============================================================================
1038// Execution Plan Caching
1039// ============================================================================
1040
1041/// Hints for query execution optimization.
1042///
1043/// These hints can be used by database engines to optimize query execution.
1044/// Different databases support different hints - the engine implementation
1045/// decides how to apply them.
1046#[derive(Debug, Clone, Default)]
1047pub enum PlanHint {
1048    /// No specific hint.
1049    #[default]
1050    None,
1051    /// Force use of a specific index.
1052    IndexScan(String),
1053    /// Force a sequential scan (for analytics queries).
1054    SeqScan,
1055    /// Enable parallel execution.
1056    Parallel(u32),
1057    /// Cache this query's execution plan.
1058    CachePlan,
1059    /// Set a timeout for this query.
1060    Timeout(std::time::Duration),
1061    /// Custom database-specific hint.
1062    Custom(String),
1063}
1064
1065/// A cached execution plan with optimization hints.
1066#[derive(Debug)]
1067pub struct ExecutionPlan {
1068    /// The SQL query.
1069    pub sql: Arc<str>,
1070    /// Pre-computed hash for fast lookup.
1071    pub hash: u64,
1072    /// Execution hint.
1073    pub hint: PlanHint,
1074    /// Estimated cost (if available from EXPLAIN).
1075    pub estimated_cost: Option<f64>,
1076    /// Number of times this plan has been used.
1077    use_count: std::sync::atomic::AtomicU64,
1078    /// Average execution time in microseconds.
1079    avg_execution_us: std::sync::atomic::AtomicU64,
1080}
1081
1082/// Compute a hash for a string.
1083fn compute_hash(s: &str) -> u64 {
1084    let mut hasher = std::collections::hash_map::DefaultHasher::new();
1085    s.hash(&mut hasher);
1086    hasher.finish()
1087}
1088
1089impl ExecutionPlan {
1090    /// Create a new execution plan.
1091    pub fn new(sql: impl AsRef<str>, hint: PlanHint) -> Self {
1092        let sql_str = sql.as_ref();
1093        Self {
1094            sql: Arc::from(sql_str),
1095            hash: compute_hash(sql_str),
1096            hint,
1097            estimated_cost: None,
1098            use_count: std::sync::atomic::AtomicU64::new(0),
1099            avg_execution_us: std::sync::atomic::AtomicU64::new(0),
1100        }
1101    }
1102
1103    /// Create with estimated cost.
1104    pub fn with_cost(sql: impl AsRef<str>, hint: PlanHint, cost: f64) -> Self {
1105        let sql_str = sql.as_ref();
1106        Self {
1107            sql: Arc::from(sql_str),
1108            hash: compute_hash(sql_str),
1109            hint,
1110            estimated_cost: Some(cost),
1111            use_count: std::sync::atomic::AtomicU64::new(0),
1112            avg_execution_us: std::sync::atomic::AtomicU64::new(0),
1113        }
1114    }
1115
1116    /// Record an execution with timing.
1117    pub fn record_execution(&self, duration_us: u64) {
1118        let old_count = self
1119            .use_count
1120            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
1121        let old_avg = self
1122            .avg_execution_us
1123            .load(std::sync::atomic::Ordering::Relaxed);
1124
1125        // Update running average
1126        let new_avg = if old_count == 0 {
1127            duration_us
1128        } else {
1129            // Weighted average: (old_avg * old_count + new_value) / (old_count + 1)
1130            (old_avg * old_count + duration_us) / (old_count + 1)
1131        };
1132
1133        self.avg_execution_us
1134            .store(new_avg, std::sync::atomic::Ordering::Relaxed);
1135    }
1136
1137    /// Get the use count.
1138    pub fn use_count(&self) -> u64 {
1139        self.use_count.load(std::sync::atomic::Ordering::Relaxed)
1140    }
1141
1142    /// Get the average execution time in microseconds.
1143    pub fn avg_execution_us(&self) -> u64 {
1144        self.avg_execution_us
1145            .load(std::sync::atomic::Ordering::Relaxed)
1146    }
1147}
1148
1149/// Cache for query execution plans.
1150///
1151/// This cache stores not just SQL strings but also execution hints and
1152/// performance metrics for each query, enabling adaptive optimization.
1153///
1154/// # Example
1155///
1156/// ```rust
1157/// use prax_query::cache::{ExecutionPlanCache, PlanHint};
1158///
1159/// let cache = ExecutionPlanCache::new(1000);
1160///
1161/// // Register a plan with an index hint
1162/// let plan = cache.register(
1163///     "find_user_by_email",
1164///     "SELECT * FROM users WHERE email = $1",
1165///     PlanHint::IndexScan("idx_users_email".into()),
1166/// );
1167///
1168/// // Get the plan later
1169/// if let Some(plan) = cache.get("find_user_by_email") {
1170///     println!("Using plan with hint: {:?}", plan.hint);
1171/// }
1172/// ```
1173#[derive(Debug)]
1174pub struct ExecutionPlanCache {
1175    /// Maximum number of plans to cache.
1176    max_size: usize,
1177    /// Cached plans.
1178    plans: parking_lot::RwLock<HashMap<u64, Arc<ExecutionPlan>>>,
1179    /// Key to hash lookup.
1180    key_index: parking_lot::RwLock<HashMap<Cow<'static, str>, u64>>,
1181}
1182
1183impl ExecutionPlanCache {
1184    /// Create a new execution plan cache.
1185    pub fn new(max_size: usize) -> Self {
1186        Self {
1187            max_size,
1188            plans: parking_lot::RwLock::new(HashMap::with_capacity(max_size / 2)),
1189            key_index: parking_lot::RwLock::new(HashMap::with_capacity(max_size / 2)),
1190        }
1191    }
1192
1193    /// Register a new execution plan.
1194    pub fn register(
1195        &self,
1196        key: impl Into<Cow<'static, str>>,
1197        sql: impl AsRef<str>,
1198        hint: PlanHint,
1199    ) -> Arc<ExecutionPlan> {
1200        let key = key.into();
1201        let plan = Arc::new(ExecutionPlan::new(sql, hint));
1202        let hash = plan.hash;
1203
1204        let mut plans = self.plans.write();
1205        let mut key_index = self.key_index.write();
1206
1207        // Evict if at capacity
1208        if plans.len() >= self.max_size && !plans.contains_key(&hash) {
1209            // Simple eviction: remove least used
1210            if let Some((&evict_hash, _)) = plans.iter().min_by_key(|(_, p)| p.use_count()) {
1211                plans.remove(&evict_hash);
1212                key_index.retain(|_, &mut v| v != evict_hash);
1213            }
1214        }
1215
1216        plans.insert(hash, Arc::clone(&plan));
1217        key_index.insert(key, hash);
1218
1219        plan
1220    }
1221
1222    /// Register a plan with estimated cost.
1223    pub fn register_with_cost(
1224        &self,
1225        key: impl Into<Cow<'static, str>>,
1226        sql: impl AsRef<str>,
1227        hint: PlanHint,
1228        cost: f64,
1229    ) -> Arc<ExecutionPlan> {
1230        let key = key.into();
1231        let plan = Arc::new(ExecutionPlan::with_cost(sql, hint, cost));
1232        let hash = plan.hash;
1233
1234        let mut plans = self.plans.write();
1235        let mut key_index = self.key_index.write();
1236
1237        if plans.len() >= self.max_size
1238            && !plans.contains_key(&hash)
1239            && let Some((&evict_hash, _)) = plans.iter().min_by_key(|(_, p)| p.use_count())
1240        {
1241            plans.remove(&evict_hash);
1242            key_index.retain(|_, &mut v| v != evict_hash);
1243        }
1244
1245        plans.insert(hash, Arc::clone(&plan));
1246        key_index.insert(key, hash);
1247
1248        plan
1249    }
1250
1251    /// Get a cached execution plan.
1252    pub fn get(&self, key: &str) -> Option<Arc<ExecutionPlan>> {
1253        let hash = {
1254            let key_index = self.key_index.read();
1255            *key_index.get(key)?
1256        };
1257
1258        self.plans.read().get(&hash).cloned()
1259    }
1260
1261    /// Get a plan by its hash.
1262    pub fn get_by_hash(&self, hash: u64) -> Option<Arc<ExecutionPlan>> {
1263        self.plans.read().get(&hash).cloned()
1264    }
1265
1266    /// Get or create a plan.
1267    pub fn get_or_register<F>(
1268        &self,
1269        key: impl Into<Cow<'static, str>>,
1270        sql_fn: F,
1271        hint: PlanHint,
1272    ) -> Arc<ExecutionPlan>
1273    where
1274        F: FnOnce() -> String,
1275    {
1276        let key = key.into();
1277
1278        // Fast path: check if exists
1279        if let Some(plan) = self.get(key.as_ref()) {
1280            return plan;
1281        }
1282
1283        // Slow path: create and register
1284        self.register(key, sql_fn(), hint)
1285    }
1286
1287    /// Record execution timing for a plan.
1288    pub fn record_execution(&self, key: &str, duration_us: u64) {
1289        if let Some(plan) = self.get(key) {
1290            plan.record_execution(duration_us);
1291        }
1292    }
1293
1294    /// Get plans sorted by average execution time (slowest first).
1295    pub fn slowest_queries(&self, limit: usize) -> Vec<Arc<ExecutionPlan>> {
1296        let plans = self.plans.read();
1297        let mut sorted: Vec<_> = plans.values().cloned().collect();
1298        sorted.sort_by_key(|a| std::cmp::Reverse(a.avg_execution_us()));
1299        sorted.truncate(limit);
1300        sorted
1301    }
1302
1303    /// Get plans sorted by use count (most used first).
1304    pub fn most_used(&self, limit: usize) -> Vec<Arc<ExecutionPlan>> {
1305        let plans = self.plans.read();
1306        let mut sorted: Vec<_> = plans.values().cloned().collect();
1307        sorted.sort_by_key(|a| std::cmp::Reverse(a.use_count()));
1308        sorted.truncate(limit);
1309        sorted
1310    }
1311
1312    /// Clear all cached plans.
1313    pub fn clear(&self) {
1314        self.plans.write().clear();
1315        self.key_index.write().clear();
1316    }
1317
1318    /// Get the number of cached plans.
1319    pub fn len(&self) -> usize {
1320        self.plans.read().len()
1321    }
1322
1323    /// Check if the cache is empty.
1324    pub fn is_empty(&self) -> bool {
1325        self.plans.read().is_empty()
1326    }
1327}