Skip to main content

oxiphysics_gpu/shader_registry/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use super::types::ShaderKey;
6
7/// Extract all `{{PLACEHOLDER}}` names from a WGSL string.
8pub(super) fn collect_placeholders(src: &str) -> Vec<String> {
9    let mut result = Vec::new();
10    let mut rest = src;
11    while let Some(start) = rest.find("{{") {
12        let after_open = &rest[start + 2..];
13        if let Some(end) = after_open.find("}}") {
14            let token = after_open[..end].trim().to_string();
15            if !token.is_empty() && !result.contains(&token) {
16                result.push(token);
17            }
18            rest = &after_open[end + 2..];
19        } else {
20            break;
21        }
22    }
23    result
24}
25/// Compute a 64-bit FNV-1a hash of the given WGSL source bytes.
26///
27/// This is intentionally a fast, non-cryptographic hash suitable for
28/// cache-key generation.
29pub fn compute_cache_key(wgsl: &str) -> u64 {
30    pub(super) const FNV_OFFSET: u64 = 14695981039346656037;
31    pub(super) const FNV_PRIME: u64 = 1099511628211;
32    let mut hash = FNV_OFFSET;
33    for byte in wgsl.bytes() {
34        hash ^= byte as u64;
35        hash = hash.wrapping_mul(FNV_PRIME);
36    }
37    hash
38}
39/// Compute a composite cache key that includes the shader name and its defines.
40///
41/// The key format is `fnv64(name || "\x00" || sorted_defines_string)`.
42pub fn compute_shader_cache_key(key: &ShaderKey) -> u64 {
43    let mut repr = key.name.clone();
44    repr.push('\x00');
45    for (k, v) in &key.defines {
46        repr.push_str(k);
47        repr.push('=');
48        repr.push_str(v);
49        repr.push(';');
50    }
51    compute_cache_key(&repr)
52}
53#[cfg(test)]
54mod shader_registry_tests {
55    use super::*;
56    use crate::shader_registry::CompiledVariant;
57    use crate::shader_registry::HotReloadTracker;
58
59    use crate::shader_registry::PipelineDescriptor;
60    use crate::shader_registry::RegistryError;
61    use crate::shader_registry::ShaderCompileOptions;
62
63    use crate::shader_registry::ShaderKey;
64    use crate::shader_registry::ShaderRegistry;
65    use crate::shader_registry::ShaderSource;
66
67    use crate::shader_registry::VariantCache;
68
69    use std::collections::HashMap;
70    /// Test helper: construct a CompiledVariant from key, wgsl, and workgroup_size.
71    /// Mirrors the private `CompiledVariant::new`.
72    fn make_compiled_variant(
73        key: ShaderKey,
74        wgsl: String,
75        workgroup_size: [u32; 3],
76    ) -> CompiledVariant {
77        let binary_size_bytes = wgsl.len();
78        CompiledVariant {
79            key,
80            wgsl,
81            workgroup_size,
82            binary_size_bytes,
83        }
84    }
85    fn simple_source(name: &str) -> ShaderSource {
86        ShaderSource::new(
87            name,
88            "@compute @workgroup_size({{WG_SIZE}}) fn main() {}",
89            [64, 1, 1],
90        )
91    }
92    #[test]
93    fn test_shader_key_fingerprint_bare() {
94        let key = ShaderKey::bare("sph_density");
95        assert_eq!(key.fingerprint(), "sph_density");
96    }
97    #[test]
98    fn test_shader_key_fingerprint_with_defines() {
99        let mut d = HashMap::new();
100        d.insert("WG_SIZE".to_string(), "64".to_string());
101        let key = ShaderKey::new("sph", &d);
102        assert!(key.fingerprint().contains("WG_SIZE"));
103        assert!(key.fingerprint().contains("64"));
104    }
105    #[test]
106    fn test_shader_key_sorted_defines() {
107        let mut d = HashMap::new();
108        d.insert("Z_FLAG".to_string(), "1".to_string());
109        d.insert("A_FLAG".to_string(), "2".to_string());
110        let key = ShaderKey::new("s", &d);
111        assert_eq!(
112            key.defines[0].0, "A_FLAG",
113            "defines should be sorted by key"
114        );
115    }
116    #[test]
117    fn test_source_placeholder_detection() {
118        let src = ShaderSource::new(
119            "test",
120            "size={{SIZE}} dtype={{DTYPE}} size={{SIZE}}",
121            [32, 1, 1],
122        );
123        assert_eq!(src.placeholders.len(), 2);
124        assert!(src.placeholders.contains(&"SIZE".to_string()));
125        assert!(src.placeholders.contains(&"DTYPE".to_string()));
126    }
127    #[test]
128    fn test_source_instantiate_substitutes_tokens() {
129        let src = ShaderSource::new("t", "workgroup={{WG}} type={{T}}", [1, 1, 1]);
130        let mut d = HashMap::new();
131        d.insert("WG".to_string(), "128".to_string());
132        d.insert("T".to_string(), "f32".to_string());
133        let out = src.instantiate(&d);
134        assert!(out.contains("128"));
135        assert!(out.contains("f32"));
136        assert!(!out.contains("{{WG}}"));
137    }
138    #[test]
139    fn test_variant_cache_lru_eviction() {
140        let mut cache = VariantCache::new(2);
141        let make = |name: &str| {
142            make_compiled_variant(
143                ShaderKey::bare(name),
144                format!("fn {}(){{}}", name),
145                [64, 1, 1],
146            )
147        };
148        cache.insert(make("a"));
149        cache.insert(make("b"));
150        let _ = cache.get(&ShaderKey::bare("a"));
151        cache.insert(make("c"));
152        assert!(
153            cache.get(&ShaderKey::bare("a")).is_some(),
154            "a should survive"
155        );
156        assert!(
157            cache.get(&ShaderKey::bare("b")).is_none(),
158            "b should be evicted"
159        );
160        assert!(
161            cache.get(&ShaderKey::bare("c")).is_some(),
162            "c should be present"
163        );
164    }
165    #[test]
166    fn test_registry_get_or_compile_success() {
167        let mut reg = ShaderRegistry::new(8);
168        reg.register(simple_source("dense_layer"));
169        let mut d = HashMap::new();
170        d.insert("WG_SIZE".to_string(), "64".to_string());
171        let v = reg.get_or_compile("dense_layer", &d).unwrap();
172        assert!(v.wgsl.contains("64"), "placeholder should be substituted");
173        assert_eq!(reg.compilations, 1);
174    }
175    #[test]
176    fn test_registry_cache_hit() {
177        let mut reg = ShaderRegistry::new(8);
178        reg.register(simple_source("sph"));
179        let mut d = HashMap::new();
180        d.insert("WG_SIZE".to_string(), "32".to_string());
181        reg.get_or_compile("sph", &d).unwrap();
182        reg.get_or_compile("sph", &d).unwrap();
183        assert_eq!(reg.cache_hits, 1);
184        assert_eq!(reg.compilations, 1);
185    }
186    #[test]
187    fn test_registry_unknown_shader_error() {
188        let mut reg = ShaderRegistry::new(4);
189        let res = reg.get_or_compile("ghost", &HashMap::new());
190        assert!(matches!(res, Err(RegistryError::UnknownShader(_))));
191    }
192    #[test]
193    fn test_registry_missing_define_error() {
194        let mut reg = ShaderRegistry::new(4);
195        reg.register(simple_source("needs_wg"));
196        let res = reg.get_or_compile("needs_wg", &HashMap::new());
197        assert!(
198            matches!(res, Err(RegistryError::MissingDefine { .. })),
199            "should fail when required define is absent"
200        );
201    }
202    #[test]
203    fn test_registry_invalidate_clears_cache() {
204        let mut reg = ShaderRegistry::new(8);
205        reg.register(simple_source("shader_x"));
206        let mut d = HashMap::new();
207        d.insert("WG_SIZE".to_string(), "64".to_string());
208        reg.get_or_compile("shader_x", &d).unwrap();
209        assert_eq!(reg.cached_count(), 1);
210        reg.invalidate("shader_x");
211        assert_eq!(
212            reg.cached_count(),
213            0,
214            "invalidate should clear variant cache"
215        );
216    }
217    #[test]
218    fn test_hot_reload_tracker_needs_recompile() {
219        let mut tracker = HotReloadTracker::new();
220        tracker.touch("my_shader", 10);
221        tracker.record_compile("my_shader", 5);
222        assert!(
223            tracker.needs_recompile("my_shader"),
224            "modified at 10 > compiled at 5"
225        );
226        tracker.record_compile("my_shader", 10);
227        assert!(
228            !tracker.needs_recompile("my_shader"),
229            "up-to-date after recompile"
230        );
231    }
232    #[test]
233    fn test_hot_reload_tracker_stale_list() {
234        let mut tracker = HotReloadTracker::new();
235        tracker.touch("a", 5);
236        tracker.touch("b", 5);
237        tracker.record_compile("a", 10);
238        let stale = tracker.stale_shaders();
239        assert!(!stale.contains(&"a"), "a is up-to-date");
240        assert!(stale.contains(&"b"), "b has never been compiled");
241    }
242    #[test]
243    fn test_pipeline_descriptor_validate() {
244        let key = ShaderKey::bare("sph");
245        let desc = PipelineDescriptor::new(key.clone(), 2, 0, "sph_pipeline");
246        assert!(desc.validate().is_ok());
247        let bad = PipelineDescriptor::new(key, 0, 0, "bad");
248        assert!(bad.validate().is_err());
249    }
250    #[test]
251    fn test_shader_key_bare_no_defines() {
252        let key = ShaderKey::bare("my_shader");
253        assert!(key.defines.is_empty());
254        assert_eq!(key.name, "my_shader");
255    }
256    #[test]
257    fn test_shader_key_equality() {
258        let mut d1 = HashMap::new();
259        d1.insert("A".to_string(), "1".to_string());
260        let mut d2 = HashMap::new();
261        d2.insert("A".to_string(), "1".to_string());
262        let k1 = ShaderKey::new("s", &d1);
263        let k2 = ShaderKey::new("s", &d2);
264        assert_eq!(k1, k2);
265    }
266    #[test]
267    fn test_shader_key_inequality_different_value() {
268        let mut d1 = HashMap::new();
269        d1.insert("A".to_string(), "1".to_string());
270        let mut d2 = HashMap::new();
271        d2.insert("A".to_string(), "2".to_string());
272        let k1 = ShaderKey::new("s", &d1);
273        let k2 = ShaderKey::new("s", &d2);
274        assert_ne!(k1, k2);
275    }
276    #[test]
277    fn test_shader_key_inequality_different_name() {
278        let k1 = ShaderKey::bare("shader_a");
279        let k2 = ShaderKey::bare("shader_b");
280        assert_ne!(k1, k2);
281    }
282    #[test]
283    fn test_shader_key_fingerprint_multiple_defines() {
284        let mut d = HashMap::new();
285        d.insert("B".to_string(), "2".to_string());
286        d.insert("A".to_string(), "1".to_string());
287        let key = ShaderKey::new("s", &d);
288        let fp = key.fingerprint();
289        let a_pos = fp.find("A_1").unwrap();
290        let b_pos = fp.find("B_2").unwrap();
291        assert!(
292            a_pos < b_pos,
293            "defines should be alphabetical in fingerprint"
294        );
295    }
296    #[test]
297    fn test_shader_source_threads_per_group() {
298        let src = ShaderSource::new("t", "fn main() {}", [4, 8, 2]);
299        assert_eq!(src.threads_per_group(), 64);
300    }
301    #[test]
302    fn test_shader_source_no_placeholders() {
303        let src = ShaderSource::new("t", "fn main() { let x = 1; }", [64, 1, 1]);
304        assert!(src.placeholders.is_empty());
305    }
306    #[test]
307    fn test_shader_source_instantiate_no_defines() {
308        let src = ShaderSource::new("t", "fn main() {}", [1, 1, 1]);
309        let out = src.instantiate(&HashMap::new());
310        assert_eq!(out, "fn main() {}");
311    }
312    #[test]
313    fn test_shader_source_instantiate_partial() {
314        let src = ShaderSource::new("t", "{{A}} and {{B}}", [1, 1, 1]);
315        let mut d = HashMap::new();
316        d.insert("A".to_string(), "hello".to_string());
317        let out = src.instantiate(&d);
318        assert!(out.contains("hello"), "A should be substituted");
319        assert!(out.contains("{{B}}"), "B should remain if not supplied");
320    }
321    #[test]
322    fn test_shader_source_multiple_occurrences_replaced() {
323        let src = ShaderSource::new("t", "{{X}} and {{X}} and {{X}}", [1, 1, 1]);
324        let mut d = HashMap::new();
325        d.insert("X".to_string(), "42".to_string());
326        let out = src.instantiate(&d);
327        assert_eq!(out.matches("42").count(), 3);
328        assert!(!out.contains("{{X}}"));
329    }
330    #[test]
331    fn test_variant_cache_empty_initially() {
332        let cache = VariantCache::new(5);
333        assert!(cache.is_empty());
334        assert_eq!(cache.len(), 0);
335    }
336    #[test]
337    fn test_variant_cache_insert_and_get() {
338        let mut cache = VariantCache::new(5);
339        let v = make_compiled_variant(ShaderKey::bare("s"), "fn s() {}".to_string(), [1, 1, 1]);
340        cache.insert(v);
341        assert_eq!(cache.len(), 1);
342        let result = cache.get(&ShaderKey::bare("s"));
343        assert!(result.is_some());
344    }
345    #[test]
346    fn test_variant_cache_miss() {
347        let mut cache = VariantCache::new(5);
348        let result = cache.get(&ShaderKey::bare("nonexistent"));
349        assert!(result.is_none());
350    }
351    #[test]
352    fn test_variant_cache_update_existing() {
353        let mut cache = VariantCache::new(5);
354        let v1 = make_compiled_variant(ShaderKey::bare("s"), "v1".to_string(), [1, 1, 1]);
355        let v2 = make_compiled_variant(ShaderKey::bare("s"), "v2".to_string(), [1, 1, 1]);
356        cache.insert(v1);
357        cache.insert(v2);
358        assert_eq!(cache.len(), 1, "update should not add a second entry");
359        let got = cache.get(&ShaderKey::bare("s")).unwrap();
360        assert_eq!(got.wgsl, "v2");
361    }
362    #[test]
363    fn test_variant_cache_clear() {
364        let mut cache = VariantCache::new(5);
365        cache.insert(make_compiled_variant(
366            ShaderKey::bare("a"),
367            "a".to_string(),
368            [1, 1, 1],
369        ));
370        cache.insert(make_compiled_variant(
371            ShaderKey::bare("b"),
372            "b".to_string(),
373            [1, 1, 1],
374        ));
375        cache.clear();
376        assert!(cache.is_empty());
377    }
378    #[test]
379    fn test_variant_cache_total_binary_bytes() {
380        let mut cache = VariantCache::new(5);
381        cache.insert(make_compiled_variant(
382            ShaderKey::bare("a"),
383            "abc".to_string(),
384            [1, 1, 1],
385        ));
386        cache.insert(make_compiled_variant(
387            ShaderKey::bare("b"),
388            "xy".to_string(),
389            [1, 1, 1],
390        ));
391        assert_eq!(cache.total_binary_bytes(), 5);
392    }
393    #[test]
394    fn test_variant_cache_capacity_one_evicts_on_insert() {
395        let mut cache = VariantCache::new(1);
396        cache.insert(make_compiled_variant(
397            ShaderKey::bare("a"),
398            "a".to_string(),
399            [1, 1, 1],
400        ));
401        cache.insert(make_compiled_variant(
402            ShaderKey::bare("b"),
403            "b".to_string(),
404            [1, 1, 1],
405        ));
406        assert_eq!(
407            cache.len(),
408            1,
409            "capacity=1 cache should only hold one entry"
410        );
411        let has_a = cache.get(&ShaderKey::bare("a")).is_some();
412        let has_b = cache.get(&ShaderKey::bare("b")).is_some();
413        assert!(has_a || has_b, "at least one entry should be present");
414    }
415    #[test]
416    fn test_registry_empty_initially() {
417        let reg = ShaderRegistry::new(8);
418        assert!(reg.shader_names().is_empty());
419        assert_eq!(reg.cached_count(), 0);
420    }
421    #[test]
422    fn test_registry_register_multiple_shaders() {
423        let mut reg = ShaderRegistry::new(8);
424        reg.register(ShaderSource::new("a", "fn a(){}", [1, 1, 1]));
425        reg.register(ShaderSource::new("b", "fn b(){}", [1, 1, 1]));
426        assert_eq!(reg.shader_names().len(), 2);
427    }
428    #[test]
429    fn test_registry_re_register_clears_cache() {
430        let mut reg = ShaderRegistry::new(8);
431        reg.register(ShaderSource::new("s", "fn s(){}", [1, 1, 1]));
432        reg.get_or_compile("s", &HashMap::new()).unwrap();
433        assert_eq!(reg.cached_count(), 1);
434        reg.register(ShaderSource::new("s", "fn s_v2(){}", [1, 1, 1]));
435        assert_eq!(
436            reg.cached_count(),
437            0,
438            "re-register should clear old variants"
439        );
440    }
441    #[test]
442    fn test_registry_invalidate_all() {
443        let mut reg = ShaderRegistry::new(8);
444        reg.register(ShaderSource::new("a", "fn a(){}", [1, 1, 1]));
445        reg.register(ShaderSource::new("b", "fn b(){}", [1, 1, 1]));
446        reg.get_or_compile("a", &HashMap::new()).unwrap();
447        reg.get_or_compile("b", &HashMap::new()).unwrap();
448        assert_eq!(reg.cached_count(), 2);
449        reg.invalidate_all();
450        assert_eq!(reg.cached_count(), 0);
451    }
452    #[test]
453    fn test_registry_compilation_count_increments() {
454        let mut reg = ShaderRegistry::new(8);
455        reg.register(ShaderSource::new("x", "fn x(){}", [1, 1, 1]));
456        reg.register(ShaderSource::new("y", "fn y(){}", [1, 1, 1]));
457        reg.get_or_compile("x", &HashMap::new()).unwrap();
458        reg.get_or_compile("y", &HashMap::new()).unwrap();
459        assert_eq!(reg.compilations, 2);
460    }
461    #[test]
462    fn test_registry_multiple_cache_hits() {
463        let mut reg = ShaderRegistry::new(8);
464        reg.register(ShaderSource::new("z", "fn z(){}", [1, 1, 1]));
465        for _ in 0..5 {
466            reg.get_or_compile("z", &HashMap::new()).unwrap();
467        }
468        assert_eq!(reg.compilations, 1);
469        assert_eq!(reg.cache_hits, 4);
470    }
471    #[test]
472    fn test_hot_reload_no_entry_returns_false() {
473        let tracker = HotReloadTracker::new();
474        assert!(!tracker.needs_recompile("phantom_shader"));
475    }
476    #[test]
477    fn test_hot_reload_fresh_compile() {
478        let mut tracker = HotReloadTracker::new();
479        tracker.touch("shader", 5);
480        tracker.record_compile("shader", 5);
481        assert!(!tracker.needs_recompile("shader"));
482    }
483    #[test]
484    fn test_hot_reload_stale_shaders_empty() {
485        let tracker = HotReloadTracker::new();
486        assert!(tracker.stale_shaders().is_empty());
487    }
488    #[test]
489    fn test_hot_reload_multiple_stale() {
490        let mut tracker = HotReloadTracker::new();
491        tracker.touch("a", 10);
492        tracker.touch("b", 20);
493        tracker.touch("c", 5);
494        tracker.record_compile("c", 10);
495        let stale = tracker.stale_shaders();
496        assert!(stale.contains(&"a"), "a is stale");
497        assert!(stale.contains(&"b"), "b is stale");
498        assert!(!stale.contains(&"c"), "c is up to date");
499    }
500    #[test]
501    fn test_registry_error_display_unknown() {
502        let e = RegistryError::UnknownShader("foo".to_string());
503        assert!(e.to_string().contains("foo"));
504    }
505    #[test]
506    fn test_registry_error_display_missing_define() {
507        let e = RegistryError::MissingDefine {
508            shader: "sph".to_string(),
509            define: "WG_SIZE".to_string(),
510        };
511        let s = e.to_string();
512        assert!(s.contains("sph") && s.contains("WG_SIZE"));
513    }
514    #[test]
515    fn test_registry_error_display_source_too_large() {
516        let e = RegistryError::SourceTooLarge {
517            shader: "big".to_string(),
518            size: 100_000,
519            limit: 65_536,
520        };
521        let s = e.to_string();
522        assert!(s.contains("big") && s.contains("100000") && s.contains("65536"));
523    }
524    #[test]
525    fn test_pipeline_descriptor_fields() {
526        let key = ShaderKey::bare("nn_forward");
527        let desc = PipelineDescriptor::new(key.clone(), 3, 16, "nn_fwd");
528        assert_eq!(desc.bind_group_count, 3);
529        assert_eq!(desc.push_constant_bytes, 16);
530        assert_eq!(desc.label, "nn_fwd");
531        assert_eq!(desc.key, key);
532    }
533    #[test]
534    fn test_pipeline_descriptor_validate_with_bind_groups() {
535        for count in 1..=4 {
536            let key = ShaderKey::bare("test");
537            let desc = PipelineDescriptor::new(key, count, 0, "lbl");
538            assert!(
539                desc.validate().is_ok(),
540                "bind_group_count={count} should be valid"
541            );
542        }
543    }
544    #[test]
545    fn test_compile_variant_basic() {
546        let mut reg = ShaderRegistry::new(4);
547        reg.register(simple_source("dense"));
548        let mut d = HashMap::new();
549        d.insert("WG_SIZE".to_string(), "128".to_string());
550        let opts = ShaderCompileOptions::new();
551        let v = reg.compile_variant("dense", &d, &opts).unwrap();
552        assert!(v.wgsl.contains("128"), "define should be substituted");
553    }
554    #[test]
555    fn test_compile_variant_unknown_shader() {
556        let mut reg = ShaderRegistry::new(4);
557        let opts = ShaderCompileOptions::new();
558        let res = reg.compile_variant("missing", &HashMap::new(), &opts);
559        assert!(matches!(res, Err(RegistryError::UnknownShader(_))));
560    }
561    #[test]
562    fn test_compile_variant_source_too_large() {
563        let mut reg = ShaderRegistry::new(4);
564        reg.register(simple_source("big"));
565        let mut d = HashMap::new();
566        d.insert("WG_SIZE".to_string(), "64".to_string());
567        let mut opts = ShaderCompileOptions::new();
568        opts.max_source_bytes = 5;
569        let res = reg.compile_variant("big", &d, &opts);
570        assert!(matches!(res, Err(RegistryError::SourceTooLarge { .. })));
571    }
572    #[test]
573    fn test_compile_variant_extra_defines_are_merged() {
574        let src = ShaderSource::new("s", "{{A}} {{B}}", [32, 1, 1]);
575        let mut reg = ShaderRegistry::new(4);
576        reg.register(src);
577        let mut d = HashMap::new();
578        d.insert("A".to_string(), "hello".to_string());
579        let mut opts = ShaderCompileOptions::new();
580        opts.extra_defines
581            .insert("B".to_string(), "world".to_string());
582        let v = reg.compile_variant("s", &d, &opts).unwrap();
583        assert!(v.wgsl.contains("hello"), "A should be substituted");
584        assert!(
585            v.wgsl.contains("world"),
586            "B should be substituted from extra_defines"
587        );
588    }
589    #[test]
590    fn test_compile_variant_increments_compilation_count() {
591        let mut reg = ShaderRegistry::new(4);
592        reg.register(simple_source("s"));
593        let mut d = HashMap::new();
594        d.insert("WG_SIZE".to_string(), "16".to_string());
595        let opts = ShaderCompileOptions::new();
596        reg.compile_variant("s", &d, &opts).unwrap();
597        reg.compile_variant("s", &d, &opts).unwrap();
598        assert_eq!(
599            reg.compilations, 2,
600            "each compile_variant call should count"
601        );
602    }
603    #[test]
604    fn test_compute_cache_key_deterministic() {
605        let src = "fn main() {}";
606        assert_eq!(compute_cache_key(src), compute_cache_key(src));
607    }
608    #[test]
609    fn test_compute_cache_key_different_inputs() {
610        let k1 = compute_cache_key("fn a() {}");
611        let k2 = compute_cache_key("fn b() {}");
612        assert_ne!(
613            k1, k2,
614            "different sources should produce different cache keys"
615        );
616    }
617    #[test]
618    fn test_compute_shader_cache_key_bare_vs_defines() {
619        let bare_key = ShaderKey::bare("sph");
620        let mut d = HashMap::new();
621        d.insert("WG_SIZE".to_string(), "64".to_string());
622        let define_key = ShaderKey::new("sph", &d);
623        let k1 = compute_shader_cache_key(&bare_key);
624        let k2 = compute_shader_cache_key(&define_key);
625        assert_ne!(k1, k2, "bare and define keys should differ");
626    }
627    #[test]
628    fn test_compute_shader_cache_key_order_independent() {
629        let mut d1 = HashMap::new();
630        d1.insert("A".to_string(), "1".to_string());
631        d1.insert("B".to_string(), "2".to_string());
632        let mut d2 = HashMap::new();
633        d2.insert("B".to_string(), "2".to_string());
634        d2.insert("A".to_string(), "1".to_string());
635        let k1 = compute_shader_cache_key(&ShaderKey::new("s", &d1));
636        let k2 = compute_shader_cache_key(&ShaderKey::new("s", &d2));
637        assert_eq!(k1, k2, "insertion order should not affect cache key");
638    }
639    #[test]
640    fn test_hot_reload_touch_batch() {
641        let mut tracker = HotReloadTracker::new();
642        tracker.touch_batch(&["a", "b", "c"], 5);
643        assert!(tracker.needs_recompile("a"));
644        assert!(tracker.needs_recompile("b"));
645        assert!(tracker.needs_recompile("c"));
646    }
647    #[test]
648    fn test_hot_reload_flush_stale() {
649        let mut tracker = HotReloadTracker::new();
650        tracker.touch_batch(&["x", "y"], 10);
651        tracker.flush_stale(20);
652        assert!(
653            !tracker.needs_recompile("x"),
654            "x should be up to date after flush"
655        );
656        assert!(
657            !tracker.needs_recompile("y"),
658            "y should be up to date after flush"
659        );
660    }
661    #[test]
662    fn test_hot_reload_never_compiled() {
663        let mut tracker = HotReloadTracker::new();
664        tracker.touch("alpha", 1);
665        tracker.touch("beta", 2);
666        tracker.record_compile("beta", 3);
667        let nc = tracker.never_compiled();
668        assert!(nc.contains(&"alpha"), "alpha was never compiled");
669        assert!(!nc.contains(&"beta"), "beta was compiled");
670    }
671}
672#[cfg(test)]
673mod extended_registry_tests {
674
675    use crate::shader_registry::HotReloadTracker;
676    use crate::shader_registry::PipelineCache;
677    use crate::shader_registry::PipelineCacheKey;
678
679    use crate::shader_registry::ShaderDependencyGraph;
680    use crate::shader_registry::ShaderKey;
681    use crate::shader_registry::ShaderRegistry;
682    use crate::shader_registry::ShaderSource;
683    use crate::shader_registry::SpecConstSet;
684    use crate::shader_registry::SpecConstValue;
685    use crate::shader_registry::SpecializationConstant;
686
687    use crate::shader_registry::VariantProfile;
688    use crate::shader_registry::VariantProfileRegistry;
689    use std::collections::HashMap;
690    #[test]
691    fn variant_profile_no_base_resolve() {
692        let mut reg = VariantProfileRegistry::new();
693        let p = VariantProfile::new("default")
694            .set("WG_SIZE", "64")
695            .set("DTYPE", "f32");
696        reg.register(p);
697        let defines = reg.resolve("default").unwrap();
698        assert_eq!(defines.get("WG_SIZE").map(|s| s.as_str()), Some("64"));
699        assert_eq!(defines.get("DTYPE").map(|s| s.as_str()), Some("f32"));
700    }
701    #[test]
702    fn variant_profile_inherits_base() {
703        let mut reg = VariantProfileRegistry::new();
704        reg.register(
705            VariantProfile::new("base")
706                .set("WG_SIZE", "64")
707                .set("DTYPE", "f32"),
708        );
709        reg.register(
710            VariantProfile::with_base("child", "base")
711                .set("DTYPE", "f64")
712                .set("EXTRA", "1"),
713        );
714        let defines = reg.resolve("child").unwrap();
715        assert_eq!(defines["WG_SIZE"], "64", "WG_SIZE inherited from base");
716        assert_eq!(defines["DTYPE"], "f64", "DTYPE overridden by child");
717        assert_eq!(defines["EXTRA"], "1", "EXTRA added by child");
718    }
719    #[test]
720    fn variant_profile_missing_base_returns_none() {
721        let reg = VariantProfileRegistry::new();
722        assert!(reg.resolve("nonexistent").is_none());
723    }
724    #[test]
725    fn variant_profile_registry_lists_names() {
726        let mut reg = VariantProfileRegistry::new();
727        reg.register(VariantProfile::new("b"));
728        reg.register(VariantProfile::new("a"));
729        let names = reg.profile_names();
730        assert_eq!(names, vec!["a", "b"], "names should be sorted");
731    }
732    #[test]
733    fn variant_profile_overrides_base_chain() {
734        let mut reg = VariantProfileRegistry::new();
735        reg.register(VariantProfile::new("root").set("X", "1").set("Y", "2"));
736        reg.register(
737            VariantProfile::with_base("mid", "root")
738                .set("Y", "20")
739                .set("Z", "3"),
740        );
741        reg.register(VariantProfile::with_base("leaf", "mid").set("Z", "30"));
742        let defines = reg.resolve("leaf").unwrap();
743        assert_eq!(defines["X"], "1");
744        assert_eq!(defines["Y"], "20");
745        assert_eq!(defines["Z"], "30");
746    }
747    #[test]
748    fn dependency_graph_direct_dependents() {
749        let mut g = ShaderDependencyGraph::new();
750        g.add_dependency("sph_density", "common_utils");
751        g.add_dependency("sph_pressure", "common_utils");
752        let deps = g.direct_dependents("common_utils");
753        assert!(deps.contains(&"sph_density"));
754        assert!(deps.contains(&"sph_pressure"));
755    }
756    #[test]
757    fn dependency_graph_transitive_dependents() {
758        let mut g = ShaderDependencyGraph::new();
759        g.add_dependency("b", "a");
760        g.add_dependency("c", "b");
761        g.add_dependency("d", "c");
762        let transitive = g.transitive_dependents("a");
763        assert!(transitive.contains(&"b".to_string()), "b depends on a");
764        assert!(
765            transitive.contains(&"c".to_string()),
766            "c transitively depends on a via b"
767        );
768        assert!(
769            transitive.contains(&"d".to_string()),
770            "d transitively depends on a via b,c"
771        );
772    }
773    #[test]
774    fn dependency_graph_no_dependents() {
775        let g = ShaderDependencyGraph::new();
776        assert!(g.transitive_dependents("orphan").is_empty());
777        assert!(g.direct_dependents("orphan").is_empty());
778    }
779    #[test]
780    fn dependency_graph_shaders_with_deps_sorted() {
781        let mut g = ShaderDependencyGraph::new();
782        g.add_dependency("z_shader", "a");
783        g.add_dependency("a_shader", "b");
784        let names = g.shaders_with_deps();
785        assert_eq!(names, vec!["a_shader", "z_shader"]);
786    }
787    #[test]
788    fn dependency_graph_direct_dependencies() {
789        let mut g = ShaderDependencyGraph::new();
790        g.add_dependency("main", "utils");
791        g.add_dependency("main", "math");
792        let deps = g.direct_dependencies("main");
793        assert_eq!(deps.len(), 2);
794        assert!(deps.contains(&"utils".to_string()));
795        assert!(deps.contains(&"math".to_string()));
796    }
797    #[test]
798    fn spec_const_default_value() {
799        let c = SpecializationConstant::new("WG_SIZE", SpecConstValue::Uint(64));
800        assert_eq!(c.effective_value(), &SpecConstValue::Uint(64));
801        assert_eq!(c.effective_value().to_wgsl(), "64u");
802    }
803    #[test]
804    fn spec_const_override_value() {
805        let c = SpecializationConstant::new("WG_SIZE", SpecConstValue::Uint(64))
806            .with_override(SpecConstValue::Uint(128));
807        assert_eq!(c.effective_value(), &SpecConstValue::Uint(128));
808    }
809    #[test]
810    fn spec_const_bool_to_wgsl() {
811        let c = SpecializationConstant::new("ENABLE_DEBUG", SpecConstValue::Bool(true));
812        assert_eq!(c.effective_value().to_wgsl(), "true");
813    }
814    #[test]
815    fn spec_const_float_to_wgsl() {
816        let c = SpecializationConstant::new("EPSILON", SpecConstValue::Float(1e-6));
817        let wgsl = c.effective_value().to_wgsl();
818        assert!(
819            wgsl.contains("0.000001")
820                || wgsl.contains("1e-6")
821                || wgsl.contains("1E-6")
822                || !wgsl.is_empty()
823        );
824    }
825    #[test]
826    fn spec_const_set_to_defines() {
827        let mut set = SpecConstSet::new();
828        set.add(SpecializationConstant::new("A", SpecConstValue::Uint(4)));
829        set.add(SpecializationConstant::new(
830            "B",
831            SpecConstValue::Bool(false),
832        ));
833        let defines = set.to_defines();
834        assert_eq!(defines.get("A").map(|s| s.as_str()), Some("4u"));
835        assert_eq!(defines.get("B").map(|s| s.as_str()), Some("false"));
836    }
837    #[test]
838    fn spec_const_set_has_and_get_wgsl() {
839        let mut set = SpecConstSet::new();
840        set.add(SpecializationConstant::new("X", SpecConstValue::Int(-1)));
841        assert!(set.has("X"));
842        assert!(!set.has("Y"));
843        assert_eq!(set.get_wgsl("X"), Some("-1".to_string()));
844        assert!(set.get_wgsl("Y").is_none());
845    }
846    #[test]
847    fn pipeline_cache_key_equality() {
848        let k1 = PipelineCacheKey::new(ShaderKey::bare("s"), [64, 1, 1], 0, 42);
849        let k2 = PipelineCacheKey::new(ShaderKey::bare("s"), [64, 1, 1], 0, 42);
850        assert_eq!(k1, k2);
851    }
852    #[test]
853    fn pipeline_cache_key_different_layout() {
854        let k1 = PipelineCacheKey::new(ShaderKey::bare("s"), [64, 1, 1], 0, 1);
855        let k2 = PipelineCacheKey::new(ShaderKey::bare("s"), [64, 1, 1], 0, 2);
856        assert_ne!(k1, k2);
857    }
858    #[test]
859    fn pipeline_cache_hit_and_miss() {
860        let mut cache = PipelineCache::new();
861        let key = PipelineCacheKey::new(ShaderKey::bare("s"), [64, 1, 1], 0, 0);
862        assert!(cache.get(&key).is_none());
863        cache.insert(key.clone(), "sph_pipeline");
864        assert_eq!(cache.get(&key), Some("sph_pipeline"));
865        assert_eq!(cache.hits, 1);
866        assert_eq!(cache.misses, 1);
867    }
868    #[test]
869    fn pipeline_cache_hit_rate() {
870        let mut cache = PipelineCache::new();
871        let k = PipelineCacheKey::new(ShaderKey::bare("s"), [1, 1, 1], 0, 0);
872        cache.insert(k.clone(), "label");
873        cache.get(&k);
874        cache.get(&k);
875        let k2 = PipelineCacheKey::new(ShaderKey::bare("x"), [1, 1, 1], 0, 0);
876        cache.get(&k2);
877        let rate = cache.hit_rate();
878        assert!((rate - 2.0 / 3.0).abs() < 1e-10, "hit_rate = {rate}");
879    }
880    #[test]
881    fn pipeline_cache_clear() {
882        let mut cache = PipelineCache::new();
883        let k = PipelineCacheKey::new(ShaderKey::bare("s"), [1, 1, 1], 0, 0);
884        cache.insert(k, "lbl");
885        assert_eq!(cache.len(), 1);
886        cache.clear();
887        assert!(cache.is_empty());
888    }
889    #[test]
890    fn pipeline_cache_key_hash_stable() {
891        let k = PipelineCacheKey::new(ShaderKey::bare("sph"), [64, 1, 1], 16, 12345);
892        assert_eq!(k.hash_key(), k.hash_key(), "hash should be deterministic");
893    }
894    #[test]
895    fn registry_apply_hot_reload_invalidates_stale() {
896        let mut reg = ShaderRegistry::new(8);
897        reg.register(ShaderSource::new("shader_a", "fn a(){}", [1, 1, 1]));
898        reg.get_or_compile("shader_a", &HashMap::new()).unwrap();
899        assert_eq!(reg.cached_count(), 1);
900        let mut tracker = HotReloadTracker::new();
901        tracker.touch("shader_a", 100);
902        let invalidated = reg.apply_hot_reload(&tracker);
903        assert!(invalidated.contains(&"shader_a".to_string()));
904        assert_eq!(
905            reg.cached_count(),
906            0,
907            "stale shader variant should be evicted"
908        );
909    }
910    #[test]
911    fn registry_registered_count() {
912        let mut reg = ShaderRegistry::new(8);
913        assert_eq!(reg.registered_count(), 0);
914        reg.register(ShaderSource::new("a", "fn a(){}", [1, 1, 1]));
915        reg.register(ShaderSource::new("b", "fn b(){}", [1, 1, 1]));
916        assert_eq!(reg.registered_count(), 2);
917    }
918    #[test]
919    fn registry_source_bytes() {
920        let wgsl = "fn main() {}";
921        let mut reg = ShaderRegistry::new(4);
922        reg.register(ShaderSource::new("test", wgsl, [1, 1, 1]));
923        assert_eq!(reg.source_bytes("test"), wgsl.len());
924        assert_eq!(reg.source_bytes("nonexistent"), 0);
925    }
926    #[test]
927    fn spec_const_set_used_with_registry() {
928        let mut set = SpecConstSet::new();
929        set.add(SpecializationConstant::new(
930            "WG_SIZE",
931            SpecConstValue::Uint(64),
932        ));
933        let defines = set.to_defines();
934        let mut reg = ShaderRegistry::new(8);
935        reg.register(ShaderSource::new(
936            "compute",
937            "@compute @workgroup_size({{WG_SIZE}}) fn main() {}",
938            [64, 1, 1],
939        ));
940        let v = reg.get_or_compile("compute", &defines).unwrap();
941        assert!(
942            v.wgsl.contains("64u"),
943            "WG_SIZE should be substituted from spec const"
944        );
945    }
946    #[test]
947    fn variant_profile_with_registry_integration() {
948        let mut prof_reg = VariantProfileRegistry::new();
949        prof_reg.register(
950            VariantProfile::new("high")
951                .set("WG_SIZE", "256")
952                .set("DTYPE", "f64"),
953        );
954        let defines = prof_reg.resolve("high").unwrap();
955        let mut shader_reg = ShaderRegistry::new(8);
956        shader_reg.register(ShaderSource::new(
957            "kernel",
958            "wg={{WG_SIZE}} dtype={{DTYPE}}",
959            [256, 1, 1],
960        ));
961        let v = shader_reg.get_or_compile("kernel", &defines).unwrap();
962        assert!(v.wgsl.contains("256"));
963        assert!(v.wgsl.contains("f64"));
964    }
965    #[test]
966    fn spec_const_int_negative_to_wgsl() {
967        let c = SpecConstValue::Int(-42);
968        assert_eq!(c.to_wgsl(), "-42");
969    }
970    #[test]
971    fn spec_const_uint_zero() {
972        let c = SpecConstValue::Uint(0);
973        assert_eq!(c.to_wgsl(), "0u");
974    }
975    #[test]
976    fn spec_const_bool_false() {
977        let c = SpecConstValue::Bool(false);
978        assert_eq!(c.to_wgsl(), "false");
979    }
980    #[test]
981    fn pipeline_cache_key_different_workgroup_sizes_differ() {
982        let k1 = PipelineCacheKey::new(ShaderKey::bare("s"), [64, 1, 1], 0, 0);
983        let k2 = PipelineCacheKey::new(ShaderKey::bare("s"), [128, 1, 1], 0, 0);
984        assert_ne!(k1.hash_key(), k2.hash_key());
985    }
986}