1use crate::{GpuError, Result};
7use parking_lot::RwLock;
8use std::collections::HashMap;
9use std::fs::{self, File};
10use std::io::{Read, Write};
11use std::path::{Path, PathBuf};
12use std::sync::Arc;
13use std::time::SystemTime;
14
15#[derive(Debug, Clone)]
17struct CacheEntry {
18 key: String,
20 timestamp: SystemTime,
22 size: usize,
24}
25
26pub struct PipelineCache {
28 cache_dir: PathBuf,
29 entries: Arc<RwLock<HashMap<String, CacheEntry>>>,
30 max_cache_size: usize,
31 enabled: bool,
32}
33
34impl PipelineCache {
35 pub fn new(cache_dir: impl AsRef<Path>, max_cache_size: usize) -> Result<Self> {
42 let cache_dir = cache_dir.as_ref().to_path_buf();
43
44 if !cache_dir.exists() {
46 fs::create_dir_all(&cache_dir).map_err(|e| {
47 GpuError::Internal(format!("Failed to create cache directory: {e}"))
48 })?;
49 }
50
51 let cache = Self {
52 cache_dir,
53 entries: Arc::new(RwLock::new(HashMap::new())),
54 max_cache_size,
55 enabled: true,
56 };
57
58 cache.load_cache_index()?;
60
61 Ok(cache)
62 }
63
64 pub fn default_cache() -> Result<Self> {
66 let cache_dir = Self::default_cache_dir()?;
67 Self::new(cache_dir, 100 * 1024 * 1024) }
69
70 fn default_cache_dir() -> Result<PathBuf> {
72 let cache_dir = if let Some(cache_dir) = dirs::cache_dir() {
73 cache_dir.join("oximedia").join("gpu_cache")
74 } else {
75 PathBuf::from(".oximedia_cache")
76 };
77
78 Ok(cache_dir)
79 }
80
81 fn load_cache_index(&self) -> Result<()> {
83 let index_path = self.cache_dir.join("index.json");
84
85 if !index_path.exists() {
86 return Ok(());
87 }
88
89 let mut file = File::open(&index_path)
90 .map_err(|e| GpuError::Internal(format!("Failed to open cache index: {e}")))?;
91
92 let mut contents = String::new();
93 file.read_to_string(&mut contents)
94 .map_err(|e| GpuError::Internal(format!("Failed to read cache index: {e}")))?;
95
96 Ok(())
99 }
100
101 #[allow(dead_code)]
103 fn save_cache_index(&self) -> Result<()> {
104 let index_path = self.cache_dir.join("index.json");
105
106 let mut file = File::create(&index_path)
107 .map_err(|e| GpuError::Internal(format!("Failed to create cache index: {e}")))?;
108
109 file.write_all(b"{}")
111 .map_err(|e| GpuError::Internal(format!("Failed to write cache index: {e}")))?;
112
113 Ok(())
114 }
115
116 #[must_use]
126 pub fn get(&self, key: &str) -> Option<Vec<u8>> {
127 if !self.enabled {
128 return None;
129 }
130
131 let entries = self.entries.read();
132 if !entries.contains_key(key) {
133 return None;
134 }
135
136 let cache_path = self.cache_dir.join(format!("{key}.bin"));
137 if !cache_path.exists() {
138 return None;
139 }
140
141 let mut file = File::open(&cache_path).ok()?;
142 let mut data = Vec::new();
143 file.read_to_end(&mut data).ok()?;
144
145 Some(data)
146 }
147
148 pub fn put(&self, key: &str, data: &[u8]) -> Result<()> {
155 if !self.enabled {
156 return Ok(());
157 }
158
159 if self.max_cache_size > 0 {
161 let current_size = self.total_cache_size();
162 if current_size + data.len() > self.max_cache_size {
163 self.evict_oldest()?;
164 }
165 }
166
167 let cache_path = self.cache_dir.join(format!("{key}.bin"));
168 let mut file = File::create(&cache_path)
169 .map_err(|e| GpuError::Internal(format!("Failed to create cache file: {e}")))?;
170
171 file.write_all(data)
172 .map_err(|e| GpuError::Internal(format!("Failed to write cache file: {e}")))?;
173
174 let mut entries = self.entries.write();
176 entries.insert(
177 key.to_string(),
178 CacheEntry {
179 key: key.to_string(),
180 timestamp: SystemTime::now(),
181 size: data.len(),
182 },
183 );
184
185 Ok(())
186 }
187
188 pub fn remove(&self, key: &str) -> Result<()> {
190 let cache_path = self.cache_dir.join(format!("{key}.bin"));
191
192 if cache_path.exists() {
193 fs::remove_file(&cache_path)
194 .map_err(|e| GpuError::Internal(format!("Failed to remove cache file: {e}")))?;
195 }
196
197 let mut entries = self.entries.write();
198 entries.remove(key);
199
200 Ok(())
201 }
202
203 pub fn clear(&self) -> Result<()> {
205 let entries: Vec<String> = {
206 let entries = self.entries.read();
207 entries.keys().cloned().collect()
208 };
209
210 for key in entries {
211 self.remove(&key)?;
212 }
213
214 Ok(())
215 }
216
217 #[must_use]
219 pub fn total_cache_size(&self) -> usize {
220 let entries = self.entries.read();
221 entries.values().map(|e| e.size).sum()
222 }
223
224 #[must_use]
226 pub fn entry_count(&self) -> usize {
227 let entries = self.entries.read();
228 entries.len()
229 }
230
231 fn evict_oldest(&self) -> Result<()> {
233 let oldest_key = {
234 let entries = self.entries.read();
235 entries
236 .values()
237 .min_by_key(|e| e.timestamp)
238 .map(|e| e.key.clone())
239 };
240
241 if let Some(key) = oldest_key {
242 self.remove(&key)?;
243 }
244
245 Ok(())
246 }
247
248 pub fn set_enabled(&mut self, enabled: bool) {
250 self.enabled = enabled;
251 }
252
253 #[must_use]
255 pub fn is_enabled(&self) -> bool {
256 self.enabled
257 }
258
259 #[must_use]
261 pub fn stats(&self) -> CacheStats {
262 CacheStats {
263 entry_count: self.entry_count(),
264 total_size: self.total_cache_size(),
265 max_size: self.max_cache_size,
266 enabled: self.enabled,
267 }
268 }
269}
270
271#[derive(Debug, Clone)]
273pub struct CacheStats {
274 pub entry_count: usize,
276 pub total_size: usize,
278 pub max_size: usize,
280 pub enabled: bool,
282}
283
284impl CacheStats {
285 #[must_use]
287 pub fn utilization(&self) -> f64 {
288 if self.max_size == 0 {
289 0.0
290 } else {
291 (self.total_size as f64 / self.max_size as f64) * 100.0
292 }
293 }
294
295 #[must_use]
297 pub fn size_mb(&self) -> f64 {
298 self.total_size as f64 / (1024.0 * 1024.0)
299 }
300}
301
302pub struct ShaderCache {
304 cache: Arc<RwLock<HashMap<String, Vec<u8>>>>,
305}
306
307impl ShaderCache {
308 #[must_use]
310 pub fn new() -> Self {
311 Self {
312 cache: Arc::new(RwLock::new(HashMap::new())),
313 }
314 }
315
316 #[must_use]
318 pub fn get(&self, source_hash: &str) -> Option<Vec<u8>> {
319 let cache = self.cache.read();
320 cache.get(source_hash).cloned()
321 }
322
323 pub fn put(&self, source_hash: String, compiled_shader: Vec<u8>) {
325 let mut cache = self.cache.write();
326 cache.insert(source_hash, compiled_shader);
327 }
328
329 pub fn clear(&self) {
331 let mut cache = self.cache.write();
332 cache.clear();
333 }
334
335 #[must_use]
337 pub fn size(&self) -> usize {
338 let cache = self.cache.read();
339 cache.len()
340 }
341}
342
343impl Default for ShaderCache {
344 fn default() -> Self {
345 Self::new()
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_pipeline_cache_creation() {
355 let temp_dir = std::env::temp_dir().join("oximedia_cache_test");
356 let cache = PipelineCache::new(&temp_dir, 1024 * 1024)
357 .expect("pipeline cache creation should succeed");
358
359 assert!(cache.is_enabled());
360 assert_eq!(cache.entry_count(), 0);
361 assert_eq!(cache.total_cache_size(), 0);
362
363 let _ = fs::remove_dir_all(&temp_dir);
365 }
366
367 #[test]
368 fn test_pipeline_cache_put_get() {
369 let temp_dir = std::env::temp_dir().join("oximedia_cache_test_2");
370 let cache = PipelineCache::new(&temp_dir, 1024 * 1024)
371 .expect("pipeline cache creation should succeed");
372
373 let key = "test_shader";
374 let data = vec![1, 2, 3, 4, 5];
375
376 cache.put(key, &data).expect("cache put should succeed");
377 let retrieved = cache.get(key).expect("cache get should return stored data");
378
379 assert_eq!(data, retrieved);
380 assert_eq!(cache.entry_count(), 1);
381
382 let _ = fs::remove_dir_all(&temp_dir);
384 }
385
386 #[test]
387 fn test_shader_cache() {
388 let cache = ShaderCache::new();
389 assert_eq!(cache.size(), 0);
390
391 cache.put("shader1".to_string(), vec![1, 2, 3]);
392 assert_eq!(cache.size(), 1);
393
394 let shader = cache.get("shader1");
395 assert_eq!(shader, Some(vec![1, 2, 3]));
396
397 cache.clear();
398 assert_eq!(cache.size(), 0);
399 }
400
401 #[test]
402 fn test_cache_stats() {
403 let temp_dir = std::env::temp_dir().join("oximedia_cache_test_3");
404 let cache =
405 PipelineCache::new(&temp_dir, 1024).expect("pipeline cache creation should succeed");
406
407 cache
408 .put("key1", &[0u8; 100])
409 .expect("cache put should succeed");
410 cache
411 .put("key2", &[0u8; 200])
412 .expect("cache put should succeed");
413
414 let stats = cache.stats();
415 assert_eq!(stats.entry_count, 2);
416 assert_eq!(stats.total_size, 300);
417
418 let _ = fs::remove_dir_all(&temp_dir);
420 }
421}