Skip to main content

oximedia_gpu/
shader_cache.rs

1//! GPU shader cache management.
2//!
3//! This module provides compiled shader storage, cache eviction, and version
4//! tracking to avoid redundant shader compilation work.
5
6use std::collections::HashMap;
7use std::time::{Duration, SystemTime};
8
9/// Version identifier for a compiled shader.
10#[allow(dead_code)]
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12pub struct ShaderVersion {
13    /// Source code hash (simple string identifier).
14    pub source_hash: u64,
15    /// Backend name (e.g. "vulkan", "metal", "dx12").
16    pub backend: String,
17    /// Optional feature flags bitmask.
18    pub feature_flags: u32,
19}
20
21impl ShaderVersion {
22    /// Create a new `ShaderVersion`.
23    #[allow(dead_code)]
24    #[must_use]
25    pub fn new(source_hash: u64, backend: impl Into<String>, feature_flags: u32) -> Self {
26        Self {
27            source_hash,
28            backend: backend.into(),
29            feature_flags,
30        }
31    }
32}
33
34/// A compiled shader blob stored in the cache.
35#[allow(dead_code)]
36#[derive(Debug, Clone)]
37pub struct CompiledShader {
38    /// The shader byte code or SPIR-V blob.
39    pub bytecode: Vec<u8>,
40    /// Version information.
41    pub version: ShaderVersion,
42    /// When this entry was inserted.
43    pub created_at: SystemTime,
44    /// Approximate size of the bytecode in bytes.
45    pub size_bytes: usize,
46    /// How many times this shader has been retrieved from the cache.
47    pub hit_count: u64,
48}
49
50impl CompiledShader {
51    /// Create a new `CompiledShader`.
52    #[allow(dead_code)]
53    #[must_use]
54    pub fn new(bytecode: Vec<u8>, version: ShaderVersion) -> Self {
55        let size_bytes = bytecode.len();
56        Self {
57            bytecode,
58            version,
59            created_at: SystemTime::now(),
60            size_bytes,
61            hit_count: 0,
62        }
63    }
64}
65
66/// Statistics for the shader cache.
67#[allow(dead_code)]
68#[derive(Debug, Clone, Default)]
69pub struct ShaderCacheStats {
70    /// Total number of entries currently held.
71    pub entry_count: usize,
72    /// Total bytes occupied by all cached bytecodes.
73    pub total_bytes: usize,
74    /// Total number of cache hits.
75    pub hits: u64,
76    /// Total number of cache misses.
77    pub misses: u64,
78    /// Total number of evictions performed.
79    pub evictions: u64,
80}
81
82/// Eviction policy for the shader cache.
83#[allow(dead_code)]
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
85pub enum EvictionPolicy {
86    /// Least-recently-used eviction.
87    #[default]
88    Lru,
89    /// Least-frequently-used eviction.
90    Lfu,
91    /// Oldest-first eviction.
92    OldestFirst,
93}
94
95/// In-process GPU shader cache.
96#[allow(dead_code)]
97pub struct GpuShaderCache {
98    entries: HashMap<ShaderVersion, CompiledShader>,
99    /// Maximum total bytes stored before eviction triggers.
100    max_bytes: usize,
101    /// Maximum number of entries before eviction triggers.
102    max_entries: usize,
103    /// Selected eviction policy.
104    policy: EvictionPolicy,
105    /// Accumulated statistics.
106    stats: ShaderCacheStats,
107    /// Last-access timestamps (version → instant-as-duration-since-epoch).
108    last_access: HashMap<ShaderVersion, SystemTime>,
109}
110
111impl GpuShaderCache {
112    /// Create a new shader cache.
113    ///
114    /// * `max_bytes`   – evict when total payload exceeds this many bytes.
115    /// * `max_entries` – evict when the number of entries exceeds this.
116    /// * `policy`      – which eviction strategy to use.
117    #[allow(dead_code)]
118    #[must_use]
119    pub fn new(max_bytes: usize, max_entries: usize, policy: EvictionPolicy) -> Self {
120        Self {
121            entries: HashMap::new(),
122            max_bytes,
123            max_entries,
124            policy,
125            stats: ShaderCacheStats::default(),
126            last_access: HashMap::new(),
127        }
128    }
129
130    /// Insert a compiled shader into the cache.
131    ///
132    /// If the cache is full, one entry is evicted according to the current
133    /// policy before the new entry is stored.
134    #[allow(dead_code)]
135    pub fn insert(&mut self, shader: CompiledShader) {
136        // Evict if needed before inserting.
137        while self.needs_eviction(shader.size_bytes) {
138            if !self.evict_one() {
139                break; // Nothing left to evict.
140            }
141        }
142
143        self.stats.total_bytes += shader.size_bytes;
144        self.stats.entry_count += 1;
145        self.last_access
146            .insert(shader.version.clone(), SystemTime::now());
147        self.entries.insert(shader.version.clone(), shader);
148    }
149
150    /// Retrieve a compiled shader by its version key.
151    ///
152    /// Returns `None` if the shader is not cached.
153    #[allow(dead_code)]
154    pub fn get(&mut self, version: &ShaderVersion) -> Option<&CompiledShader> {
155        if self.entries.contains_key(version) {
156            self.stats.hits += 1;
157            // Update access time and hit count.
158            self.last_access.insert(version.clone(), SystemTime::now());
159            if let Some(e) = self.entries.get_mut(version) {
160                e.hit_count += 1;
161            }
162            self.entries.get(version)
163        } else {
164            self.stats.misses += 1;
165            None
166        }
167    }
168
169    /// Check whether a shader is present in the cache.
170    #[allow(dead_code)]
171    #[must_use]
172    pub fn contains(&self, version: &ShaderVersion) -> bool {
173        self.entries.contains_key(version)
174    }
175
176    /// Remove a specific shader from the cache.
177    #[allow(dead_code)]
178    pub fn remove(&mut self, version: &ShaderVersion) -> Option<CompiledShader> {
179        if let Some(shader) = self.entries.remove(version) {
180            self.stats.total_bytes = self.stats.total_bytes.saturating_sub(shader.size_bytes);
181            self.stats.entry_count = self.stats.entry_count.saturating_sub(1);
182            self.last_access.remove(version);
183            Some(shader)
184        } else {
185            None
186        }
187    }
188
189    /// Remove all shaders for a given backend.
190    #[allow(dead_code)]
191    pub fn invalidate_backend(&mut self, backend: &str) {
192        let to_remove: Vec<ShaderVersion> = self
193            .entries
194            .keys()
195            .filter(|v| v.backend == backend)
196            .cloned()
197            .collect();
198        for key in to_remove {
199            self.remove(&key);
200        }
201    }
202
203    /// Clear all entries.
204    #[allow(dead_code)]
205    pub fn clear(&mut self) {
206        self.entries.clear();
207        self.last_access.clear();
208        self.stats.total_bytes = 0;
209        self.stats.entry_count = 0;
210    }
211
212    /// Current statistics.
213    #[allow(dead_code)]
214    #[must_use]
215    pub fn stats(&self) -> &ShaderCacheStats {
216        &self.stats
217    }
218
219    /// Number of entries currently held.
220    #[allow(dead_code)]
221    #[must_use]
222    pub fn len(&self) -> usize {
223        self.entries.len()
224    }
225
226    /// Returns true if the cache is empty.
227    #[allow(dead_code)]
228    #[must_use]
229    pub fn is_empty(&self) -> bool {
230        self.entries.is_empty()
231    }
232
233    // -----------------------------------------------------------------------
234    // Private helpers
235    // -----------------------------------------------------------------------
236
237    fn needs_eviction(&self, incoming_bytes: usize) -> bool {
238        let bytes_after = self.stats.total_bytes + incoming_bytes;
239        bytes_after > self.max_bytes || self.stats.entry_count >= self.max_entries
240    }
241
242    /// Evict one entry according to the current policy. Returns `true` if
243    /// something was actually removed.
244    fn evict_one(&mut self) -> bool {
245        if self.entries.is_empty() {
246            return false;
247        }
248
249        let victim_key: Option<ShaderVersion> = match self.policy {
250            EvictionPolicy::Lru => {
251                // Remove the entry with the oldest last-access time.
252                self.last_access
253                    .iter()
254                    .min_by_key(|(_, t)| *t)
255                    .map(|(k, _)| k.clone())
256            }
257            EvictionPolicy::Lfu => {
258                // Remove the entry with the lowest hit_count.
259                self.entries
260                    .iter()
261                    .min_by_key(|(_, v)| v.hit_count)
262                    .map(|(k, _)| k.clone())
263            }
264            EvictionPolicy::OldestFirst => {
265                // Remove the entry created earliest.
266                self.entries
267                    .iter()
268                    .min_by_key(|(_, v)| v.created_at)
269                    .map(|(k, _)| k.clone())
270            }
271        };
272
273        if let Some(key) = victim_key {
274            self.remove(&key);
275            self.stats.evictions += 1;
276            true
277        } else {
278            false
279        }
280    }
281}
282
283impl Default for GpuShaderCache {
284    fn default() -> Self {
285        // 64 MB default cache, up to 256 entries.
286        Self::new(64 * 1024 * 1024, 256, EvictionPolicy::Lru)
287    }
288}
289
290/// Compute a simple 64-bit FNV-1a hash of a byte slice.
291#[allow(dead_code)]
292#[must_use]
293pub fn hash_source(data: &[u8]) -> u64 {
294    const FNV_OFFSET: u64 = 14_695_981_039_346_656_037;
295    const FNV_PRIME: u64 = 1_099_511_628_211;
296    let mut hash = FNV_OFFSET;
297    for &byte in data {
298        hash ^= u64(byte);
299        hash = hash.wrapping_mul(FNV_PRIME);
300    }
301    hash
302}
303
304/// Convenience: convert a `u8` to `u64` without a cast warning.
305#[inline(always)]
306fn u64(v: u8) -> u64 {
307    u64::from(v)
308}
309
310/// Estimate the age of a `SystemTime` relative to now.
311#[allow(dead_code)]
312#[must_use]
313pub fn age_of(t: SystemTime) -> Duration {
314    SystemTime::now()
315        .duration_since(t)
316        .unwrap_or(Duration::ZERO)
317}
318
319// ---------------------------------------------------------------------------
320// Unit tests
321// ---------------------------------------------------------------------------
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326
327    fn make_version(hash: u64) -> ShaderVersion {
328        ShaderVersion::new(hash, "vulkan", 0)
329    }
330
331    fn make_shader(hash: u64, size: usize) -> CompiledShader {
332        CompiledShader::new(vec![0u8; size], make_version(hash))
333    }
334
335    #[test]
336    fn test_insert_and_get() {
337        let mut cache = GpuShaderCache::default();
338        let shader = make_shader(1, 100);
339        let version = shader.version.clone();
340        cache.insert(shader);
341        assert!(cache.get(&version).is_some());
342    }
343
344    #[test]
345    fn test_cache_miss() {
346        let mut cache = GpuShaderCache::default();
347        let v = make_version(42);
348        assert!(cache.get(&v).is_none());
349        assert_eq!(cache.stats().misses, 1);
350    }
351
352    #[test]
353    fn test_hit_count_increments() {
354        let mut cache = GpuShaderCache::default();
355        let shader = make_shader(7, 50);
356        let version = shader.version.clone();
357        cache.insert(shader);
358        cache.get(&version);
359        cache.get(&version);
360        assert_eq!(cache.get(&version).unwrap().hit_count, 3);
361    }
362
363    #[test]
364    fn test_remove() {
365        let mut cache = GpuShaderCache::default();
366        let shader = make_shader(99, 200);
367        let version = shader.version.clone();
368        cache.insert(shader);
369        assert!(cache.remove(&version).is_some());
370        assert!(cache.get(&version).is_none());
371    }
372
373    #[test]
374    fn test_contains() {
375        let mut cache = GpuShaderCache::default();
376        let shader = make_shader(5, 10);
377        let version = shader.version.clone();
378        assert!(!cache.contains(&version));
379        cache.insert(shader);
380        assert!(cache.contains(&version));
381    }
382
383    #[test]
384    fn test_clear() {
385        let mut cache = GpuShaderCache::default();
386        cache.insert(make_shader(1, 10));
387        cache.insert(make_shader(2, 10));
388        cache.clear();
389        assert!(cache.is_empty());
390        assert_eq!(cache.stats().total_bytes, 0);
391    }
392
393    #[test]
394    fn test_eviction_by_entry_count() {
395        // Allow at most 2 entries.
396        let mut cache = GpuShaderCache::new(usize::MAX, 2, EvictionPolicy::Lfu);
397        cache.insert(make_shader(1, 10));
398        cache.insert(make_shader(2, 10));
399        // Hitting shader 2 raises its hit count so shader 1 gets evicted (LFU).
400        cache.get(&make_version(2));
401        // Insert a third shader – should evict the LFU entry.
402        cache.insert(make_shader(3, 10));
403        assert_eq!(cache.len(), 2);
404        assert!(cache.stats().evictions >= 1);
405    }
406
407    #[test]
408    fn test_eviction_by_bytes() {
409        // Allow at most 30 bytes.
410        let mut cache = GpuShaderCache::new(30, usize::MAX, EvictionPolicy::OldestFirst);
411        cache.insert(make_shader(1, 15));
412        cache.insert(make_shader(2, 15));
413        // Third insert (15 bytes) exceeds the cap – one entry should be evicted.
414        cache.insert(make_shader(3, 15));
415        assert!(cache.stats().evictions >= 1);
416    }
417
418    #[test]
419    fn test_invalidate_backend() {
420        let mut cache = GpuShaderCache::default();
421        let v1 = ShaderVersion::new(1, "vulkan", 0);
422        let v2 = ShaderVersion::new(2, "metal", 0);
423        cache.insert(CompiledShader::new(vec![0u8; 10], v1));
424        cache.insert(CompiledShader::new(vec![0u8; 10], v2.clone()));
425        cache.invalidate_backend("vulkan");
426        assert!(!cache.contains(&ShaderVersion::new(1, "vulkan", 0)));
427        assert!(cache.contains(&v2));
428    }
429
430    #[test]
431    fn test_hash_source_deterministic() {
432        let data = b"hello world shader";
433        assert_eq!(hash_source(data), hash_source(data));
434    }
435
436    #[test]
437    fn test_hash_source_differs_for_different_inputs() {
438        assert_ne!(hash_source(b"shader_a"), hash_source(b"shader_b"));
439    }
440
441    #[test]
442    fn test_default_cache_capacity() {
443        let cache = GpuShaderCache::default();
444        assert!(cache.is_empty());
445    }
446
447    #[test]
448    fn test_shader_version_equality() {
449        let v1 = ShaderVersion::new(10, "dx12", 3);
450        let v2 = ShaderVersion::new(10, "dx12", 3);
451        let v3 = ShaderVersion::new(10, "dx12", 4);
452        assert_eq!(v1, v2);
453        assert_ne!(v1, v3);
454    }
455
456    #[test]
457    fn test_age_of_is_non_negative() {
458        let t = SystemTime::now();
459        let age = age_of(t);
460        // Age should be very small but non-negative.
461        assert!(age < Duration::from_secs(5));
462    }
463}