Skip to main content

oximedia_gpu/
pipeline_cache.rs

1//! Pipeline object cache for the GPU crate.
2//!
3//! Avoids redundant pipeline compilation by keying compiled pipeline
4//! binaries on a `u64` hash.  On a cache miss, a user-supplied closure
5//! is called to produce the binary and the result is stored for subsequent
6//! hits.
7//!
8//! The cache is entirely in-memory; for disk persistence see the
9//! higher-level `shader_cache` module.
10//!
11//! # Example
12//!
13//! ```rust
14//! use oximedia_gpu::pipeline_cache::PipelineCache;
15//!
16//! let mut cache = PipelineCache::new();
17//! let binary = cache.get_or_create(0xDEAD_BEEF, || vec![0x01, 0x02, 0x03]);
18//! assert_eq!(binary, &[0x01, 0x02, 0x03]);
19//! // Second call returns cached value without invoking the closure.
20//! let again = cache.get_or_create(0xDEAD_BEEF, || panic!("should not be called"));
21//! assert_eq!(again, &[0x01, 0x02, 0x03]);
22//! ```
23
24#![allow(dead_code)]
25
26use std::collections::HashMap;
27
28// ── PipelineCache ─────────────────────────────────────────────────────────────
29
30/// In-memory cache mapping pipeline keys to compiled pipeline binaries.
31///
32/// `key` is a caller-defined `u64` (typically a hash of the pipeline
33/// descriptor / shader source).  `value` is an opaque `Vec<u8>` that
34/// represents the compiled pipeline object — format is backend-specific.
35#[derive(Debug, Default)]
36pub struct PipelineCache {
37    entries: HashMap<u64, Vec<u8>>,
38    /// Total number of cache lookups (hits + misses).
39    pub total_lookups: u64,
40    /// Total number of cache hits.
41    pub hits: u64,
42    /// Total number of cache misses (compilations triggered).
43    pub misses: u64,
44}
45
46impl PipelineCache {
47    /// Create a new, empty pipeline cache.
48    #[must_use]
49    pub fn new() -> Self {
50        Self::default()
51    }
52
53    /// Return a reference to the cached binary for `key`, compiling and
54    /// caching it first if not already present.
55    ///
56    /// `create_fn` is invoked **at most once** per unique `key`.
57    pub fn get_or_create(&mut self, key: u64, create_fn: impl Fn() -> Vec<u8>) -> &Vec<u8> {
58        self.total_lookups += 1;
59        if !self.entries.contains_key(&key) {
60            self.misses += 1;
61            let binary = create_fn();
62            self.entries.insert(key, binary);
63        } else {
64            self.hits += 1;
65        }
66        // Safety: we just inserted if absent.
67        self.entries.get(&key).expect("entry was just inserted")
68    }
69
70    /// Return the cached binary for `key`, or `None` if not present.
71    #[must_use]
72    pub fn get(&self, key: u64) -> Option<&Vec<u8>> {
73        self.entries.get(&key)
74    }
75
76    /// Remove a cached entry (e.g. when the shader source changes).
77    pub fn invalidate(&mut self, key: u64) -> bool {
78        self.entries.remove(&key).is_some()
79    }
80
81    /// Remove all cached entries.
82    pub fn clear(&mut self) {
83        self.entries.clear();
84        self.total_lookups = 0;
85        self.hits = 0;
86        self.misses = 0;
87    }
88
89    /// Number of cached pipelines.
90    #[must_use]
91    pub fn len(&self) -> usize {
92        self.entries.len()
93    }
94
95    /// `true` if the cache is empty.
96    #[must_use]
97    pub fn is_empty(&self) -> bool {
98        self.entries.is_empty()
99    }
100
101    /// Hit ratio in `[0.0, 1.0]`.  Returns `0.0` if no lookups have occurred.
102    #[must_use]
103    pub fn hit_ratio(&self) -> f64 {
104        if self.total_lookups == 0 {
105            0.0
106        } else {
107            self.hits as f64 / self.total_lookups as f64
108        }
109    }
110}
111
112// ── Tests ─────────────────────────────────────────────────────────────────────
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn test_get_or_create_miss_then_hit() {
120        let mut cache = PipelineCache::new();
121        let b1 = cache.get_or_create(1, || vec![0xAA]);
122        assert_eq!(b1, &[0xAA]);
123        // Second call: closure returns different value but must not be invoked.
124        let b2 = cache.get_or_create(1, || vec![0xBB]);
125        assert_eq!(
126            b2,
127            &[0xAA],
128            "cached value must be returned, not new closure result"
129        );
130        // Verify miss count is 1 (only one compilation).
131        assert_eq!(cache.misses, 1);
132        assert_eq!(cache.hits, 1);
133    }
134
135    #[test]
136    fn test_different_keys_stored_separately() {
137        let mut cache = PipelineCache::new();
138        cache.get_or_create(10, || vec![1]);
139        cache.get_or_create(20, || vec![2]);
140        assert_eq!(cache.len(), 2);
141        assert_eq!(cache.get(10), Some(&vec![1]));
142        assert_eq!(cache.get(20), Some(&vec![2]));
143    }
144
145    #[test]
146    fn test_invalidate_removes_entry() {
147        let mut cache = PipelineCache::new();
148        cache.get_or_create(42, || vec![0x42]);
149        assert!(cache.invalidate(42));
150        assert!(cache.get(42).is_none());
151        assert!(!cache.invalidate(42), "already removed");
152    }
153
154    #[test]
155    fn test_clear_resets_state() {
156        let mut cache = PipelineCache::new();
157        cache.get_or_create(1, || vec![1]);
158        cache.get_or_create(2, || vec![2]);
159        cache.clear();
160        assert!(cache.is_empty());
161        assert_eq!(cache.total_lookups, 0);
162    }
163
164    #[test]
165    fn test_hit_ratio_calculation() {
166        let mut cache = PipelineCache::new();
167        cache.get_or_create(1, || vec![1]); // miss
168        cache.get_or_create(1, || vec![1]); // hit
169        cache.get_or_create(1, || vec![1]); // hit
170        assert!((cache.hit_ratio() - 2.0 / 3.0).abs() < 1e-9);
171    }
172
173    #[test]
174    fn test_hit_ratio_no_lookups_returns_zero() {
175        let cache = PipelineCache::new();
176        assert_eq!(cache.hit_ratio(), 0.0);
177    }
178
179    #[test]
180    fn test_empty_binary_cached() {
181        let mut cache = PipelineCache::new();
182        let b = cache.get_or_create(0, || vec![]);
183        assert!(b.is_empty());
184        assert_eq!(cache.len(), 1);
185    }
186}