1#[allow(unused_imports)]
6use super::functions::*;
7use std::collections::HashMap;
8
9pub struct VariantCache {
14 pub capacity: usize,
16 pub(super) entries: Vec<(ShaderKey, CompiledVariant, u64)>,
18 pub(super) clock: u64,
20}
21impl VariantCache {
22 pub fn new(capacity: usize) -> Self {
24 Self {
25 capacity: capacity.max(1),
26 entries: Vec::new(),
27 clock: 0,
28 }
29 }
30 pub fn insert(&mut self, variant: CompiledVariant) {
32 if let Some(pos) = self.entries.iter().position(|(k, _, _)| *k == variant.key) {
33 self.clock += 1;
34 self.entries[pos] = (variant.key.clone(), variant, self.clock);
35 return;
36 }
37 if self.entries.len() >= self.capacity {
38 let lru_pos = self
39 .entries
40 .iter()
41 .enumerate()
42 .min_by_key(|(_, (_, _, t))| *t)
43 .map(|(i, _)| i)
44 .unwrap_or(0);
45 self.entries.swap_remove(lru_pos);
46 }
47 self.clock += 1;
48 self.entries
49 .push((variant.key.clone(), variant, self.clock));
50 }
51 pub fn get(&mut self, key: &ShaderKey) -> Option<&CompiledVariant> {
53 if let Some(pos) = self.entries.iter().position(|(k, _, _)| k == key) {
54 self.clock += 1;
55 self.entries[pos].2 = self.clock;
56 return Some(&self.entries[pos].1);
57 }
58 None
59 }
60 pub fn len(&self) -> usize {
62 self.entries.len()
63 }
64 pub fn is_empty(&self) -> bool {
66 self.entries.is_empty()
67 }
68 pub fn clear(&mut self) {
70 self.entries.clear();
71 self.clock = 0;
72 }
73 pub fn total_binary_bytes(&self) -> usize {
75 self.entries
76 .iter()
77 .map(|(_, v, _)| v.binary_size_bytes)
78 .sum()
79 }
80}
81#[derive(Debug, Clone)]
83pub struct CompiledVariant {
84 pub key: ShaderKey,
86 pub wgsl: String,
88 pub workgroup_size: [u32; 3],
90 pub binary_size_bytes: usize,
92}
93impl CompiledVariant {
94 fn new(key: ShaderKey, wgsl: String, workgroup_size: [u32; 3]) -> Self {
95 let binary_size_bytes = wgsl.len();
96 Self {
97 key,
98 wgsl,
99 workgroup_size,
100 binary_size_bytes,
101 }
102 }
103}
104#[derive(Debug, Default)]
106pub struct VariantProfileRegistry {
107 pub(super) profiles: HashMap<String, VariantProfile>,
108}
109impl VariantProfileRegistry {
110 pub fn new() -> Self {
112 Self::default()
113 }
114 pub fn register(&mut self, profile: VariantProfile) {
116 self.profiles.insert(profile.name.clone(), profile);
117 }
118 pub fn resolve(&self, name: &str) -> Option<HashMap<String, String>> {
122 let profile = self.profiles.get(name)?;
123 let mut merged = if let Some(base) = &profile.base {
124 self.resolve(base)?
125 } else {
126 HashMap::new()
127 };
128 for (k, v) in &profile.defines {
129 merged.insert(k.clone(), v.clone());
130 }
131 Some(merged)
132 }
133 pub fn profile_names(&self) -> Vec<&str> {
135 let mut names: Vec<&str> = self.profiles.keys().map(|s| s.as_str()).collect();
136 names.sort_unstable();
137 names
138 }
139}
140#[derive(Debug, Default)]
145pub struct ShaderDependencyGraph {
146 pub(super) depends_on: HashMap<String, Vec<String>>,
148}
149impl ShaderDependencyGraph {
150 pub fn new() -> Self {
152 Self::default()
153 }
154 pub fn add_dependency(&mut self, dependent: impl Into<String>, dependency: impl Into<String>) {
156 self.depends_on
157 .entry(dependent.into())
158 .or_default()
159 .push(dependency.into());
160 }
161 pub fn direct_dependents(&self, dependency: &str) -> Vec<&str> {
163 self.depends_on
164 .iter()
165 .filter(|(_, deps)| deps.iter().any(|d| d == dependency))
166 .map(|(name, _)| name.as_str())
167 .collect()
168 }
169 pub fn transitive_dependents(&self, changed_shader: &str) -> Vec<String> {
172 let mut visited = std::collections::HashSet::new();
173 let mut queue = std::collections::VecDeque::new();
174 queue.push_back(changed_shader.to_string());
175 while let Some(current) = queue.pop_front() {
176 for (name, deps) in &self.depends_on {
177 if deps.contains(¤t) && !visited.contains(name.as_str()) {
178 visited.insert(name.clone());
179 queue.push_back(name.clone());
180 }
181 }
182 }
183 let mut result: Vec<String> = visited.into_iter().collect();
184 result.sort_unstable();
185 result
186 }
187 pub fn direct_dependencies(&self, shader: &str) -> &[String] {
189 self.depends_on
190 .get(shader)
191 .map(Vec::as_slice)
192 .unwrap_or(&[])
193 }
194 pub fn shaders_with_deps(&self) -> Vec<&str> {
196 let mut names: Vec<&str> = self.depends_on.keys().map(|s| s.as_str()).collect();
197 names.sort_unstable();
198 names
199 }
200}
201#[derive(Debug, Clone, PartialEq, Eq, Hash)]
205pub struct ShaderKey {
206 pub name: String,
208 pub defines: Vec<(String, String)>,
210}
211impl ShaderKey {
212 pub fn new(name: impl Into<String>, defines: &HashMap<String, String>) -> Self {
214 let mut sorted: Vec<(String, String)> = defines
215 .iter()
216 .map(|(k, v)| (k.clone(), v.clone()))
217 .collect();
218 sorted.sort_by(|a, b| a.0.cmp(&b.0));
219 Self {
220 name: name.into(),
221 defines: sorted,
222 }
223 }
224 pub fn bare(name: impl Into<String>) -> Self {
226 Self {
227 name: name.into(),
228 defines: Vec::new(),
229 }
230 }
231 pub fn fingerprint(&self) -> String {
233 if self.defines.is_empty() {
234 return self.name.clone();
235 }
236 let suffix: String = self
237 .defines
238 .iter()
239 .map(|(k, v)| format!("{k}_{v}"))
240 .collect::<Vec<_>>()
241 .join("__");
242 format!("{}__{}", self.name, suffix)
243 }
244}
245#[derive(Debug, Clone)]
250pub struct VariantProfile {
251 pub name: String,
253 pub defines: HashMap<String, String>,
255 pub base: Option<String>,
257}
258impl VariantProfile {
259 pub fn new(name: impl Into<String>) -> Self {
261 Self {
262 name: name.into(),
263 defines: HashMap::new(),
264 base: None,
265 }
266 }
267 pub fn with_base(name: impl Into<String>, base: impl Into<String>) -> Self {
269 Self {
270 name: name.into(),
271 defines: HashMap::new(),
272 base: Some(base.into()),
273 }
274 }
275 pub fn set(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
277 self.defines.insert(key.into(), value.into());
278 self
279 }
280}
281#[derive(Debug, Default)]
286pub struct HotReloadTracker {
287 pub(super) timestamps: HashMap<String, u64>,
289 pub(super) compiled_at: HashMap<String, u64>,
291}
292impl HotReloadTracker {
293 pub fn new() -> Self {
295 Self::default()
296 }
297 pub fn touch(&mut self, name: impl Into<String>, time: u64) {
299 self.timestamps.insert(name.into(), time);
300 }
301 pub fn record_compile(&mut self, name: impl Into<String>, time: u64) {
303 self.compiled_at.insert(name.into(), time);
304 }
305 pub fn needs_recompile(&self, name: &str) -> bool {
307 let modified = self.timestamps.get(name).copied().unwrap_or(0);
308 let compiled = self.compiled_at.get(name).copied().unwrap_or(0);
309 modified > compiled
310 }
311 pub fn stale_shaders(&self) -> Vec<&str> {
313 self.timestamps
314 .keys()
315 .filter(|n| self.needs_recompile(n))
316 .map(|s| s.as_str())
317 .collect()
318 }
319}
320impl HotReloadTracker {
321 pub fn touch_batch(&mut self, names: &[&str], time: u64) {
323 for &name in names {
324 self.touch(name, time);
325 }
326 }
327 pub fn flush_stale(&mut self, time: u64) {
329 let stale: Vec<String> = self
330 .timestamps
331 .keys()
332 .filter(|n| self.needs_recompile(n))
333 .cloned()
334 .collect();
335 for name in stale {
336 self.compiled_at.insert(name, time);
337 }
338 }
339 pub fn never_compiled(&self) -> Vec<&str> {
341 self.timestamps
342 .keys()
343 .filter(|n| !self.compiled_at.contains_key(n.as_str()))
344 .map(|s| s.as_str())
345 .collect()
346 }
347}
348pub struct ShaderRegistry {
354 pub(super) sources: HashMap<String, ShaderSource>,
356 pub(super) cache: VariantCache,
358 pub cache_hits: u64,
360 pub compilations: u64,
362}
363impl ShaderRegistry {
364 pub fn new(cache_capacity: usize) -> Self {
366 Self {
367 sources: HashMap::new(),
368 cache: VariantCache::new(cache_capacity),
369 cache_hits: 0,
370 compilations: 0,
371 }
372 }
373 pub fn register(&mut self, source: ShaderSource) {
376 let name = source.name.clone();
377 self.sources.insert(name.clone(), source);
378 self.cache.entries.retain(|(k, _, _)| k.name != name);
379 }
380 pub fn get_or_compile(
385 &mut self,
386 name: &str,
387 defines: &HashMap<String, String>,
388 ) -> Result<&CompiledVariant, RegistryError> {
389 let key = ShaderKey::new(name, defines);
390 if self.cache.get(&key).is_some() {
391 self.cache_hits += 1;
392 return Ok(self
393 .cache
394 .get(&key)
395 .expect("cache entry must exist after insertion"));
396 }
397 let source = self
398 .sources
399 .get(name)
400 .ok_or_else(|| RegistryError::UnknownShader(name.to_string()))?;
401 for ph in &source.placeholders {
402 if !defines.contains_key(ph.as_str()) {
403 return Err(RegistryError::MissingDefine {
404 shader: name.to_string(),
405 define: ph.clone(),
406 });
407 }
408 }
409 let wgsl = source.instantiate(defines);
410 let workgroup_size = source.workgroup_size;
411 let variant = CompiledVariant::new(key.clone(), wgsl, workgroup_size);
412 self.compilations += 1;
413 self.cache.insert(variant);
414 Ok(self
415 .cache
416 .get(&key)
417 .expect("cache entry must exist after insertion"))
418 }
419 pub fn shader_names(&self) -> Vec<&str> {
421 self.sources.keys().map(|s| s.as_str()).collect()
422 }
423 pub fn cached_count(&self) -> usize {
425 self.cache.len()
426 }
427 pub fn invalidate(&mut self, name: &str) {
429 self.cache.entries.retain(|(k, _, _)| k.name != name);
430 }
431 pub fn invalidate_all(&mut self) {
433 self.cache.clear();
434 }
435}
436impl ShaderRegistry {
437 pub fn compile_variant(
443 &mut self,
444 name: &str,
445 defines: &HashMap<String, String>,
446 opts: &ShaderCompileOptions,
447 ) -> Result<CompiledVariant, RegistryError> {
448 let source = self
449 .sources
450 .get(name)
451 .ok_or_else(|| RegistryError::UnknownShader(name.to_string()))?;
452 let mut merged = defines.clone();
453 for (k, v) in &opts.extra_defines {
454 merged.entry(k.clone()).or_insert_with(|| v.clone());
455 }
456 for ph in &source.placeholders {
457 if !merged.contains_key(ph.as_str()) {
458 return Err(RegistryError::MissingDefine {
459 shader: name.to_string(),
460 define: ph.clone(),
461 });
462 }
463 }
464 let wgsl = source.instantiate(&merged);
465 if opts.max_source_bytes > 0 && wgsl.len() > opts.max_source_bytes {
466 return Err(RegistryError::SourceTooLarge {
467 shader: name.to_string(),
468 size: wgsl.len(),
469 limit: opts.max_source_bytes,
470 });
471 }
472 let workgroup_size = opts.workgroup_size;
473 let key = ShaderKey::new(name, &merged);
474 self.compilations += 1;
475 Ok(CompiledVariant::new(key, wgsl, workgroup_size))
476 }
477}
478impl ShaderRegistry {
479 pub fn apply_hot_reload(&mut self, tracker: &HotReloadTracker) -> Vec<String> {
483 let stale: Vec<String> = tracker
484 .stale_shaders()
485 .into_iter()
486 .map(|s| s.to_string())
487 .collect();
488 for name in &stale {
489 self.invalidate(name);
490 }
491 stale
492 }
493 pub fn registered_count(&self) -> usize {
495 self.sources.len()
496 }
497 pub fn source_bytes(&self, name: &str) -> usize {
499 self.sources.get(name).map(|s| s.wgsl.len()).unwrap_or(0)
500 }
501}
502#[derive(Debug, Clone)]
504pub struct ShaderSource {
505 pub name: String,
507 pub wgsl: String,
509 pub workgroup_size: [u32; 3],
511 pub placeholders: Vec<String>,
513}
514impl ShaderSource {
515 pub fn new(name: impl Into<String>, wgsl: impl Into<String>, workgroup_size: [u32; 3]) -> Self {
517 let wgsl_str: String = wgsl.into();
518 let placeholders = collect_placeholders(&wgsl_str);
519 Self {
520 name: name.into(),
521 wgsl: wgsl_str,
522 workgroup_size,
523 placeholders,
524 }
525 }
526 pub fn instantiate(&self, defines: &HashMap<String, String>) -> String {
530 let mut out = self.wgsl.clone();
531 for (k, v) in defines {
532 let token = format!("{{{{{}}}}}", k);
533 out = out.replace(&token, v);
534 }
535 out
536 }
537 pub fn threads_per_group(&self) -> u32 {
539 self.workgroup_size[0] * self.workgroup_size[1] * self.workgroup_size[2]
540 }
541}
542#[derive(Debug, Clone, PartialEq)]
544pub enum SpecConstValue {
545 Int(i64),
547 Uint(u64),
549 Float(f64),
551 Bool(bool),
553}
554impl SpecConstValue {
555 pub fn to_wgsl(&self) -> String {
557 match self {
558 SpecConstValue::Int(v) => v.to_string(),
559 SpecConstValue::Uint(v) => format!("{v}u"),
560 SpecConstValue::Float(v) => format!("{v}"),
561 SpecConstValue::Bool(v) => v.to_string(),
562 }
563 }
564}
565#[derive(Debug, Clone, PartialEq)]
567pub struct SpecializationConstant {
568 pub name: String,
570 pub default_value: SpecConstValue,
572 pub override_value: Option<SpecConstValue>,
574}
575impl SpecializationConstant {
576 pub fn new(name: impl Into<String>, default: SpecConstValue) -> Self {
578 Self {
579 name: name.into(),
580 default_value: default,
581 override_value: None,
582 }
583 }
584 pub fn with_override(mut self, value: SpecConstValue) -> Self {
586 self.override_value = Some(value);
587 self
588 }
589 pub fn effective_value(&self) -> &SpecConstValue {
591 self.override_value.as_ref().unwrap_or(&self.default_value)
592 }
593}
594#[derive(Debug, Default)]
597pub struct PipelineCache {
598 pub(super) entries: HashMap<PipelineCacheKey, String>,
599 pub hits: u64,
601 pub misses: u64,
603}
604impl PipelineCache {
605 pub fn new() -> Self {
607 Self::default()
608 }
609 pub fn insert(&mut self, key: PipelineCacheKey, label: impl Into<String>) {
611 self.entries.insert(key, label.into());
612 }
613 pub fn get(&mut self, key: &PipelineCacheKey) -> Option<&str> {
615 if let Some(v) = self.entries.get(key) {
616 self.hits += 1;
617 Some(v.as_str())
618 } else {
619 self.misses += 1;
620 None
621 }
622 }
623 pub fn len(&self) -> usize {
625 self.entries.len()
626 }
627 pub fn is_empty(&self) -> bool {
629 self.entries.is_empty()
630 }
631 pub fn clear(&mut self) {
633 self.entries.clear();
634 }
635 pub fn hit_rate(&self) -> f64 {
637 let total = self.hits + self.misses;
638 if total == 0 {
639 return 0.0;
640 }
641 self.hits as f64 / total as f64
642 }
643}
644#[derive(Debug, Clone)]
646pub struct PipelineDescriptor {
647 pub key: ShaderKey,
649 pub bind_group_count: u32,
651 pub push_constant_bytes: u32,
653 pub label: String,
655}
656#[allow(clippy::too_many_arguments)]
657impl PipelineDescriptor {
658 pub fn new(
660 key: ShaderKey,
661 bind_group_count: u32,
662 push_constant_bytes: u32,
663 label: impl Into<String>,
664 ) -> Self {
665 Self {
666 key,
667 bind_group_count,
668 push_constant_bytes,
669 label: label.into(),
670 }
671 }
672 pub fn validate(&self) -> Result<(), RegistryError> {
674 if self.bind_group_count == 0 {
675 return Err(RegistryError::UnknownShader(
676 "pipeline has no bind groups".into(),
677 ));
678 }
679 Ok(())
680 }
681}
682#[derive(Debug, Clone, PartialEq)]
684pub struct ShaderCompileOptions {
685 pub workgroup_size: [u32; 3],
687 pub extra_defines: HashMap<String, String>,
689 pub max_source_bytes: usize,
692}
693impl ShaderCompileOptions {
694 pub fn new() -> Self {
696 Self {
697 workgroup_size: [64, 1, 1],
698 extra_defines: HashMap::new(),
699 max_source_bytes: 0,
700 }
701 }
702}
703#[derive(Debug, Clone, Default)]
705pub struct SpecConstSet {
706 pub constants: HashMap<String, SpecializationConstant>,
708}
709impl SpecConstSet {
710 pub fn new() -> Self {
712 Self::default()
713 }
714 pub fn add(&mut self, constant: SpecializationConstant) {
716 self.constants.insert(constant.name.clone(), constant);
717 }
718 pub fn to_defines(&self) -> HashMap<String, String> {
721 self.constants
722 .iter()
723 .map(|(name, c)| (name.clone(), c.effective_value().to_wgsl()))
724 .collect()
725 }
726 pub fn has(&self, name: &str) -> bool {
728 self.constants.contains_key(name)
729 }
730 pub fn get_wgsl(&self, name: &str) -> Option<String> {
732 self.constants
733 .get(name)
734 .map(|c| c.effective_value().to_wgsl())
735 }
736}
737#[derive(Debug, Clone, PartialEq, Eq)]
739pub enum RegistryError {
740 UnknownShader(String),
742 MissingDefine {
744 shader: String,
746 define: String,
748 },
749 SourceTooLarge {
751 shader: String,
753 size: usize,
755 limit: usize,
757 },
758}
759#[derive(Debug, Clone, PartialEq, Eq, Hash)]
764pub struct PipelineCacheKey {
765 pub shader_key: ShaderKey,
767 pub workgroup_size: [u32; 3],
769 pub push_constant_bytes: u32,
771 pub layout_hash: u64,
773}
774impl PipelineCacheKey {
775 pub fn new(
777 shader_key: ShaderKey,
778 workgroup_size: [u32; 3],
779 push_constant_bytes: u32,
780 layout_hash: u64,
781 ) -> Self {
782 Self {
783 shader_key,
784 workgroup_size,
785 push_constant_bytes,
786 layout_hash,
787 }
788 }
789 pub fn hash_key(&self) -> u64 {
791 let repr = format!(
792 "{}_{:?}_{}_{}",
793 self.shader_key.fingerprint(),
794 self.workgroup_size,
795 self.push_constant_bytes,
796 self.layout_hash,
797 );
798 compute_cache_key(&repr)
799 }
800}