1use std::collections::HashMap;
7use std::time::{Duration, SystemTime};
8
9#[allow(dead_code)]
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12pub struct ShaderVersion {
13 pub source_hash: u64,
15 pub backend: String,
17 pub feature_flags: u32,
19}
20
21impl ShaderVersion {
22 #[allow(dead_code)]
24 #[must_use]
25 pub fn new(source_hash: u64, backend: impl Into<String>, feature_flags: u32) -> Self {
26 Self {
27 source_hash,
28 backend: backend.into(),
29 feature_flags,
30 }
31 }
32}
33
34#[allow(dead_code)]
36#[derive(Debug, Clone)]
37pub struct CompiledShader {
38 pub bytecode: Vec<u8>,
40 pub version: ShaderVersion,
42 pub created_at: SystemTime,
44 pub size_bytes: usize,
46 pub hit_count: u64,
48}
49
50impl CompiledShader {
51 #[allow(dead_code)]
53 #[must_use]
54 pub fn new(bytecode: Vec<u8>, version: ShaderVersion) -> Self {
55 let size_bytes = bytecode.len();
56 Self {
57 bytecode,
58 version,
59 created_at: SystemTime::now(),
60 size_bytes,
61 hit_count: 0,
62 }
63 }
64}
65
66#[allow(dead_code)]
68#[derive(Debug, Clone, Default)]
69pub struct ShaderCacheStats {
70 pub entry_count: usize,
72 pub total_bytes: usize,
74 pub hits: u64,
76 pub misses: u64,
78 pub evictions: u64,
80}
81
82#[allow(dead_code)]
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
85pub enum EvictionPolicy {
86 #[default]
88 Lru,
89 Lfu,
91 OldestFirst,
93}
94
95#[allow(dead_code)]
97pub struct GpuShaderCache {
98 entries: HashMap<ShaderVersion, CompiledShader>,
99 max_bytes: usize,
101 max_entries: usize,
103 policy: EvictionPolicy,
105 stats: ShaderCacheStats,
107 last_access: HashMap<ShaderVersion, SystemTime>,
109}
110
111impl GpuShaderCache {
112 #[allow(dead_code)]
118 #[must_use]
119 pub fn new(max_bytes: usize, max_entries: usize, policy: EvictionPolicy) -> Self {
120 Self {
121 entries: HashMap::new(),
122 max_bytes,
123 max_entries,
124 policy,
125 stats: ShaderCacheStats::default(),
126 last_access: HashMap::new(),
127 }
128 }
129
130 #[allow(dead_code)]
135 pub fn insert(&mut self, shader: CompiledShader) {
136 while self.needs_eviction(shader.size_bytes) {
138 if !self.evict_one() {
139 break; }
141 }
142
143 self.stats.total_bytes += shader.size_bytes;
144 self.stats.entry_count += 1;
145 self.last_access
146 .insert(shader.version.clone(), SystemTime::now());
147 self.entries.insert(shader.version.clone(), shader);
148 }
149
150 #[allow(dead_code)]
154 pub fn get(&mut self, version: &ShaderVersion) -> Option<&CompiledShader> {
155 if self.entries.contains_key(version) {
156 self.stats.hits += 1;
157 self.last_access.insert(version.clone(), SystemTime::now());
159 if let Some(e) = self.entries.get_mut(version) {
160 e.hit_count += 1;
161 }
162 self.entries.get(version)
163 } else {
164 self.stats.misses += 1;
165 None
166 }
167 }
168
169 #[allow(dead_code)]
171 #[must_use]
172 pub fn contains(&self, version: &ShaderVersion) -> bool {
173 self.entries.contains_key(version)
174 }
175
176 #[allow(dead_code)]
178 pub fn remove(&mut self, version: &ShaderVersion) -> Option<CompiledShader> {
179 if let Some(shader) = self.entries.remove(version) {
180 self.stats.total_bytes = self.stats.total_bytes.saturating_sub(shader.size_bytes);
181 self.stats.entry_count = self.stats.entry_count.saturating_sub(1);
182 self.last_access.remove(version);
183 Some(shader)
184 } else {
185 None
186 }
187 }
188
189 #[allow(dead_code)]
191 pub fn invalidate_backend(&mut self, backend: &str) {
192 let to_remove: Vec<ShaderVersion> = self
193 .entries
194 .keys()
195 .filter(|v| v.backend == backend)
196 .cloned()
197 .collect();
198 for key in to_remove {
199 self.remove(&key);
200 }
201 }
202
203 #[allow(dead_code)]
205 pub fn clear(&mut self) {
206 self.entries.clear();
207 self.last_access.clear();
208 self.stats.total_bytes = 0;
209 self.stats.entry_count = 0;
210 }
211
212 #[allow(dead_code)]
214 #[must_use]
215 pub fn stats(&self) -> &ShaderCacheStats {
216 &self.stats
217 }
218
219 #[allow(dead_code)]
221 #[must_use]
222 pub fn len(&self) -> usize {
223 self.entries.len()
224 }
225
226 #[allow(dead_code)]
228 #[must_use]
229 pub fn is_empty(&self) -> bool {
230 self.entries.is_empty()
231 }
232
233 fn needs_eviction(&self, incoming_bytes: usize) -> bool {
238 let bytes_after = self.stats.total_bytes + incoming_bytes;
239 bytes_after > self.max_bytes || self.stats.entry_count >= self.max_entries
240 }
241
242 fn evict_one(&mut self) -> bool {
245 if self.entries.is_empty() {
246 return false;
247 }
248
249 let victim_key: Option<ShaderVersion> = match self.policy {
250 EvictionPolicy::Lru => {
251 self.last_access
253 .iter()
254 .min_by_key(|(_, t)| *t)
255 .map(|(k, _)| k.clone())
256 }
257 EvictionPolicy::Lfu => {
258 self.entries
260 .iter()
261 .min_by_key(|(_, v)| v.hit_count)
262 .map(|(k, _)| k.clone())
263 }
264 EvictionPolicy::OldestFirst => {
265 self.entries
267 .iter()
268 .min_by_key(|(_, v)| v.created_at)
269 .map(|(k, _)| k.clone())
270 }
271 };
272
273 if let Some(key) = victim_key {
274 self.remove(&key);
275 self.stats.evictions += 1;
276 true
277 } else {
278 false
279 }
280 }
281}
282
283impl Default for GpuShaderCache {
284 fn default() -> Self {
285 Self::new(64 * 1024 * 1024, 256, EvictionPolicy::Lru)
287 }
288}
289
290#[allow(dead_code)]
292#[must_use]
293pub fn hash_source(data: &[u8]) -> u64 {
294 const FNV_OFFSET: u64 = 14_695_981_039_346_656_037;
295 const FNV_PRIME: u64 = 1_099_511_628_211;
296 let mut hash = FNV_OFFSET;
297 for &byte in data {
298 hash ^= u64(byte);
299 hash = hash.wrapping_mul(FNV_PRIME);
300 }
301 hash
302}
303
304#[inline(always)]
306fn u64(v: u8) -> u64 {
307 u64::from(v)
308}
309
310#[allow(dead_code)]
312#[must_use]
313pub fn age_of(t: SystemTime) -> Duration {
314 SystemTime::now()
315 .duration_since(t)
316 .unwrap_or(Duration::ZERO)
317}
318
319#[cfg(test)]
324mod tests {
325 use super::*;
326
327 fn make_version(hash: u64) -> ShaderVersion {
328 ShaderVersion::new(hash, "vulkan", 0)
329 }
330
331 fn make_shader(hash: u64, size: usize) -> CompiledShader {
332 CompiledShader::new(vec![0u8; size], make_version(hash))
333 }
334
335 #[test]
336 fn test_insert_and_get() {
337 let mut cache = GpuShaderCache::default();
338 let shader = make_shader(1, 100);
339 let version = shader.version.clone();
340 cache.insert(shader);
341 assert!(cache.get(&version).is_some());
342 }
343
344 #[test]
345 fn test_cache_miss() {
346 let mut cache = GpuShaderCache::default();
347 let v = make_version(42);
348 assert!(cache.get(&v).is_none());
349 assert_eq!(cache.stats().misses, 1);
350 }
351
352 #[test]
353 fn test_hit_count_increments() {
354 let mut cache = GpuShaderCache::default();
355 let shader = make_shader(7, 50);
356 let version = shader.version.clone();
357 cache.insert(shader);
358 cache.get(&version);
359 cache.get(&version);
360 assert_eq!(cache.get(&version).unwrap().hit_count, 3);
361 }
362
363 #[test]
364 fn test_remove() {
365 let mut cache = GpuShaderCache::default();
366 let shader = make_shader(99, 200);
367 let version = shader.version.clone();
368 cache.insert(shader);
369 assert!(cache.remove(&version).is_some());
370 assert!(cache.get(&version).is_none());
371 }
372
373 #[test]
374 fn test_contains() {
375 let mut cache = GpuShaderCache::default();
376 let shader = make_shader(5, 10);
377 let version = shader.version.clone();
378 assert!(!cache.contains(&version));
379 cache.insert(shader);
380 assert!(cache.contains(&version));
381 }
382
383 #[test]
384 fn test_clear() {
385 let mut cache = GpuShaderCache::default();
386 cache.insert(make_shader(1, 10));
387 cache.insert(make_shader(2, 10));
388 cache.clear();
389 assert!(cache.is_empty());
390 assert_eq!(cache.stats().total_bytes, 0);
391 }
392
393 #[test]
394 fn test_eviction_by_entry_count() {
395 let mut cache = GpuShaderCache::new(usize::MAX, 2, EvictionPolicy::Lfu);
397 cache.insert(make_shader(1, 10));
398 cache.insert(make_shader(2, 10));
399 cache.get(&make_version(2));
401 cache.insert(make_shader(3, 10));
403 assert_eq!(cache.len(), 2);
404 assert!(cache.stats().evictions >= 1);
405 }
406
407 #[test]
408 fn test_eviction_by_bytes() {
409 let mut cache = GpuShaderCache::new(30, usize::MAX, EvictionPolicy::OldestFirst);
411 cache.insert(make_shader(1, 15));
412 cache.insert(make_shader(2, 15));
413 cache.insert(make_shader(3, 15));
415 assert!(cache.stats().evictions >= 1);
416 }
417
418 #[test]
419 fn test_invalidate_backend() {
420 let mut cache = GpuShaderCache::default();
421 let v1 = ShaderVersion::new(1, "vulkan", 0);
422 let v2 = ShaderVersion::new(2, "metal", 0);
423 cache.insert(CompiledShader::new(vec![0u8; 10], v1));
424 cache.insert(CompiledShader::new(vec![0u8; 10], v2.clone()));
425 cache.invalidate_backend("vulkan");
426 assert!(!cache.contains(&ShaderVersion::new(1, "vulkan", 0)));
427 assert!(cache.contains(&v2));
428 }
429
430 #[test]
431 fn test_hash_source_deterministic() {
432 let data = b"hello world shader";
433 assert_eq!(hash_source(data), hash_source(data));
434 }
435
436 #[test]
437 fn test_hash_source_differs_for_different_inputs() {
438 assert_ne!(hash_source(b"shader_a"), hash_source(b"shader_b"));
439 }
440
441 #[test]
442 fn test_default_cache_capacity() {
443 let cache = GpuShaderCache::default();
444 assert!(cache.is_empty());
445 }
446
447 #[test]
448 fn test_shader_version_equality() {
449 let v1 = ShaderVersion::new(10, "dx12", 3);
450 let v2 = ShaderVersion::new(10, "dx12", 3);
451 let v3 = ShaderVersion::new(10, "dx12", 4);
452 assert_eq!(v1, v2);
453 assert_ne!(v1, v3);
454 }
455
456 #[test]
457 fn test_age_of_is_non_negative() {
458 let t = SystemTime::now();
459 let age = age_of(t);
460 assert!(age < Duration::from_secs(5));
462 }
463}