Skip to main content

oximedia_gpu/
shader_cache.rs

1//! GPU shader cache management.
2//!
3//! This module provides compiled shader storage, cache eviction, version
4//! tracking, and **disk-persistent** caching to avoid redundant shader
5//! compilation work across process restarts.
6//!
7//! # Disk cache layout
8//!
9//! The persistent cache stores each entry as a pair of files inside the
10//! configured directory:
11//!
12//! ```text
13//! <cache_dir>/<hex_hash>_<backend>_<flags>.shd   – raw bytecode
14//! <cache_dir>/<hex_hash>_<backend>_<flags>.meta  – metadata (JSON-like text)
15//! ```
16//!
17//! The text metadata file contains a single line:
18//! `<source_hash> <backend> <feature_flags> <created_unix_secs>`.
19
20use std::collections::HashMap;
21use std::io::{Read, Write};
22use std::path::{Path, PathBuf};
23use std::time::{Duration, SystemTime, UNIX_EPOCH};
24
25/// Version identifier for a compiled shader.
26#[allow(dead_code)]
27#[derive(Debug, Clone, PartialEq, Eq, Hash)]
28pub struct ShaderVersion {
29    /// Source code hash (simple string identifier).
30    pub source_hash: u64,
31    /// Backend name (e.g. "vulkan", "metal", "dx12").
32    pub backend: String,
33    /// Optional feature flags bitmask.
34    pub feature_flags: u32,
35}
36
37impl ShaderVersion {
38    /// Create a new `ShaderVersion`.
39    #[allow(dead_code)]
40    #[must_use]
41    pub fn new(source_hash: u64, backend: impl Into<String>, feature_flags: u32) -> Self {
42        Self {
43            source_hash,
44            backend: backend.into(),
45            feature_flags,
46        }
47    }
48}
49
50/// A compiled shader blob stored in the cache.
51#[allow(dead_code)]
52#[derive(Debug, Clone)]
53pub struct CompiledShader {
54    /// The shader byte code or SPIR-V blob.
55    pub bytecode: Vec<u8>,
56    /// Version information.
57    pub version: ShaderVersion,
58    /// When this entry was inserted.
59    pub created_at: SystemTime,
60    /// Approximate size of the bytecode in bytes.
61    pub size_bytes: usize,
62    /// How many times this shader has been retrieved from the cache.
63    pub hit_count: u64,
64}
65
66impl CompiledShader {
67    /// Create a new `CompiledShader`.
68    #[allow(dead_code)]
69    #[must_use]
70    pub fn new(bytecode: Vec<u8>, version: ShaderVersion) -> Self {
71        let size_bytes = bytecode.len();
72        Self {
73            bytecode,
74            version,
75            created_at: SystemTime::now(),
76            size_bytes,
77            hit_count: 0,
78        }
79    }
80}
81
82/// Statistics for the shader cache.
83#[allow(dead_code)]
84#[derive(Debug, Clone, Default)]
85pub struct ShaderCacheStats {
86    /// Total number of entries currently held.
87    pub entry_count: usize,
88    /// Total bytes occupied by all cached bytecodes.
89    pub total_bytes: usize,
90    /// Total number of cache hits.
91    pub hits: u64,
92    /// Total number of cache misses.
93    pub misses: u64,
94    /// Total number of evictions performed.
95    pub evictions: u64,
96}
97
98/// Eviction policy for the shader cache.
99#[allow(dead_code)]
100#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
101pub enum EvictionPolicy {
102    /// Least-recently-used eviction.
103    #[default]
104    Lru,
105    /// Least-frequently-used eviction.
106    Lfu,
107    /// Oldest-first eviction.
108    OldestFirst,
109}
110
111/// In-process GPU shader cache.
112#[allow(dead_code)]
113pub struct GpuShaderCache {
114    entries: HashMap<ShaderVersion, CompiledShader>,
115    /// Maximum total bytes stored before eviction triggers.
116    max_bytes: usize,
117    /// Maximum number of entries before eviction triggers.
118    max_entries: usize,
119    /// Selected eviction policy.
120    policy: EvictionPolicy,
121    /// Accumulated statistics.
122    stats: ShaderCacheStats,
123    /// Last-access timestamps (version → instant-as-duration-since-epoch).
124    last_access: HashMap<ShaderVersion, SystemTime>,
125}
126
127impl GpuShaderCache {
128    /// Create a new shader cache.
129    ///
130    /// * `max_bytes`   – evict when total payload exceeds this many bytes.
131    /// * `max_entries` – evict when the number of entries exceeds this.
132    /// * `policy`      – which eviction strategy to use.
133    #[allow(dead_code)]
134    #[must_use]
135    pub fn new(max_bytes: usize, max_entries: usize, policy: EvictionPolicy) -> Self {
136        Self {
137            entries: HashMap::new(),
138            max_bytes,
139            max_entries,
140            policy,
141            stats: ShaderCacheStats::default(),
142            last_access: HashMap::new(),
143        }
144    }
145
146    /// Insert a compiled shader into the cache.
147    ///
148    /// If the cache is full, one entry is evicted according to the current
149    /// policy before the new entry is stored.
150    #[allow(dead_code)]
151    pub fn insert(&mut self, shader: CompiledShader) {
152        // Evict if needed before inserting.
153        while self.needs_eviction(shader.size_bytes) {
154            if !self.evict_one() {
155                break; // Nothing left to evict.
156            }
157        }
158
159        self.stats.total_bytes += shader.size_bytes;
160        self.stats.entry_count += 1;
161        self.last_access
162            .insert(shader.version.clone(), SystemTime::now());
163        self.entries.insert(shader.version.clone(), shader);
164    }
165
166    /// Retrieve a compiled shader by its version key.
167    ///
168    /// Returns `None` if the shader is not cached.
169    #[allow(dead_code)]
170    pub fn get(&mut self, version: &ShaderVersion) -> Option<&CompiledShader> {
171        if self.entries.contains_key(version) {
172            self.stats.hits += 1;
173            // Update access time and hit count.
174            self.last_access.insert(version.clone(), SystemTime::now());
175            if let Some(e) = self.entries.get_mut(version) {
176                e.hit_count += 1;
177            }
178            self.entries.get(version)
179        } else {
180            self.stats.misses += 1;
181            None
182        }
183    }
184
185    /// Check whether a shader is present in the cache.
186    #[allow(dead_code)]
187    #[must_use]
188    pub fn contains(&self, version: &ShaderVersion) -> bool {
189        self.entries.contains_key(version)
190    }
191
192    /// Remove a specific shader from the cache.
193    #[allow(dead_code)]
194    pub fn remove(&mut self, version: &ShaderVersion) -> Option<CompiledShader> {
195        if let Some(shader) = self.entries.remove(version) {
196            self.stats.total_bytes = self.stats.total_bytes.saturating_sub(shader.size_bytes);
197            self.stats.entry_count = self.stats.entry_count.saturating_sub(1);
198            self.last_access.remove(version);
199            Some(shader)
200        } else {
201            None
202        }
203    }
204
205    /// Remove all shaders for a given backend.
206    #[allow(dead_code)]
207    pub fn invalidate_backend(&mut self, backend: &str) {
208        let to_remove: Vec<ShaderVersion> = self
209            .entries
210            .keys()
211            .filter(|v| v.backend == backend)
212            .cloned()
213            .collect();
214        for key in to_remove {
215            self.remove(&key);
216        }
217    }
218
219    /// Clear all entries.
220    #[allow(dead_code)]
221    pub fn clear(&mut self) {
222        self.entries.clear();
223        self.last_access.clear();
224        self.stats.total_bytes = 0;
225        self.stats.entry_count = 0;
226    }
227
228    /// Current statistics.
229    #[allow(dead_code)]
230    #[must_use]
231    pub fn stats(&self) -> &ShaderCacheStats {
232        &self.stats
233    }
234
235    /// Number of entries currently held.
236    #[allow(dead_code)]
237    #[must_use]
238    pub fn len(&self) -> usize {
239        self.entries.len()
240    }
241
242    /// Returns true if the cache is empty.
243    #[allow(dead_code)]
244    #[must_use]
245    pub fn is_empty(&self) -> bool {
246        self.entries.is_empty()
247    }
248
249    // -----------------------------------------------------------------------
250    // Private helpers
251    // -----------------------------------------------------------------------
252
253    fn needs_eviction(&self, incoming_bytes: usize) -> bool {
254        let bytes_after = self.stats.total_bytes + incoming_bytes;
255        bytes_after > self.max_bytes || self.stats.entry_count >= self.max_entries
256    }
257
258    /// Evict one entry according to the current policy. Returns `true` if
259    /// something was actually removed.
260    fn evict_one(&mut self) -> bool {
261        if self.entries.is_empty() {
262            return false;
263        }
264
265        let victim_key: Option<ShaderVersion> = match self.policy {
266            EvictionPolicy::Lru => {
267                // Remove the entry with the oldest last-access time.
268                self.last_access
269                    .iter()
270                    .min_by_key(|(_, t)| *t)
271                    .map(|(k, _)| k.clone())
272            }
273            EvictionPolicy::Lfu => {
274                // Remove the entry with the lowest hit_count.
275                self.entries
276                    .iter()
277                    .min_by_key(|(_, v)| v.hit_count)
278                    .map(|(k, _)| k.clone())
279            }
280            EvictionPolicy::OldestFirst => {
281                // Remove the entry created earliest.
282                self.entries
283                    .iter()
284                    .min_by_key(|(_, v)| v.created_at)
285                    .map(|(k, _)| k.clone())
286            }
287        };
288
289        if let Some(key) = victim_key {
290            self.remove(&key);
291            self.stats.evictions += 1;
292            true
293        } else {
294            false
295        }
296    }
297}
298
299impl Default for GpuShaderCache {
300    fn default() -> Self {
301        // 64 MB default cache, up to 256 entries.
302        Self::new(64 * 1024 * 1024, 256, EvictionPolicy::Lru)
303    }
304}
305
306/// Compute a simple 64-bit FNV-1a hash of a byte slice.
307#[allow(dead_code)]
308#[must_use]
309pub fn hash_source(data: &[u8]) -> u64 {
310    const FNV_OFFSET: u64 = 14_695_981_039_346_656_037;
311    const FNV_PRIME: u64 = 1_099_511_628_211;
312    let mut hash = FNV_OFFSET;
313    for &byte in data {
314        hash ^= u64(byte);
315        hash = hash.wrapping_mul(FNV_PRIME);
316    }
317    hash
318}
319
320/// Convenience: convert a `u8` to `u64` without a cast warning.
321#[inline(always)]
322fn u64(v: u8) -> u64 {
323    u64::from(v)
324}
325
326/// Estimate the age of a `SystemTime` relative to now.
327#[allow(dead_code)]
328#[must_use]
329pub fn age_of(t: SystemTime) -> Duration {
330    SystemTime::now()
331        .duration_since(t)
332        .unwrap_or(Duration::ZERO)
333}
334
335// =============================================================================
336// Disk-persistent shader cache (Task 3)
337// =============================================================================
338
339/// Errors that can occur during disk cache I/O.
340#[derive(Debug)]
341pub enum DiskCacheError {
342    /// A filesystem I/O error occurred.
343    Io(std::io::Error),
344    /// The metadata file was malformed (could not be parsed).
345    MalformedMetadata(String),
346}
347
348impl std::fmt::Display for DiskCacheError {
349    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350        match self {
351            Self::Io(e) => write!(f, "disk cache I/O error: {e}"),
352            Self::MalformedMetadata(s) => write!(f, "malformed cache metadata: {s}"),
353        }
354    }
355}
356
357impl From<std::io::Error> for DiskCacheError {
358    fn from(e: std::io::Error) -> Self {
359        Self::Io(e)
360    }
361}
362
363/// Statistics for the disk cache.
364#[allow(dead_code)]
365#[derive(Debug, Clone, Default)]
366pub struct DiskCacheStats {
367    /// Number of times a bytecode blob was successfully read from disk.
368    pub disk_hits: u64,
369    /// Number of times a lookup fell through to compilation.
370    pub disk_misses: u64,
371    /// Number of bytecode blobs written to disk.
372    pub disk_writes: u64,
373    /// Number of I/O errors encountered (non-fatal; treated as misses).
374    pub io_errors: u64,
375}
376
377/// Disk-persistent GPU shader cache.
378///
379/// Each entry is stored as two files in `cache_dir`:
380/// - `<key>.shd` – the raw bytecode blob.
381/// - `<key>.meta` – a single-line text file:
382///   `<source_hash> <backend> <feature_flags> <unix_secs_since_epoch>`.
383///
384/// The cache directory is created automatically on first use.
385#[allow(dead_code)]
386pub struct DiskShaderCache {
387    cache_dir: PathBuf,
388    stats: DiskCacheStats,
389}
390
391impl DiskShaderCache {
392    /// Open (or create) a disk shader cache rooted at `cache_dir`.
393    ///
394    /// The directory is created if it does not already exist.
395    ///
396    /// # Errors
397    ///
398    /// Returns a [`DiskCacheError`] if the directory cannot be created.
399    #[allow(dead_code)]
400    pub fn open(cache_dir: impl AsRef<Path>) -> Result<Self, DiskCacheError> {
401        let cache_dir = cache_dir.as_ref().to_path_buf();
402        std::fs::create_dir_all(&cache_dir)?;
403        Ok(Self {
404            cache_dir,
405            stats: DiskCacheStats::default(),
406        })
407    }
408
409    /// Look up a shader by its [`ShaderVersion`].
410    ///
411    /// Returns `Some(bytecode)` if both `.shd` and `.meta` files exist and
412    /// the metadata matches the requested version.  Returns `None` on any
413    /// mismatch or I/O error.
414    #[allow(dead_code)]
415    pub fn get(&mut self, version: &ShaderVersion) -> Option<Vec<u8>> {
416        let key = self.entry_key(version);
417        let shd_path = self.cache_dir.join(format!("{key}.shd"));
418        let meta_path = self.cache_dir.join(format!("{key}.meta"));
419
420        // Read and validate the metadata.
421        match self.read_meta(&meta_path, version) {
422            Err(_) => {
423                self.stats.disk_misses += 1;
424                return None;
425            }
426            Ok(false) => {
427                self.stats.disk_misses += 1;
428                return None;
429            }
430            Ok(true) => {}
431        }
432
433        // Read the bytecode blob.
434        match std::fs::read(&shd_path) {
435            Ok(bytes) => {
436                self.stats.disk_hits += 1;
437                Some(bytes)
438            }
439            Err(_) => {
440                self.stats.disk_misses += 1;
441                self.stats.io_errors += 1;
442                None
443            }
444        }
445    }
446
447    /// Store a compiled shader on disk.
448    ///
449    /// On any I/O error the error is recorded in statistics but is **not**
450    /// propagated — the in-memory cache remains the source of truth.
451    #[allow(dead_code)]
452    pub fn put(&mut self, shader: &CompiledShader) {
453        let key = self.entry_key(&shader.version);
454        let shd_path = self.cache_dir.join(format!("{key}.shd"));
455        let meta_path = self.cache_dir.join(format!("{key}.meta"));
456
457        // Write bytecode.
458        if let Err(_e) = self.write_bytes(&shd_path, &shader.bytecode) {
459            self.stats.io_errors += 1;
460            return;
461        }
462
463        // Write metadata: "<source_hash> <backend> <feature_flags> <unix_secs>"
464        let unix_secs = shader
465            .created_at
466            .duration_since(UNIX_EPOCH)
467            .unwrap_or(Duration::ZERO)
468            .as_secs();
469        let meta_content = format!(
470            "{} {} {} {}",
471            shader.version.source_hash,
472            shader.version.backend,
473            shader.version.feature_flags,
474            unix_secs
475        );
476        if let Err(_e) = self.write_str(&meta_path, &meta_content) {
477            self.stats.io_errors += 1;
478            // Remove orphaned .shd to avoid inconsistency.
479            let _ = std::fs::remove_file(&shd_path);
480            return;
481        }
482
483        self.stats.disk_writes += 1;
484    }
485
486    /// Invalidate (delete) all entries for a specific backend.
487    ///
488    /// Errors during directory listing or file deletion are silently ignored.
489    #[allow(dead_code)]
490    pub fn invalidate_backend(&mut self, backend: &str) {
491        let Ok(entries) = std::fs::read_dir(&self.cache_dir) else {
492            return;
493        };
494        for entry in entries.flatten() {
495            let path = entry.path();
496            if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
497                // The key encodes the backend: `<hash>_<backend>_<flags>.*`
498                if name.contains(&format!("_{backend}_")) {
499                    let _ = std::fs::remove_file(&path);
500                }
501            }
502        }
503    }
504
505    /// Remove all cached entries from disk.
506    ///
507    /// Errors during directory listing or file deletion are silently ignored.
508    #[allow(dead_code)]
509    pub fn clear(&mut self) {
510        let Ok(entries) = std::fs::read_dir(&self.cache_dir) else {
511            return;
512        };
513        for entry in entries.flatten() {
514            let path = entry.path();
515            if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
516                if ext == "shd" || ext == "meta" {
517                    let _ = std::fs::remove_file(&path);
518                }
519            }
520        }
521    }
522
523    /// Returns a snapshot of the accumulated disk-cache statistics.
524    #[allow(dead_code)]
525    #[must_use]
526    pub fn stats(&self) -> &DiskCacheStats {
527        &self.stats
528    }
529
530    // ── private helpers ──────────────────────────────────────────────────────
531
532    /// Derive a filesystem-safe key from a [`ShaderVersion`].
533    fn entry_key(&self, v: &ShaderVersion) -> String {
534        // Sanitise the backend string (remove chars that are illegal on some
535        // filesystems).  We allow alphanumerics and hyphens only.
536        let safe_backend: String = v
537            .backend
538            .chars()
539            .map(|c| {
540                if c.is_alphanumeric() || c == '-' {
541                    c
542                } else {
543                    '_'
544                }
545            })
546            .collect();
547        format!(
548            "{:016x}_{}_{}",
549            v.source_hash, safe_backend, v.feature_flags
550        )
551    }
552
553    /// Read and validate a `.meta` file.
554    ///
555    /// Returns `Ok(true)` if the file matches `version`, `Ok(false)` if it
556    /// does not match, and `Err` on I/O failure.
557    fn read_meta(&mut self, path: &Path, version: &ShaderVersion) -> Result<bool, DiskCacheError> {
558        let mut file = std::fs::File::open(path)?;
559        let mut content = String::new();
560        file.read_to_string(&mut content)?;
561        let parts: Vec<&str> = content.trim().splitn(4, ' ').collect();
562        if parts.len() < 3 {
563            return Err(DiskCacheError::MalformedMetadata(content.clone()));
564        }
565        let stored_hash: u64 = parts[0]
566            .parse()
567            .map_err(|_| DiskCacheError::MalformedMetadata(parts[0].to_string()))?;
568        let stored_backend = parts[1];
569        let stored_flags: u32 = parts[2]
570            .parse()
571            .map_err(|_| DiskCacheError::MalformedMetadata(parts[2].to_string()))?;
572        Ok(stored_hash == version.source_hash
573            && stored_backend == version.backend
574            && stored_flags == version.feature_flags)
575    }
576
577    fn write_bytes(&self, path: &Path, data: &[u8]) -> std::io::Result<()> {
578        let mut f = std::fs::File::create(path)?;
579        f.write_all(data)
580    }
581
582    fn write_str(&self, path: &Path, s: &str) -> std::io::Result<()> {
583        let mut f = std::fs::File::create(path)?;
584        f.write_all(s.as_bytes())
585    }
586}
587
588// ---------------------------------------------------------------------------
589// Unit tests
590// ---------------------------------------------------------------------------
591
592#[cfg(test)]
593mod tests {
594    use super::*;
595
596    fn make_version(hash: u64) -> ShaderVersion {
597        ShaderVersion::new(hash, "vulkan", 0)
598    }
599
600    fn make_shader(hash: u64, size: usize) -> CompiledShader {
601        CompiledShader::new(vec![0u8; size], make_version(hash))
602    }
603
604    #[test]
605    fn test_insert_and_get() {
606        let mut cache = GpuShaderCache::default();
607        let shader = make_shader(1, 100);
608        let version = shader.version.clone();
609        cache.insert(shader);
610        assert!(cache.get(&version).is_some());
611    }
612
613    #[test]
614    fn test_cache_miss() {
615        let mut cache = GpuShaderCache::default();
616        let v = make_version(42);
617        assert!(cache.get(&v).is_none());
618        assert_eq!(cache.stats().misses, 1);
619    }
620
621    #[test]
622    fn test_hit_count_increments() {
623        let mut cache = GpuShaderCache::default();
624        let shader = make_shader(7, 50);
625        let version = shader.version.clone();
626        cache.insert(shader);
627        cache.get(&version);
628        cache.get(&version);
629        assert_eq!(
630            cache
631                .get(&version)
632                .expect("cache get should return stored data")
633                .hit_count,
634            3
635        );
636    }
637
638    #[test]
639    fn test_remove() {
640        let mut cache = GpuShaderCache::default();
641        let shader = make_shader(99, 200);
642        let version = shader.version.clone();
643        cache.insert(shader);
644        assert!(cache.remove(&version).is_some());
645        assert!(cache.get(&version).is_none());
646    }
647
648    #[test]
649    fn test_contains() {
650        let mut cache = GpuShaderCache::default();
651        let shader = make_shader(5, 10);
652        let version = shader.version.clone();
653        assert!(!cache.contains(&version));
654        cache.insert(shader);
655        assert!(cache.contains(&version));
656    }
657
658    #[test]
659    fn test_clear() {
660        let mut cache = GpuShaderCache::default();
661        cache.insert(make_shader(1, 10));
662        cache.insert(make_shader(2, 10));
663        cache.clear();
664        assert!(cache.is_empty());
665        assert_eq!(cache.stats().total_bytes, 0);
666    }
667
668    #[test]
669    fn test_eviction_by_entry_count() {
670        // Allow at most 2 entries.
671        let mut cache = GpuShaderCache::new(usize::MAX, 2, EvictionPolicy::Lfu);
672        cache.insert(make_shader(1, 10));
673        cache.insert(make_shader(2, 10));
674        // Hitting shader 2 raises its hit count so shader 1 gets evicted (LFU).
675        cache.get(&make_version(2));
676        // Insert a third shader – should evict the LFU entry.
677        cache.insert(make_shader(3, 10));
678        assert_eq!(cache.len(), 2);
679        assert!(cache.stats().evictions >= 1);
680    }
681
682    #[test]
683    fn test_eviction_by_bytes() {
684        // Allow at most 30 bytes.
685        let mut cache = GpuShaderCache::new(30, usize::MAX, EvictionPolicy::OldestFirst);
686        cache.insert(make_shader(1, 15));
687        cache.insert(make_shader(2, 15));
688        // Third insert (15 bytes) exceeds the cap – one entry should be evicted.
689        cache.insert(make_shader(3, 15));
690        assert!(cache.stats().evictions >= 1);
691    }
692
693    #[test]
694    fn test_invalidate_backend() {
695        let mut cache = GpuShaderCache::default();
696        let v1 = ShaderVersion::new(1, "vulkan", 0);
697        let v2 = ShaderVersion::new(2, "metal", 0);
698        cache.insert(CompiledShader::new(vec![0u8; 10], v1));
699        cache.insert(CompiledShader::new(vec![0u8; 10], v2.clone()));
700        cache.invalidate_backend("vulkan");
701        assert!(!cache.contains(&ShaderVersion::new(1, "vulkan", 0)));
702        assert!(cache.contains(&v2));
703    }
704
705    #[test]
706    fn test_hash_source_deterministic() {
707        let data = b"hello world shader";
708        assert_eq!(hash_source(data), hash_source(data));
709    }
710
711    #[test]
712    fn test_hash_source_differs_for_different_inputs() {
713        assert_ne!(hash_source(b"shader_a"), hash_source(b"shader_b"));
714    }
715
716    #[test]
717    fn test_default_cache_capacity() {
718        let cache = GpuShaderCache::default();
719        assert!(cache.is_empty());
720    }
721
722    #[test]
723    fn test_shader_version_equality() {
724        let v1 = ShaderVersion::new(10, "dx12", 3);
725        let v2 = ShaderVersion::new(10, "dx12", 3);
726        let v3 = ShaderVersion::new(10, "dx12", 4);
727        assert_eq!(v1, v2);
728        assert_ne!(v1, v3);
729    }
730
731    #[test]
732    fn test_age_of_is_non_negative() {
733        let t = SystemTime::now();
734        let age = age_of(t);
735        // Age should be very small but non-negative.
736        assert!(age < Duration::from_secs(5));
737    }
738
739    // ── DiskShaderCache tests ─────────────────────────────────────────────────
740
741    #[test]
742    fn test_disk_cache_put_and_get() {
743        let dir = std::env::temp_dir().join("oximedia_gpu_disk_cache_test_pg");
744        let _ = std::fs::remove_dir_all(&dir); // clean slate
745        let mut cache = DiskShaderCache::open(&dir).expect("open disk cache");
746        let version = ShaderVersion::new(0xDEAD_BEEF, "vulkan", 7);
747        let shader = CompiledShader::new(vec![1, 2, 3, 4, 5], version.clone());
748        cache.put(&shader);
749        let bytes = cache.get(&version).expect("should find stored shader");
750        assert_eq!(bytes, vec![1u8, 2, 3, 4, 5]);
751        assert_eq!(cache.stats().disk_writes, 1);
752        assert_eq!(cache.stats().disk_hits, 1);
753        let _ = std::fs::remove_dir_all(&dir);
754    }
755
756    #[test]
757    fn test_disk_cache_miss_unknown_version() {
758        let dir = std::env::temp_dir().join("oximedia_gpu_disk_cache_test_miss");
759        let _ = std::fs::remove_dir_all(&dir);
760        let mut cache = DiskShaderCache::open(&dir).expect("open disk cache");
761        let version = ShaderVersion::new(0x1234, "metal", 0);
762        assert!(cache.get(&version).is_none());
763        assert_eq!(cache.stats().disk_misses, 1);
764        let _ = std::fs::remove_dir_all(&dir);
765    }
766
767    #[test]
768    fn test_disk_cache_roundtrip_large_bytecode() {
769        let dir = std::env::temp_dir().join("oximedia_gpu_disk_cache_test_large");
770        let _ = std::fs::remove_dir_all(&dir);
771        let mut cache = DiskShaderCache::open(&dir).expect("open disk cache");
772        let version = ShaderVersion::new(0xABCD_1234, "dx12", 3);
773        let bytecode: Vec<u8> = (0..=255u8).cycle().take(4096).collect();
774        let shader = CompiledShader::new(bytecode.clone(), version.clone());
775        cache.put(&shader);
776        let result = cache.get(&version).expect("should retrieve large blob");
777        assert_eq!(result, bytecode);
778        let _ = std::fs::remove_dir_all(&dir);
779    }
780
781    #[test]
782    fn test_disk_cache_version_mismatch_returns_none() {
783        let dir = std::env::temp_dir().join("oximedia_gpu_disk_cache_test_mismatch");
784        let _ = std::fs::remove_dir_all(&dir);
785        let mut cache = DiskShaderCache::open(&dir).expect("open disk cache");
786        let v1 = ShaderVersion::new(0xAAAA, "vulkan", 1);
787        let v2 = ShaderVersion::new(0xBBBB, "vulkan", 1); // different hash
788        cache.put(&CompiledShader::new(vec![0u8; 10], v1));
789        // v2 was never written; looking it up must return None.
790        assert!(cache.get(&v2).is_none());
791        let _ = std::fs::remove_dir_all(&dir);
792    }
793
794    #[test]
795    fn test_disk_cache_clear_removes_all_files() {
796        let dir = std::env::temp_dir().join("oximedia_gpu_disk_cache_test_clear");
797        let _ = std::fs::remove_dir_all(&dir);
798        let mut cache = DiskShaderCache::open(&dir).expect("open disk cache");
799        for i in 0u64..5 {
800            cache.put(&CompiledShader::new(
801                vec![i as u8; 8],
802                ShaderVersion::new(i, "vulkan", 0),
803            ));
804        }
805        cache.clear();
806        // After clearing, no .shd or .meta files should remain.
807        let file_count = std::fs::read_dir(&dir)
808            .map(|it| it.flatten().count())
809            .unwrap_or(0);
810        assert_eq!(
811            file_count, 0,
812            "expected 0 files after clear, got {file_count}"
813        );
814        let _ = std::fs::remove_dir_all(&dir);
815    }
816
817    #[test]
818    fn test_disk_cache_invalidate_backend() {
819        let dir = std::env::temp_dir().join("oximedia_gpu_disk_cache_test_inval");
820        let _ = std::fs::remove_dir_all(&dir);
821        let mut cache = DiskShaderCache::open(&dir).expect("open disk cache");
822        let v_vulkan = ShaderVersion::new(0x10, "vulkan", 0);
823        let v_metal = ShaderVersion::new(0x20, "metal", 0);
824        cache.put(&CompiledShader::new(vec![1u8; 8], v_vulkan.clone()));
825        cache.put(&CompiledShader::new(vec![2u8; 8], v_metal.clone()));
826        cache.invalidate_backend("vulkan");
827        // Vulkan entry must be gone; metal must remain.
828        assert!(
829            cache.get(&v_vulkan).is_none(),
830            "vulkan entry should be gone"
831        );
832        assert!(
833            cache.get(&v_metal).is_some(),
834            "metal entry should still exist"
835        );
836        let _ = std::fs::remove_dir_all(&dir);
837    }
838
839    #[test]
840    fn test_disk_cache_stats_accumulate() {
841        let dir = std::env::temp_dir().join("oximedia_gpu_disk_cache_test_stats");
842        let _ = std::fs::remove_dir_all(&dir);
843        let mut cache = DiskShaderCache::open(&dir).expect("open disk cache");
844        let v = ShaderVersion::new(0xFF, "dx12", 0);
845        // Miss first.
846        cache.get(&v);
847        // Write then hit twice.
848        cache.put(&CompiledShader::new(vec![7u8; 4], v.clone()));
849        cache.get(&v);
850        cache.get(&v);
851        assert_eq!(cache.stats().disk_misses, 1);
852        assert_eq!(cache.stats().disk_writes, 1);
853        assert_eq!(cache.stats().disk_hits, 2);
854        let _ = std::fs::remove_dir_all(&dir);
855    }
856
857    // ── Shader cache invalidation tests ──────────────────────────────────────
858
859    /// Helper: build a `ShaderVersion` keyed by `source_hash`.
860    fn versioned(source: &[u8]) -> ShaderVersion {
861        ShaderVersion::new(hash_source(source), "vulkan", 0)
862    }
863
864    #[test]
865    fn test_invalidation_initial_hit() {
866        let mut cache = GpuShaderCache::default();
867        let source_v1 = b"// shader version 1\nvoid main() {}";
868        let version_v1 = versioned(source_v1);
869        let shader = CompiledShader::new(vec![0xAA; 32], version_v1.clone());
870        cache.insert(shader);
871        // First retrieval must be a hit.
872        assert!(cache.get(&version_v1).is_some(), "version 1 must hit");
873        assert_eq!(cache.stats().hits, 1);
874        assert_eq!(cache.stats().misses, 0);
875    }
876
877    #[test]
878    fn test_invalidation_different_source_is_miss() {
879        let mut cache = GpuShaderCache::default();
880        let source_v1 = b"// shader version 1\nvoid main() {}";
881        let source_v2 = b"// shader version 2\nvoid main() { discard; }";
882        let version_v1 = versioned(source_v1);
883        let version_v2 = versioned(source_v2);
884        // Insert version 1 only.
885        cache.insert(CompiledShader::new(vec![0x11; 16], version_v1.clone()));
886        // Looking up version 2 must be a miss.
887        assert!(
888            cache.get(&version_v2).is_none(),
889            "different source hash must be a miss"
890        );
891        assert_eq!(cache.stats().misses, 1);
892    }
893
894    #[test]
895    fn test_invalidation_old_version_not_accessible_after_remove() {
896        let mut cache = GpuShaderCache::default();
897        let source_v1 = b"// version 1";
898        let source_v2 = b"// version 2";
899        let version_v1 = versioned(source_v1);
900        let version_v2 = versioned(source_v2);
901        // Cache version 1.
902        cache.insert(CompiledShader::new(vec![0x01; 8], version_v1.clone()));
903        assert!(cache.get(&version_v1).is_some(), "v1 hit");
904        // Simulate invalidation: remove v1, insert v2.
905        cache.remove(&version_v1);
906        cache.insert(CompiledShader::new(vec![0x02; 8], version_v2.clone()));
907        // v1 must be gone.
908        assert!(
909            cache.get(&version_v1).is_none(),
910            "old version must not be accessible after remove"
911        );
912        // v2 must be present.
913        assert!(cache.get(&version_v2).is_some(), "new version must hit");
914    }
915
916    #[test]
917    fn test_invalidation_source_hash_changes_on_whitespace_edit() {
918        // Even a single whitespace difference must produce a different hash.
919        let source_a = b"void main(){}";
920        let source_b = b"void main() {}";
921        assert_ne!(
922            hash_source(source_a),
923            hash_source(source_b),
924            "whitespace change must produce different hash"
925        );
926    }
927
928    #[test]
929    fn test_invalidation_disk_cache_version_change() {
930        let dir = std::env::temp_dir().join("oximedia_gpu_shader_inval_test");
931        let _ = std::fs::remove_dir_all(&dir);
932        let mut disk = DiskShaderCache::open(&dir).expect("open disk cache");
933
934        let source_v1 = b"// v1 source";
935        let source_v2 = b"// v2 source -- recompiled";
936        let version_v1 = ShaderVersion::new(hash_source(source_v1), "vulkan", 0);
937        let version_v2 = ShaderVersion::new(hash_source(source_v2), "vulkan", 0);
938
939        // Write version 1.
940        disk.put(&CompiledShader::new(vec![0x01; 4], version_v1.clone()));
941        // Version 1 must hit.
942        assert!(disk.get(&version_v1).is_some(), "v1 disk hit");
943        // Version 2 must miss (not yet written).
944        assert!(disk.get(&version_v2).is_none(), "v2 disk miss before write");
945        // Write version 2.
946        disk.put(&CompiledShader::new(vec![0x02; 4], version_v2.clone()));
947        // Now version 2 must hit.
948        assert!(disk.get(&version_v2).is_some(), "v2 disk hit after write");
949        // Version 1 still hits (two independent entries).
950        assert!(disk.get(&version_v1).is_some(), "v1 still exists");
951
952        let _ = std::fs::remove_dir_all(&dir);
953    }
954
955    #[test]
956    fn test_invalidation_clear_invalidates_all() {
957        let mut cache = GpuShaderCache::default();
958        let v1 = versioned(b"shader A");
959        let v2 = versioned(b"shader B");
960        cache.insert(CompiledShader::new(vec![1u8; 8], v1.clone()));
961        cache.insert(CompiledShader::new(vec![2u8; 8], v2.clone()));
962        cache.clear();
963        assert!(cache.get(&v1).is_none(), "v1 must be gone after clear");
964        assert!(cache.get(&v2).is_none(), "v2 must be gone after clear");
965        assert!(cache.is_empty());
966    }
967}