Skip to main content

oximedia_gpu/
cache.rs

1//! Pipeline cache management for faster startup and reduced compilation overhead
2//!
3//! This module provides caching mechanisms for compiled pipelines and shaders,
4//! allowing them to be reused across application runs.
5
6use crate::{GpuError, Result};
7use parking_lot::RwLock;
8use std::collections::HashMap;
9use std::fs::{self, File};
10use std::io::{Read, Write};
11use std::path::{Path, PathBuf};
12use std::sync::Arc;
13use std::time::SystemTime;
14
15/// Cache entry metadata
16#[derive(Debug, Clone)]
17struct CacheEntry {
18    /// Cache key (usually a hash of the shader source)
19    key: String,
20    /// Timestamp when the entry was created
21    timestamp: SystemTime,
22    /// Size in bytes
23    size: usize,
24}
25
26/// Pipeline cache for storing compiled pipelines
27pub struct PipelineCache {
28    cache_dir: PathBuf,
29    entries: Arc<RwLock<HashMap<String, CacheEntry>>>,
30    max_cache_size: usize,
31    enabled: bool,
32}
33
34impl PipelineCache {
35    /// Create a new pipeline cache
36    ///
37    /// # Arguments
38    ///
39    /// * `cache_dir` - Directory to store cached pipelines
40    /// * `max_cache_size` - Maximum cache size in bytes (0 = unlimited)
41    pub fn new(cache_dir: impl AsRef<Path>, max_cache_size: usize) -> Result<Self> {
42        let cache_dir = cache_dir.as_ref().to_path_buf();
43
44        // Create cache directory if it doesn't exist
45        if !cache_dir.exists() {
46            fs::create_dir_all(&cache_dir).map_err(|e| {
47                GpuError::Internal(format!("Failed to create cache directory: {e}"))
48            })?;
49        }
50
51        let cache = Self {
52            cache_dir,
53            entries: Arc::new(RwLock::new(HashMap::new())),
54            max_cache_size,
55            enabled: true,
56        };
57
58        // Load existing cache entries
59        cache.load_cache_index()?;
60
61        Ok(cache)
62    }
63
64    /// Create a default pipeline cache in the system cache directory
65    pub fn default_cache() -> Result<Self> {
66        let cache_dir = Self::default_cache_dir()?;
67        Self::new(cache_dir, 100 * 1024 * 1024) // 100 MB default
68    }
69
70    /// Get the default cache directory for the current platform
71    fn default_cache_dir() -> Result<PathBuf> {
72        let cache_dir = if let Some(cache_dir) = dirs::cache_dir() {
73            cache_dir.join("oximedia").join("gpu_cache")
74        } else {
75            PathBuf::from(".oximedia_cache")
76        };
77
78        Ok(cache_dir)
79    }
80
81    /// Load the cache index from disk
82    fn load_cache_index(&self) -> Result<()> {
83        let index_path = self.cache_dir.join("index.json");
84
85        if !index_path.exists() {
86            return Ok(());
87        }
88
89        let mut file = File::open(&index_path)
90            .map_err(|e| GpuError::Internal(format!("Failed to open cache index: {e}")))?;
91
92        let mut contents = String::new();
93        file.read_to_string(&mut contents)
94            .map_err(|e| GpuError::Internal(format!("Failed to read cache index: {e}")))?;
95
96        // Parse JSON index (simplified - in production, use serde_json)
97        // For now, just create empty index
98        Ok(())
99    }
100
101    /// Save the cache index to disk
102    #[allow(dead_code)]
103    fn save_cache_index(&self) -> Result<()> {
104        let index_path = self.cache_dir.join("index.json");
105
106        let mut file = File::create(&index_path)
107            .map_err(|e| GpuError::Internal(format!("Failed to create cache index: {e}")))?;
108
109        // Write JSON index (simplified)
110        file.write_all(b"{}")
111            .map_err(|e| GpuError::Internal(format!("Failed to write cache index: {e}")))?;
112
113        Ok(())
114    }
115
116    /// Get a cached pipeline by key
117    ///
118    /// # Arguments
119    ///
120    /// * `key` - Cache key (usually shader source hash)
121    ///
122    /// # Returns
123    ///
124    /// Cached pipeline data if found, None otherwise
125    #[must_use]
126    pub fn get(&self, key: &str) -> Option<Vec<u8>> {
127        if !self.enabled {
128            return None;
129        }
130
131        let entries = self.entries.read();
132        if !entries.contains_key(key) {
133            return None;
134        }
135
136        let cache_path = self.cache_dir.join(format!("{key}.bin"));
137        if !cache_path.exists() {
138            return None;
139        }
140
141        let mut file = File::open(&cache_path).ok()?;
142        let mut data = Vec::new();
143        file.read_to_end(&mut data).ok()?;
144
145        Some(data)
146    }
147
148    /// Store a pipeline in the cache
149    ///
150    /// # Arguments
151    ///
152    /// * `key` - Cache key
153    /// * `data` - Pipeline data to cache
154    pub fn put(&self, key: &str, data: &[u8]) -> Result<()> {
155        if !self.enabled {
156            return Ok(());
157        }
158
159        // Check cache size limit
160        if self.max_cache_size > 0 {
161            let current_size = self.total_cache_size();
162            if current_size + data.len() > self.max_cache_size {
163                self.evict_oldest()?;
164            }
165        }
166
167        let cache_path = self.cache_dir.join(format!("{key}.bin"));
168        let mut file = File::create(&cache_path)
169            .map_err(|e| GpuError::Internal(format!("Failed to create cache file: {e}")))?;
170
171        file.write_all(data)
172            .map_err(|e| GpuError::Internal(format!("Failed to write cache file: {e}")))?;
173
174        // Update cache index
175        let mut entries = self.entries.write();
176        entries.insert(
177            key.to_string(),
178            CacheEntry {
179                key: key.to_string(),
180                timestamp: SystemTime::now(),
181                size: data.len(),
182            },
183        );
184
185        Ok(())
186    }
187
188    /// Remove a cached pipeline
189    pub fn remove(&self, key: &str) -> Result<()> {
190        let cache_path = self.cache_dir.join(format!("{key}.bin"));
191
192        if cache_path.exists() {
193            fs::remove_file(&cache_path)
194                .map_err(|e| GpuError::Internal(format!("Failed to remove cache file: {e}")))?;
195        }
196
197        let mut entries = self.entries.write();
198        entries.remove(key);
199
200        Ok(())
201    }
202
203    /// Clear the entire cache
204    pub fn clear(&self) -> Result<()> {
205        let entries: Vec<String> = {
206            let entries = self.entries.read();
207            entries.keys().cloned().collect()
208        };
209
210        for key in entries {
211            self.remove(&key)?;
212        }
213
214        Ok(())
215    }
216
217    /// Get the total cache size in bytes
218    #[must_use]
219    pub fn total_cache_size(&self) -> usize {
220        let entries = self.entries.read();
221        entries.values().map(|e| e.size).sum()
222    }
223
224    /// Get the number of cached entries
225    #[must_use]
226    pub fn entry_count(&self) -> usize {
227        let entries = self.entries.read();
228        entries.len()
229    }
230
231    /// Evict the oldest cache entry
232    fn evict_oldest(&self) -> Result<()> {
233        let oldest_key = {
234            let entries = self.entries.read();
235            entries
236                .values()
237                .min_by_key(|e| e.timestamp)
238                .map(|e| e.key.clone())
239        };
240
241        if let Some(key) = oldest_key {
242            self.remove(&key)?;
243        }
244
245        Ok(())
246    }
247
248    /// Enable or disable the cache
249    pub fn set_enabled(&mut self, enabled: bool) {
250        self.enabled = enabled;
251    }
252
253    /// Check if the cache is enabled
254    #[must_use]
255    pub fn is_enabled(&self) -> bool {
256        self.enabled
257    }
258
259    /// Get cache statistics
260    #[must_use]
261    pub fn stats(&self) -> CacheStats {
262        CacheStats {
263            entry_count: self.entry_count(),
264            total_size: self.total_cache_size(),
265            max_size: self.max_cache_size,
266            enabled: self.enabled,
267        }
268    }
269}
270
271/// Cache statistics
272#[derive(Debug, Clone)]
273pub struct CacheStats {
274    /// Number of cached entries
275    pub entry_count: usize,
276    /// Total cache size in bytes
277    pub total_size: usize,
278    /// Maximum cache size in bytes
279    pub max_size: usize,
280    /// Whether the cache is enabled
281    pub enabled: bool,
282}
283
284impl CacheStats {
285    /// Get cache utilization as a percentage
286    #[must_use]
287    pub fn utilization(&self) -> f64 {
288        if self.max_size == 0 {
289            0.0
290        } else {
291            (self.total_size as f64 / self.max_size as f64) * 100.0
292        }
293    }
294
295    /// Get total size in megabytes
296    #[must_use]
297    pub fn size_mb(&self) -> f64 {
298        self.total_size as f64 / (1024.0 * 1024.0)
299    }
300}
301
302/// Shader cache for storing compiled shader modules
303pub struct ShaderCache {
304    cache: Arc<RwLock<HashMap<String, Vec<u8>>>>,
305}
306
307impl ShaderCache {
308    /// Create a new shader cache
309    #[must_use]
310    pub fn new() -> Self {
311        Self {
312            cache: Arc::new(RwLock::new(HashMap::new())),
313        }
314    }
315
316    /// Get a cached shader
317    #[must_use]
318    pub fn get(&self, source_hash: &str) -> Option<Vec<u8>> {
319        let cache = self.cache.read();
320        cache.get(source_hash).cloned()
321    }
322
323    /// Store a shader in the cache
324    pub fn put(&self, source_hash: String, compiled_shader: Vec<u8>) {
325        let mut cache = self.cache.write();
326        cache.insert(source_hash, compiled_shader);
327    }
328
329    /// Clear the shader cache
330    pub fn clear(&self) {
331        let mut cache = self.cache.write();
332        cache.clear();
333    }
334
335    /// Get the number of cached shaders
336    #[must_use]
337    pub fn size(&self) -> usize {
338        let cache = self.cache.read();
339        cache.len()
340    }
341}
342
343impl Default for ShaderCache {
344    fn default() -> Self {
345        Self::new()
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_pipeline_cache_creation() {
355        let temp_dir = std::env::temp_dir().join("oximedia_cache_test");
356        let cache = PipelineCache::new(&temp_dir, 1024 * 1024)
357            .expect("pipeline cache creation should succeed");
358
359        assert!(cache.is_enabled());
360        assert_eq!(cache.entry_count(), 0);
361        assert_eq!(cache.total_cache_size(), 0);
362
363        // Clean up
364        let _ = fs::remove_dir_all(&temp_dir);
365    }
366
367    #[test]
368    fn test_pipeline_cache_put_get() {
369        let temp_dir = std::env::temp_dir().join("oximedia_cache_test_2");
370        let cache = PipelineCache::new(&temp_dir, 1024 * 1024)
371            .expect("pipeline cache creation should succeed");
372
373        let key = "test_shader";
374        let data = vec![1, 2, 3, 4, 5];
375
376        cache.put(key, &data).expect("cache put should succeed");
377        let retrieved = cache.get(key).expect("cache get should return stored data");
378
379        assert_eq!(data, retrieved);
380        assert_eq!(cache.entry_count(), 1);
381
382        // Clean up
383        let _ = fs::remove_dir_all(&temp_dir);
384    }
385
386    #[test]
387    fn test_shader_cache() {
388        let cache = ShaderCache::new();
389        assert_eq!(cache.size(), 0);
390
391        cache.put("shader1".to_string(), vec![1, 2, 3]);
392        assert_eq!(cache.size(), 1);
393
394        let shader = cache.get("shader1");
395        assert_eq!(shader, Some(vec![1, 2, 3]));
396
397        cache.clear();
398        assert_eq!(cache.size(), 0);
399    }
400
401    #[test]
402    fn test_cache_stats() {
403        let temp_dir = std::env::temp_dir().join("oximedia_cache_test_3");
404        let cache =
405            PipelineCache::new(&temp_dir, 1024).expect("pipeline cache creation should succeed");
406
407        cache
408            .put("key1", &[0u8; 100])
409            .expect("cache put should succeed");
410        cache
411            .put("key2", &[0u8; 200])
412            .expect("cache put should succeed");
413
414        let stats = cache.stats();
415        assert_eq!(stats.entry_count, 2);
416        assert_eq!(stats.total_size, 300);
417
418        // Clean up
419        let _ = fs::remove_dir_all(&temp_dir);
420    }
421}