1use super::types::ShaderKey;
6
7pub(super) fn collect_placeholders(src: &str) -> Vec<String> {
9 let mut result = Vec::new();
10 let mut rest = src;
11 while let Some(start) = rest.find("{{") {
12 let after_open = &rest[start + 2..];
13 if let Some(end) = after_open.find("}}") {
14 let token = after_open[..end].trim().to_string();
15 if !token.is_empty() && !result.contains(&token) {
16 result.push(token);
17 }
18 rest = &after_open[end + 2..];
19 } else {
20 break;
21 }
22 }
23 result
24}
25pub fn compute_cache_key(wgsl: &str) -> u64 {
30 pub(super) const FNV_OFFSET: u64 = 14695981039346656037;
31 pub(super) const FNV_PRIME: u64 = 1099511628211;
32 let mut hash = FNV_OFFSET;
33 for byte in wgsl.bytes() {
34 hash ^= byte as u64;
35 hash = hash.wrapping_mul(FNV_PRIME);
36 }
37 hash
38}
39pub fn compute_shader_cache_key(key: &ShaderKey) -> u64 {
43 let mut repr = key.name.clone();
44 repr.push('\x00');
45 for (k, v) in &key.defines {
46 repr.push_str(k);
47 repr.push('=');
48 repr.push_str(v);
49 repr.push(';');
50 }
51 compute_cache_key(&repr)
52}
53#[cfg(test)]
54mod shader_registry_tests {
55 use super::*;
56 use crate::shader_registry::CompiledVariant;
57 use crate::shader_registry::HotReloadTracker;
58
59 use crate::shader_registry::PipelineDescriptor;
60 use crate::shader_registry::RegistryError;
61 use crate::shader_registry::ShaderCompileOptions;
62
63 use crate::shader_registry::ShaderKey;
64 use crate::shader_registry::ShaderRegistry;
65 use crate::shader_registry::ShaderSource;
66
67 use crate::shader_registry::VariantCache;
68
69 use std::collections::HashMap;
70 fn make_compiled_variant(
73 key: ShaderKey,
74 wgsl: String,
75 workgroup_size: [u32; 3],
76 ) -> CompiledVariant {
77 let binary_size_bytes = wgsl.len();
78 CompiledVariant {
79 key,
80 wgsl,
81 workgroup_size,
82 binary_size_bytes,
83 }
84 }
85 fn simple_source(name: &str) -> ShaderSource {
86 ShaderSource::new(
87 name,
88 "@compute @workgroup_size({{WG_SIZE}}) fn main() {}",
89 [64, 1, 1],
90 )
91 }
92 #[test]
93 fn test_shader_key_fingerprint_bare() {
94 let key = ShaderKey::bare("sph_density");
95 assert_eq!(key.fingerprint(), "sph_density");
96 }
97 #[test]
98 fn test_shader_key_fingerprint_with_defines() {
99 let mut d = HashMap::new();
100 d.insert("WG_SIZE".to_string(), "64".to_string());
101 let key = ShaderKey::new("sph", &d);
102 assert!(key.fingerprint().contains("WG_SIZE"));
103 assert!(key.fingerprint().contains("64"));
104 }
105 #[test]
106 fn test_shader_key_sorted_defines() {
107 let mut d = HashMap::new();
108 d.insert("Z_FLAG".to_string(), "1".to_string());
109 d.insert("A_FLAG".to_string(), "2".to_string());
110 let key = ShaderKey::new("s", &d);
111 assert_eq!(
112 key.defines[0].0, "A_FLAG",
113 "defines should be sorted by key"
114 );
115 }
116 #[test]
117 fn test_source_placeholder_detection() {
118 let src = ShaderSource::new(
119 "test",
120 "size={{SIZE}} dtype={{DTYPE}} size={{SIZE}}",
121 [32, 1, 1],
122 );
123 assert_eq!(src.placeholders.len(), 2);
124 assert!(src.placeholders.contains(&"SIZE".to_string()));
125 assert!(src.placeholders.contains(&"DTYPE".to_string()));
126 }
127 #[test]
128 fn test_source_instantiate_substitutes_tokens() {
129 let src = ShaderSource::new("t", "workgroup={{WG}} type={{T}}", [1, 1, 1]);
130 let mut d = HashMap::new();
131 d.insert("WG".to_string(), "128".to_string());
132 d.insert("T".to_string(), "f32".to_string());
133 let out = src.instantiate(&d);
134 assert!(out.contains("128"));
135 assert!(out.contains("f32"));
136 assert!(!out.contains("{{WG}}"));
137 }
138 #[test]
139 fn test_variant_cache_lru_eviction() {
140 let mut cache = VariantCache::new(2);
141 let make = |name: &str| {
142 make_compiled_variant(
143 ShaderKey::bare(name),
144 format!("fn {}(){{}}", name),
145 [64, 1, 1],
146 )
147 };
148 cache.insert(make("a"));
149 cache.insert(make("b"));
150 let _ = cache.get(&ShaderKey::bare("a"));
151 cache.insert(make("c"));
152 assert!(
153 cache.get(&ShaderKey::bare("a")).is_some(),
154 "a should survive"
155 );
156 assert!(
157 cache.get(&ShaderKey::bare("b")).is_none(),
158 "b should be evicted"
159 );
160 assert!(
161 cache.get(&ShaderKey::bare("c")).is_some(),
162 "c should be present"
163 );
164 }
165 #[test]
166 fn test_registry_get_or_compile_success() {
167 let mut reg = ShaderRegistry::new(8);
168 reg.register(simple_source("dense_layer"));
169 let mut d = HashMap::new();
170 d.insert("WG_SIZE".to_string(), "64".to_string());
171 let v = reg.get_or_compile("dense_layer", &d).unwrap();
172 assert!(v.wgsl.contains("64"), "placeholder should be substituted");
173 assert_eq!(reg.compilations, 1);
174 }
175 #[test]
176 fn test_registry_cache_hit() {
177 let mut reg = ShaderRegistry::new(8);
178 reg.register(simple_source("sph"));
179 let mut d = HashMap::new();
180 d.insert("WG_SIZE".to_string(), "32".to_string());
181 reg.get_or_compile("sph", &d).unwrap();
182 reg.get_or_compile("sph", &d).unwrap();
183 assert_eq!(reg.cache_hits, 1);
184 assert_eq!(reg.compilations, 1);
185 }
186 #[test]
187 fn test_registry_unknown_shader_error() {
188 let mut reg = ShaderRegistry::new(4);
189 let res = reg.get_or_compile("ghost", &HashMap::new());
190 assert!(matches!(res, Err(RegistryError::UnknownShader(_))));
191 }
192 #[test]
193 fn test_registry_missing_define_error() {
194 let mut reg = ShaderRegistry::new(4);
195 reg.register(simple_source("needs_wg"));
196 let res = reg.get_or_compile("needs_wg", &HashMap::new());
197 assert!(
198 matches!(res, Err(RegistryError::MissingDefine { .. })),
199 "should fail when required define is absent"
200 );
201 }
202 #[test]
203 fn test_registry_invalidate_clears_cache() {
204 let mut reg = ShaderRegistry::new(8);
205 reg.register(simple_source("shader_x"));
206 let mut d = HashMap::new();
207 d.insert("WG_SIZE".to_string(), "64".to_string());
208 reg.get_or_compile("shader_x", &d).unwrap();
209 assert_eq!(reg.cached_count(), 1);
210 reg.invalidate("shader_x");
211 assert_eq!(
212 reg.cached_count(),
213 0,
214 "invalidate should clear variant cache"
215 );
216 }
217 #[test]
218 fn test_hot_reload_tracker_needs_recompile() {
219 let mut tracker = HotReloadTracker::new();
220 tracker.touch("my_shader", 10);
221 tracker.record_compile("my_shader", 5);
222 assert!(
223 tracker.needs_recompile("my_shader"),
224 "modified at 10 > compiled at 5"
225 );
226 tracker.record_compile("my_shader", 10);
227 assert!(
228 !tracker.needs_recompile("my_shader"),
229 "up-to-date after recompile"
230 );
231 }
232 #[test]
233 fn test_hot_reload_tracker_stale_list() {
234 let mut tracker = HotReloadTracker::new();
235 tracker.touch("a", 5);
236 tracker.touch("b", 5);
237 tracker.record_compile("a", 10);
238 let stale = tracker.stale_shaders();
239 assert!(!stale.contains(&"a"), "a is up-to-date");
240 assert!(stale.contains(&"b"), "b has never been compiled");
241 }
242 #[test]
243 fn test_pipeline_descriptor_validate() {
244 let key = ShaderKey::bare("sph");
245 let desc = PipelineDescriptor::new(key.clone(), 2, 0, "sph_pipeline");
246 assert!(desc.validate().is_ok());
247 let bad = PipelineDescriptor::new(key, 0, 0, "bad");
248 assert!(bad.validate().is_err());
249 }
250 #[test]
251 fn test_shader_key_bare_no_defines() {
252 let key = ShaderKey::bare("my_shader");
253 assert!(key.defines.is_empty());
254 assert_eq!(key.name, "my_shader");
255 }
256 #[test]
257 fn test_shader_key_equality() {
258 let mut d1 = HashMap::new();
259 d1.insert("A".to_string(), "1".to_string());
260 let mut d2 = HashMap::new();
261 d2.insert("A".to_string(), "1".to_string());
262 let k1 = ShaderKey::new("s", &d1);
263 let k2 = ShaderKey::new("s", &d2);
264 assert_eq!(k1, k2);
265 }
266 #[test]
267 fn test_shader_key_inequality_different_value() {
268 let mut d1 = HashMap::new();
269 d1.insert("A".to_string(), "1".to_string());
270 let mut d2 = HashMap::new();
271 d2.insert("A".to_string(), "2".to_string());
272 let k1 = ShaderKey::new("s", &d1);
273 let k2 = ShaderKey::new("s", &d2);
274 assert_ne!(k1, k2);
275 }
276 #[test]
277 fn test_shader_key_inequality_different_name() {
278 let k1 = ShaderKey::bare("shader_a");
279 let k2 = ShaderKey::bare("shader_b");
280 assert_ne!(k1, k2);
281 }
282 #[test]
283 fn test_shader_key_fingerprint_multiple_defines() {
284 let mut d = HashMap::new();
285 d.insert("B".to_string(), "2".to_string());
286 d.insert("A".to_string(), "1".to_string());
287 let key = ShaderKey::new("s", &d);
288 let fp = key.fingerprint();
289 let a_pos = fp.find("A_1").unwrap();
290 let b_pos = fp.find("B_2").unwrap();
291 assert!(
292 a_pos < b_pos,
293 "defines should be alphabetical in fingerprint"
294 );
295 }
296 #[test]
297 fn test_shader_source_threads_per_group() {
298 let src = ShaderSource::new("t", "fn main() {}", [4, 8, 2]);
299 assert_eq!(src.threads_per_group(), 64);
300 }
301 #[test]
302 fn test_shader_source_no_placeholders() {
303 let src = ShaderSource::new("t", "fn main() { let x = 1; }", [64, 1, 1]);
304 assert!(src.placeholders.is_empty());
305 }
306 #[test]
307 fn test_shader_source_instantiate_no_defines() {
308 let src = ShaderSource::new("t", "fn main() {}", [1, 1, 1]);
309 let out = src.instantiate(&HashMap::new());
310 assert_eq!(out, "fn main() {}");
311 }
312 #[test]
313 fn test_shader_source_instantiate_partial() {
314 let src = ShaderSource::new("t", "{{A}} and {{B}}", [1, 1, 1]);
315 let mut d = HashMap::new();
316 d.insert("A".to_string(), "hello".to_string());
317 let out = src.instantiate(&d);
318 assert!(out.contains("hello"), "A should be substituted");
319 assert!(out.contains("{{B}}"), "B should remain if not supplied");
320 }
321 #[test]
322 fn test_shader_source_multiple_occurrences_replaced() {
323 let src = ShaderSource::new("t", "{{X}} and {{X}} and {{X}}", [1, 1, 1]);
324 let mut d = HashMap::new();
325 d.insert("X".to_string(), "42".to_string());
326 let out = src.instantiate(&d);
327 assert_eq!(out.matches("42").count(), 3);
328 assert!(!out.contains("{{X}}"));
329 }
330 #[test]
331 fn test_variant_cache_empty_initially() {
332 let cache = VariantCache::new(5);
333 assert!(cache.is_empty());
334 assert_eq!(cache.len(), 0);
335 }
336 #[test]
337 fn test_variant_cache_insert_and_get() {
338 let mut cache = VariantCache::new(5);
339 let v = make_compiled_variant(ShaderKey::bare("s"), "fn s() {}".to_string(), [1, 1, 1]);
340 cache.insert(v);
341 assert_eq!(cache.len(), 1);
342 let result = cache.get(&ShaderKey::bare("s"));
343 assert!(result.is_some());
344 }
345 #[test]
346 fn test_variant_cache_miss() {
347 let mut cache = VariantCache::new(5);
348 let result = cache.get(&ShaderKey::bare("nonexistent"));
349 assert!(result.is_none());
350 }
351 #[test]
352 fn test_variant_cache_update_existing() {
353 let mut cache = VariantCache::new(5);
354 let v1 = make_compiled_variant(ShaderKey::bare("s"), "v1".to_string(), [1, 1, 1]);
355 let v2 = make_compiled_variant(ShaderKey::bare("s"), "v2".to_string(), [1, 1, 1]);
356 cache.insert(v1);
357 cache.insert(v2);
358 assert_eq!(cache.len(), 1, "update should not add a second entry");
359 let got = cache.get(&ShaderKey::bare("s")).unwrap();
360 assert_eq!(got.wgsl, "v2");
361 }
362 #[test]
363 fn test_variant_cache_clear() {
364 let mut cache = VariantCache::new(5);
365 cache.insert(make_compiled_variant(
366 ShaderKey::bare("a"),
367 "a".to_string(),
368 [1, 1, 1],
369 ));
370 cache.insert(make_compiled_variant(
371 ShaderKey::bare("b"),
372 "b".to_string(),
373 [1, 1, 1],
374 ));
375 cache.clear();
376 assert!(cache.is_empty());
377 }
378 #[test]
379 fn test_variant_cache_total_binary_bytes() {
380 let mut cache = VariantCache::new(5);
381 cache.insert(make_compiled_variant(
382 ShaderKey::bare("a"),
383 "abc".to_string(),
384 [1, 1, 1],
385 ));
386 cache.insert(make_compiled_variant(
387 ShaderKey::bare("b"),
388 "xy".to_string(),
389 [1, 1, 1],
390 ));
391 assert_eq!(cache.total_binary_bytes(), 5);
392 }
393 #[test]
394 fn test_variant_cache_capacity_one_evicts_on_insert() {
395 let mut cache = VariantCache::new(1);
396 cache.insert(make_compiled_variant(
397 ShaderKey::bare("a"),
398 "a".to_string(),
399 [1, 1, 1],
400 ));
401 cache.insert(make_compiled_variant(
402 ShaderKey::bare("b"),
403 "b".to_string(),
404 [1, 1, 1],
405 ));
406 assert_eq!(
407 cache.len(),
408 1,
409 "capacity=1 cache should only hold one entry"
410 );
411 let has_a = cache.get(&ShaderKey::bare("a")).is_some();
412 let has_b = cache.get(&ShaderKey::bare("b")).is_some();
413 assert!(has_a || has_b, "at least one entry should be present");
414 }
415 #[test]
416 fn test_registry_empty_initially() {
417 let reg = ShaderRegistry::new(8);
418 assert!(reg.shader_names().is_empty());
419 assert_eq!(reg.cached_count(), 0);
420 }
421 #[test]
422 fn test_registry_register_multiple_shaders() {
423 let mut reg = ShaderRegistry::new(8);
424 reg.register(ShaderSource::new("a", "fn a(){}", [1, 1, 1]));
425 reg.register(ShaderSource::new("b", "fn b(){}", [1, 1, 1]));
426 assert_eq!(reg.shader_names().len(), 2);
427 }
428 #[test]
429 fn test_registry_re_register_clears_cache() {
430 let mut reg = ShaderRegistry::new(8);
431 reg.register(ShaderSource::new("s", "fn s(){}", [1, 1, 1]));
432 reg.get_or_compile("s", &HashMap::new()).unwrap();
433 assert_eq!(reg.cached_count(), 1);
434 reg.register(ShaderSource::new("s", "fn s_v2(){}", [1, 1, 1]));
435 assert_eq!(
436 reg.cached_count(),
437 0,
438 "re-register should clear old variants"
439 );
440 }
441 #[test]
442 fn test_registry_invalidate_all() {
443 let mut reg = ShaderRegistry::new(8);
444 reg.register(ShaderSource::new("a", "fn a(){}", [1, 1, 1]));
445 reg.register(ShaderSource::new("b", "fn b(){}", [1, 1, 1]));
446 reg.get_or_compile("a", &HashMap::new()).unwrap();
447 reg.get_or_compile("b", &HashMap::new()).unwrap();
448 assert_eq!(reg.cached_count(), 2);
449 reg.invalidate_all();
450 assert_eq!(reg.cached_count(), 0);
451 }
452 #[test]
453 fn test_registry_compilation_count_increments() {
454 let mut reg = ShaderRegistry::new(8);
455 reg.register(ShaderSource::new("x", "fn x(){}", [1, 1, 1]));
456 reg.register(ShaderSource::new("y", "fn y(){}", [1, 1, 1]));
457 reg.get_or_compile("x", &HashMap::new()).unwrap();
458 reg.get_or_compile("y", &HashMap::new()).unwrap();
459 assert_eq!(reg.compilations, 2);
460 }
461 #[test]
462 fn test_registry_multiple_cache_hits() {
463 let mut reg = ShaderRegistry::new(8);
464 reg.register(ShaderSource::new("z", "fn z(){}", [1, 1, 1]));
465 for _ in 0..5 {
466 reg.get_or_compile("z", &HashMap::new()).unwrap();
467 }
468 assert_eq!(reg.compilations, 1);
469 assert_eq!(reg.cache_hits, 4);
470 }
471 #[test]
472 fn test_hot_reload_no_entry_returns_false() {
473 let tracker = HotReloadTracker::new();
474 assert!(!tracker.needs_recompile("phantom_shader"));
475 }
476 #[test]
477 fn test_hot_reload_fresh_compile() {
478 let mut tracker = HotReloadTracker::new();
479 tracker.touch("shader", 5);
480 tracker.record_compile("shader", 5);
481 assert!(!tracker.needs_recompile("shader"));
482 }
483 #[test]
484 fn test_hot_reload_stale_shaders_empty() {
485 let tracker = HotReloadTracker::new();
486 assert!(tracker.stale_shaders().is_empty());
487 }
488 #[test]
489 fn test_hot_reload_multiple_stale() {
490 let mut tracker = HotReloadTracker::new();
491 tracker.touch("a", 10);
492 tracker.touch("b", 20);
493 tracker.touch("c", 5);
494 tracker.record_compile("c", 10);
495 let stale = tracker.stale_shaders();
496 assert!(stale.contains(&"a"), "a is stale");
497 assert!(stale.contains(&"b"), "b is stale");
498 assert!(!stale.contains(&"c"), "c is up to date");
499 }
500 #[test]
501 fn test_registry_error_display_unknown() {
502 let e = RegistryError::UnknownShader("foo".to_string());
503 assert!(e.to_string().contains("foo"));
504 }
505 #[test]
506 fn test_registry_error_display_missing_define() {
507 let e = RegistryError::MissingDefine {
508 shader: "sph".to_string(),
509 define: "WG_SIZE".to_string(),
510 };
511 let s = e.to_string();
512 assert!(s.contains("sph") && s.contains("WG_SIZE"));
513 }
514 #[test]
515 fn test_registry_error_display_source_too_large() {
516 let e = RegistryError::SourceTooLarge {
517 shader: "big".to_string(),
518 size: 100_000,
519 limit: 65_536,
520 };
521 let s = e.to_string();
522 assert!(s.contains("big") && s.contains("100000") && s.contains("65536"));
523 }
524 #[test]
525 fn test_pipeline_descriptor_fields() {
526 let key = ShaderKey::bare("nn_forward");
527 let desc = PipelineDescriptor::new(key.clone(), 3, 16, "nn_fwd");
528 assert_eq!(desc.bind_group_count, 3);
529 assert_eq!(desc.push_constant_bytes, 16);
530 assert_eq!(desc.label, "nn_fwd");
531 assert_eq!(desc.key, key);
532 }
533 #[test]
534 fn test_pipeline_descriptor_validate_with_bind_groups() {
535 for count in 1..=4 {
536 let key = ShaderKey::bare("test");
537 let desc = PipelineDescriptor::new(key, count, 0, "lbl");
538 assert!(
539 desc.validate().is_ok(),
540 "bind_group_count={count} should be valid"
541 );
542 }
543 }
544 #[test]
545 fn test_compile_variant_basic() {
546 let mut reg = ShaderRegistry::new(4);
547 reg.register(simple_source("dense"));
548 let mut d = HashMap::new();
549 d.insert("WG_SIZE".to_string(), "128".to_string());
550 let opts = ShaderCompileOptions::new();
551 let v = reg.compile_variant("dense", &d, &opts).unwrap();
552 assert!(v.wgsl.contains("128"), "define should be substituted");
553 }
554 #[test]
555 fn test_compile_variant_unknown_shader() {
556 let mut reg = ShaderRegistry::new(4);
557 let opts = ShaderCompileOptions::new();
558 let res = reg.compile_variant("missing", &HashMap::new(), &opts);
559 assert!(matches!(res, Err(RegistryError::UnknownShader(_))));
560 }
561 #[test]
562 fn test_compile_variant_source_too_large() {
563 let mut reg = ShaderRegistry::new(4);
564 reg.register(simple_source("big"));
565 let mut d = HashMap::new();
566 d.insert("WG_SIZE".to_string(), "64".to_string());
567 let mut opts = ShaderCompileOptions::new();
568 opts.max_source_bytes = 5;
569 let res = reg.compile_variant("big", &d, &opts);
570 assert!(matches!(res, Err(RegistryError::SourceTooLarge { .. })));
571 }
572 #[test]
573 fn test_compile_variant_extra_defines_are_merged() {
574 let src = ShaderSource::new("s", "{{A}} {{B}}", [32, 1, 1]);
575 let mut reg = ShaderRegistry::new(4);
576 reg.register(src);
577 let mut d = HashMap::new();
578 d.insert("A".to_string(), "hello".to_string());
579 let mut opts = ShaderCompileOptions::new();
580 opts.extra_defines
581 .insert("B".to_string(), "world".to_string());
582 let v = reg.compile_variant("s", &d, &opts).unwrap();
583 assert!(v.wgsl.contains("hello"), "A should be substituted");
584 assert!(
585 v.wgsl.contains("world"),
586 "B should be substituted from extra_defines"
587 );
588 }
589 #[test]
590 fn test_compile_variant_increments_compilation_count() {
591 let mut reg = ShaderRegistry::new(4);
592 reg.register(simple_source("s"));
593 let mut d = HashMap::new();
594 d.insert("WG_SIZE".to_string(), "16".to_string());
595 let opts = ShaderCompileOptions::new();
596 reg.compile_variant("s", &d, &opts).unwrap();
597 reg.compile_variant("s", &d, &opts).unwrap();
598 assert_eq!(
599 reg.compilations, 2,
600 "each compile_variant call should count"
601 );
602 }
603 #[test]
604 fn test_compute_cache_key_deterministic() {
605 let src = "fn main() {}";
606 assert_eq!(compute_cache_key(src), compute_cache_key(src));
607 }
608 #[test]
609 fn test_compute_cache_key_different_inputs() {
610 let k1 = compute_cache_key("fn a() {}");
611 let k2 = compute_cache_key("fn b() {}");
612 assert_ne!(
613 k1, k2,
614 "different sources should produce different cache keys"
615 );
616 }
617 #[test]
618 fn test_compute_shader_cache_key_bare_vs_defines() {
619 let bare_key = ShaderKey::bare("sph");
620 let mut d = HashMap::new();
621 d.insert("WG_SIZE".to_string(), "64".to_string());
622 let define_key = ShaderKey::new("sph", &d);
623 let k1 = compute_shader_cache_key(&bare_key);
624 let k2 = compute_shader_cache_key(&define_key);
625 assert_ne!(k1, k2, "bare and define keys should differ");
626 }
627 #[test]
628 fn test_compute_shader_cache_key_order_independent() {
629 let mut d1 = HashMap::new();
630 d1.insert("A".to_string(), "1".to_string());
631 d1.insert("B".to_string(), "2".to_string());
632 let mut d2 = HashMap::new();
633 d2.insert("B".to_string(), "2".to_string());
634 d2.insert("A".to_string(), "1".to_string());
635 let k1 = compute_shader_cache_key(&ShaderKey::new("s", &d1));
636 let k2 = compute_shader_cache_key(&ShaderKey::new("s", &d2));
637 assert_eq!(k1, k2, "insertion order should not affect cache key");
638 }
639 #[test]
640 fn test_hot_reload_touch_batch() {
641 let mut tracker = HotReloadTracker::new();
642 tracker.touch_batch(&["a", "b", "c"], 5);
643 assert!(tracker.needs_recompile("a"));
644 assert!(tracker.needs_recompile("b"));
645 assert!(tracker.needs_recompile("c"));
646 }
647 #[test]
648 fn test_hot_reload_flush_stale() {
649 let mut tracker = HotReloadTracker::new();
650 tracker.touch_batch(&["x", "y"], 10);
651 tracker.flush_stale(20);
652 assert!(
653 !tracker.needs_recompile("x"),
654 "x should be up to date after flush"
655 );
656 assert!(
657 !tracker.needs_recompile("y"),
658 "y should be up to date after flush"
659 );
660 }
661 #[test]
662 fn test_hot_reload_never_compiled() {
663 let mut tracker = HotReloadTracker::new();
664 tracker.touch("alpha", 1);
665 tracker.touch("beta", 2);
666 tracker.record_compile("beta", 3);
667 let nc = tracker.never_compiled();
668 assert!(nc.contains(&"alpha"), "alpha was never compiled");
669 assert!(!nc.contains(&"beta"), "beta was compiled");
670 }
671}
672#[cfg(test)]
673mod extended_registry_tests {
674
675 use crate::shader_registry::HotReloadTracker;
676 use crate::shader_registry::PipelineCache;
677 use crate::shader_registry::PipelineCacheKey;
678
679 use crate::shader_registry::ShaderDependencyGraph;
680 use crate::shader_registry::ShaderKey;
681 use crate::shader_registry::ShaderRegistry;
682 use crate::shader_registry::ShaderSource;
683 use crate::shader_registry::SpecConstSet;
684 use crate::shader_registry::SpecConstValue;
685 use crate::shader_registry::SpecializationConstant;
686
687 use crate::shader_registry::VariantProfile;
688 use crate::shader_registry::VariantProfileRegistry;
689 use std::collections::HashMap;
690 #[test]
691 fn variant_profile_no_base_resolve() {
692 let mut reg = VariantProfileRegistry::new();
693 let p = VariantProfile::new("default")
694 .set("WG_SIZE", "64")
695 .set("DTYPE", "f32");
696 reg.register(p);
697 let defines = reg.resolve("default").unwrap();
698 assert_eq!(defines.get("WG_SIZE").map(|s| s.as_str()), Some("64"));
699 assert_eq!(defines.get("DTYPE").map(|s| s.as_str()), Some("f32"));
700 }
701 #[test]
702 fn variant_profile_inherits_base() {
703 let mut reg = VariantProfileRegistry::new();
704 reg.register(
705 VariantProfile::new("base")
706 .set("WG_SIZE", "64")
707 .set("DTYPE", "f32"),
708 );
709 reg.register(
710 VariantProfile::with_base("child", "base")
711 .set("DTYPE", "f64")
712 .set("EXTRA", "1"),
713 );
714 let defines = reg.resolve("child").unwrap();
715 assert_eq!(defines["WG_SIZE"], "64", "WG_SIZE inherited from base");
716 assert_eq!(defines["DTYPE"], "f64", "DTYPE overridden by child");
717 assert_eq!(defines["EXTRA"], "1", "EXTRA added by child");
718 }
719 #[test]
720 fn variant_profile_missing_base_returns_none() {
721 let reg = VariantProfileRegistry::new();
722 assert!(reg.resolve("nonexistent").is_none());
723 }
724 #[test]
725 fn variant_profile_registry_lists_names() {
726 let mut reg = VariantProfileRegistry::new();
727 reg.register(VariantProfile::new("b"));
728 reg.register(VariantProfile::new("a"));
729 let names = reg.profile_names();
730 assert_eq!(names, vec!["a", "b"], "names should be sorted");
731 }
732 #[test]
733 fn variant_profile_overrides_base_chain() {
734 let mut reg = VariantProfileRegistry::new();
735 reg.register(VariantProfile::new("root").set("X", "1").set("Y", "2"));
736 reg.register(
737 VariantProfile::with_base("mid", "root")
738 .set("Y", "20")
739 .set("Z", "3"),
740 );
741 reg.register(VariantProfile::with_base("leaf", "mid").set("Z", "30"));
742 let defines = reg.resolve("leaf").unwrap();
743 assert_eq!(defines["X"], "1");
744 assert_eq!(defines["Y"], "20");
745 assert_eq!(defines["Z"], "30");
746 }
747 #[test]
748 fn dependency_graph_direct_dependents() {
749 let mut g = ShaderDependencyGraph::new();
750 g.add_dependency("sph_density", "common_utils");
751 g.add_dependency("sph_pressure", "common_utils");
752 let deps = g.direct_dependents("common_utils");
753 assert!(deps.contains(&"sph_density"));
754 assert!(deps.contains(&"sph_pressure"));
755 }
756 #[test]
757 fn dependency_graph_transitive_dependents() {
758 let mut g = ShaderDependencyGraph::new();
759 g.add_dependency("b", "a");
760 g.add_dependency("c", "b");
761 g.add_dependency("d", "c");
762 let transitive = g.transitive_dependents("a");
763 assert!(transitive.contains(&"b".to_string()), "b depends on a");
764 assert!(
765 transitive.contains(&"c".to_string()),
766 "c transitively depends on a via b"
767 );
768 assert!(
769 transitive.contains(&"d".to_string()),
770 "d transitively depends on a via b,c"
771 );
772 }
773 #[test]
774 fn dependency_graph_no_dependents() {
775 let g = ShaderDependencyGraph::new();
776 assert!(g.transitive_dependents("orphan").is_empty());
777 assert!(g.direct_dependents("orphan").is_empty());
778 }
779 #[test]
780 fn dependency_graph_shaders_with_deps_sorted() {
781 let mut g = ShaderDependencyGraph::new();
782 g.add_dependency("z_shader", "a");
783 g.add_dependency("a_shader", "b");
784 let names = g.shaders_with_deps();
785 assert_eq!(names, vec!["a_shader", "z_shader"]);
786 }
787 #[test]
788 fn dependency_graph_direct_dependencies() {
789 let mut g = ShaderDependencyGraph::new();
790 g.add_dependency("main", "utils");
791 g.add_dependency("main", "math");
792 let deps = g.direct_dependencies("main");
793 assert_eq!(deps.len(), 2);
794 assert!(deps.contains(&"utils".to_string()));
795 assert!(deps.contains(&"math".to_string()));
796 }
797 #[test]
798 fn spec_const_default_value() {
799 let c = SpecializationConstant::new("WG_SIZE", SpecConstValue::Uint(64));
800 assert_eq!(c.effective_value(), &SpecConstValue::Uint(64));
801 assert_eq!(c.effective_value().to_wgsl(), "64u");
802 }
803 #[test]
804 fn spec_const_override_value() {
805 let c = SpecializationConstant::new("WG_SIZE", SpecConstValue::Uint(64))
806 .with_override(SpecConstValue::Uint(128));
807 assert_eq!(c.effective_value(), &SpecConstValue::Uint(128));
808 }
809 #[test]
810 fn spec_const_bool_to_wgsl() {
811 let c = SpecializationConstant::new("ENABLE_DEBUG", SpecConstValue::Bool(true));
812 assert_eq!(c.effective_value().to_wgsl(), "true");
813 }
814 #[test]
815 fn spec_const_float_to_wgsl() {
816 let c = SpecializationConstant::new("EPSILON", SpecConstValue::Float(1e-6));
817 let wgsl = c.effective_value().to_wgsl();
818 assert!(
819 wgsl.contains("0.000001")
820 || wgsl.contains("1e-6")
821 || wgsl.contains("1E-6")
822 || !wgsl.is_empty()
823 );
824 }
825 #[test]
826 fn spec_const_set_to_defines() {
827 let mut set = SpecConstSet::new();
828 set.add(SpecializationConstant::new("A", SpecConstValue::Uint(4)));
829 set.add(SpecializationConstant::new(
830 "B",
831 SpecConstValue::Bool(false),
832 ));
833 let defines = set.to_defines();
834 assert_eq!(defines.get("A").map(|s| s.as_str()), Some("4u"));
835 assert_eq!(defines.get("B").map(|s| s.as_str()), Some("false"));
836 }
837 #[test]
838 fn spec_const_set_has_and_get_wgsl() {
839 let mut set = SpecConstSet::new();
840 set.add(SpecializationConstant::new("X", SpecConstValue::Int(-1)));
841 assert!(set.has("X"));
842 assert!(!set.has("Y"));
843 assert_eq!(set.get_wgsl("X"), Some("-1".to_string()));
844 assert!(set.get_wgsl("Y").is_none());
845 }
846 #[test]
847 fn pipeline_cache_key_equality() {
848 let k1 = PipelineCacheKey::new(ShaderKey::bare("s"), [64, 1, 1], 0, 42);
849 let k2 = PipelineCacheKey::new(ShaderKey::bare("s"), [64, 1, 1], 0, 42);
850 assert_eq!(k1, k2);
851 }
852 #[test]
853 fn pipeline_cache_key_different_layout() {
854 let k1 = PipelineCacheKey::new(ShaderKey::bare("s"), [64, 1, 1], 0, 1);
855 let k2 = PipelineCacheKey::new(ShaderKey::bare("s"), [64, 1, 1], 0, 2);
856 assert_ne!(k1, k2);
857 }
858 #[test]
859 fn pipeline_cache_hit_and_miss() {
860 let mut cache = PipelineCache::new();
861 let key = PipelineCacheKey::new(ShaderKey::bare("s"), [64, 1, 1], 0, 0);
862 assert!(cache.get(&key).is_none());
863 cache.insert(key.clone(), "sph_pipeline");
864 assert_eq!(cache.get(&key), Some("sph_pipeline"));
865 assert_eq!(cache.hits, 1);
866 assert_eq!(cache.misses, 1);
867 }
868 #[test]
869 fn pipeline_cache_hit_rate() {
870 let mut cache = PipelineCache::new();
871 let k = PipelineCacheKey::new(ShaderKey::bare("s"), [1, 1, 1], 0, 0);
872 cache.insert(k.clone(), "label");
873 cache.get(&k);
874 cache.get(&k);
875 let k2 = PipelineCacheKey::new(ShaderKey::bare("x"), [1, 1, 1], 0, 0);
876 cache.get(&k2);
877 let rate = cache.hit_rate();
878 assert!((rate - 2.0 / 3.0).abs() < 1e-10, "hit_rate = {rate}");
879 }
880 #[test]
881 fn pipeline_cache_clear() {
882 let mut cache = PipelineCache::new();
883 let k = PipelineCacheKey::new(ShaderKey::bare("s"), [1, 1, 1], 0, 0);
884 cache.insert(k, "lbl");
885 assert_eq!(cache.len(), 1);
886 cache.clear();
887 assert!(cache.is_empty());
888 }
889 #[test]
890 fn pipeline_cache_key_hash_stable() {
891 let k = PipelineCacheKey::new(ShaderKey::bare("sph"), [64, 1, 1], 16, 12345);
892 assert_eq!(k.hash_key(), k.hash_key(), "hash should be deterministic");
893 }
894 #[test]
895 fn registry_apply_hot_reload_invalidates_stale() {
896 let mut reg = ShaderRegistry::new(8);
897 reg.register(ShaderSource::new("shader_a", "fn a(){}", [1, 1, 1]));
898 reg.get_or_compile("shader_a", &HashMap::new()).unwrap();
899 assert_eq!(reg.cached_count(), 1);
900 let mut tracker = HotReloadTracker::new();
901 tracker.touch("shader_a", 100);
902 let invalidated = reg.apply_hot_reload(&tracker);
903 assert!(invalidated.contains(&"shader_a".to_string()));
904 assert_eq!(
905 reg.cached_count(),
906 0,
907 "stale shader variant should be evicted"
908 );
909 }
910 #[test]
911 fn registry_registered_count() {
912 let mut reg = ShaderRegistry::new(8);
913 assert_eq!(reg.registered_count(), 0);
914 reg.register(ShaderSource::new("a", "fn a(){}", [1, 1, 1]));
915 reg.register(ShaderSource::new("b", "fn b(){}", [1, 1, 1]));
916 assert_eq!(reg.registered_count(), 2);
917 }
918 #[test]
919 fn registry_source_bytes() {
920 let wgsl = "fn main() {}";
921 let mut reg = ShaderRegistry::new(4);
922 reg.register(ShaderSource::new("test", wgsl, [1, 1, 1]));
923 assert_eq!(reg.source_bytes("test"), wgsl.len());
924 assert_eq!(reg.source_bytes("nonexistent"), 0);
925 }
926 #[test]
927 fn spec_const_set_used_with_registry() {
928 let mut set = SpecConstSet::new();
929 set.add(SpecializationConstant::new(
930 "WG_SIZE",
931 SpecConstValue::Uint(64),
932 ));
933 let defines = set.to_defines();
934 let mut reg = ShaderRegistry::new(8);
935 reg.register(ShaderSource::new(
936 "compute",
937 "@compute @workgroup_size({{WG_SIZE}}) fn main() {}",
938 [64, 1, 1],
939 ));
940 let v = reg.get_or_compile("compute", &defines).unwrap();
941 assert!(
942 v.wgsl.contains("64u"),
943 "WG_SIZE should be substituted from spec const"
944 );
945 }
946 #[test]
947 fn variant_profile_with_registry_integration() {
948 let mut prof_reg = VariantProfileRegistry::new();
949 prof_reg.register(
950 VariantProfile::new("high")
951 .set("WG_SIZE", "256")
952 .set("DTYPE", "f64"),
953 );
954 let defines = prof_reg.resolve("high").unwrap();
955 let mut shader_reg = ShaderRegistry::new(8);
956 shader_reg.register(ShaderSource::new(
957 "kernel",
958 "wg={{WG_SIZE}} dtype={{DTYPE}}",
959 [256, 1, 1],
960 ));
961 let v = shader_reg.get_or_compile("kernel", &defines).unwrap();
962 assert!(v.wgsl.contains("256"));
963 assert!(v.wgsl.contains("f64"));
964 }
965 #[test]
966 fn spec_const_int_negative_to_wgsl() {
967 let c = SpecConstValue::Int(-42);
968 assert_eq!(c.to_wgsl(), "-42");
969 }
970 #[test]
971 fn spec_const_uint_zero() {
972 let c = SpecConstValue::Uint(0);
973 assert_eq!(c.to_wgsl(), "0u");
974 }
975 #[test]
976 fn spec_const_bool_false() {
977 let c = SpecConstValue::Bool(false);
978 assert_eq!(c.to_wgsl(), "false");
979 }
980 #[test]
981 fn pipeline_cache_key_different_workgroup_sizes_differ() {
982 let k1 = PipelineCacheKey::new(ShaderKey::bare("s"), [64, 1, 1], 0, 0);
983 let k2 = PipelineCacheKey::new(ShaderKey::bare("s"), [128, 1, 1], 0, 0);
984 assert_ne!(k1.hash_key(), k2.hash_key());
985 }
986}