Skip to main content

oxiphysics_gpu/shaders/
types.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5#[allow(unused_imports)]
6use super::functions_2::*;
7use std::collections::HashMap;
8
9#[allow(unused_imports)]
10use super::functions::*;
11use super::functions::{
12    BOUNDARY_ENFORCE_WGSL, BROADPHASE_SORT_SHADER, INTEGRATE_WGSL, LBM_BGK_D2Q9_WGSL,
13    LBM_STREAMING_SHADER, RIGID_INTEGRATE_SHADER, SPH_DENSITY_WGSL, SPH_FORCE_WGSL,
14};
15
16/// Reflection data extracted from a (mock) SPIR-V module or WGSL source.
17#[derive(Debug, Clone)]
18pub struct SpirVModule {
19    /// Entry point function names.
20    pub entry_points: Vec<String>,
21    /// Number of bindings.
22    pub binding_count: usize,
23    /// Workgroup size \[x, y, z\].
24    pub workgroup_size: [u32; 3],
25    /// Raw SPIR-V bytes.
26    pub spirv_bytes: Vec<u8>,
27}
28impl SpirVModule {
29    /// Build a SpirV module from WGSL source (mock reflection).
30    pub fn from_wgsl(source: &str) -> Self {
31        let mut entry_points = Vec::new();
32        for line in source.lines() {
33            let trimmed = line.trim();
34            if let Some(pos) = trimmed.find("fn ") {
35                let rest = &trimmed[pos + 3..];
36                let name: String = rest
37                    .chars()
38                    .take_while(|c| c.is_alphanumeric() || *c == '_')
39                    .collect();
40                if !name.is_empty() {
41                    entry_points.push(name);
42                }
43            }
44        }
45        let binding_count = source.matches("@binding(").count();
46        let workgroup_size = parse_workgroup_size(source);
47        let spirv_bytes = mock_compile_to_spirv(
48            source,
49            entry_points.first().map(|s| s.as_str()).unwrap_or("main"),
50        );
51        Self {
52            entry_points,
53            binding_count,
54            workgroup_size,
55            spirv_bytes,
56        }
57    }
58    /// Check whether this module has a specific entry point.
59    pub fn has_entry_point(&self, name: &str) -> bool {
60        self.entry_points.iter().any(|e| e == name)
61    }
62    /// Size of the SPIR-V binary in bytes.
63    pub fn byte_size(&self) -> usize {
64        self.spirv_bytes.len()
65    }
66}
67/// A push constant range that can be set per draw/dispatch call.
68#[derive(Debug, Clone, Copy)]
69pub struct PushConstantRange {
70    /// Byte offset into the push constant block.
71    pub offset: u32,
72    /// Size in bytes.
73    pub size: u32,
74    /// Shader stage.
75    pub stage: ShaderStage,
76}
77impl PushConstantRange {
78    /// Create a new push constant range.
79    pub fn new(offset: u32, size: u32, stage: ShaderStage) -> Self {
80        Self {
81            offset,
82            size,
83            stage,
84        }
85    }
86    /// Check if this range fits within the standard 128-byte GPU limit.
87    pub fn fits_standard_limit(&self) -> bool {
88        self.offset + self.size <= 128
89    }
90}
91/// Texture address mode.
92#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93pub enum AddressMode {
94    /// Clamp coordinates to \[0, 1\].
95    ClampToEdge,
96    /// Wrap coordinates (repeat).
97    Repeat,
98    /// Mirror-repeat coordinates.
99    MirrorRepeat,
100}
101/// A layout describing all bindings in a descriptor set group.
102#[derive(Debug, Clone)]
103pub struct DescriptorSetLayout {
104    /// Group index (set number).
105    pub group: u32,
106    /// All binding entries.
107    pub bindings: Vec<DescriptorBinding>,
108}
109impl DescriptorSetLayout {
110    /// Create a new empty layout for the given group.
111    pub fn new(group: u32) -> Self {
112        Self {
113            group,
114            bindings: Vec::new(),
115        }
116    }
117    /// Add a storage buffer binding.
118    pub fn add_storage_buffer(&mut self, binding: u32, stage: ShaderStage, read_only: bool) {
119        self.bindings.push(DescriptorBinding {
120            binding,
121            descriptor_type: DescriptorType::StorageBuffer,
122            stage,
123            read_only,
124        });
125    }
126    /// Add a uniform buffer binding.
127    pub fn add_uniform_buffer(&mut self, binding: u32, stage: ShaderStage) {
128        self.bindings.push(DescriptorBinding {
129            binding,
130            descriptor_type: DescriptorType::UniformBuffer,
131            stage,
132            read_only: true,
133        });
134    }
135    /// Add a combined image sampler binding.
136    pub fn add_sampler(&mut self, binding: u32, stage: ShaderStage) {
137        self.bindings.push(DescriptorBinding {
138            binding,
139            descriptor_type: DescriptorType::CombinedImageSampler,
140            stage,
141            read_only: true,
142        });
143    }
144    /// Number of bindings in this layout.
145    pub fn len(&self) -> usize {
146        self.bindings.len()
147    }
148    /// Whether the layout has no bindings.
149    pub fn is_empty(&self) -> bool {
150        self.bindings.is_empty()
151    }
152}
153/// A storage binding entry in a bind group.
154#[derive(Debug, Clone)]
155pub struct StorageBinding {
156    /// Name in WGSL.
157    pub name: String,
158    /// Binding index.
159    pub binding: u32,
160    /// Whether read-only.
161    pub read_only: bool,
162}
163/// Manages shader source files and detects changes for hot-reloading.
164#[derive(Debug, Default)]
165pub struct ShaderHotReloadManager {
166    /// Map from shader name to current source.
167    pub(super) sources: HashMap<String, String>,
168    /// Map from shader name to source hash (for change detection).
169    pub(super) hashes: HashMap<String, u64>,
170}
171impl ShaderHotReloadManager {
172    /// Create a new hot-reload manager.
173    pub fn new() -> Self {
174        Self::default()
175    }
176    /// Start watching a shader file with initial source.
177    pub fn watch(&mut self, name: &str, source: &str) {
178        let hash = simple_hash(source);
179        self.sources.insert(name.to_string(), source.to_string());
180        self.hashes.insert(name.to_string(), hash);
181    }
182    /// Stop watching a shader.
183    pub fn unwatch(&mut self, name: &str) {
184        self.sources.remove(name);
185        self.hashes.remove(name);
186    }
187    /// Update a shader's source. Returns `true` if the source changed.
188    pub fn update(&mut self, name: &str, new_source: &str) -> bool {
189        let new_hash = simple_hash(new_source);
190        if let Some(old_hash) = self.hashes.get(name)
191            && *old_hash == new_hash
192        {
193            return false;
194        }
195        self.sources
196            .insert(name.to_string(), new_source.to_string());
197        self.hashes.insert(name.to_string(), new_hash);
198        true
199    }
200    /// Check if a shader is being watched.
201    pub fn is_watched(&self, name: &str) -> bool {
202        self.sources.contains_key(name)
203    }
204    /// Get the current source for a watched shader.
205    pub fn get_source(&self, name: &str) -> Option<&str> {
206        self.sources.get(name).map(|s| s.as_str())
207    }
208    /// Return names of all watched shaders.
209    pub fn watched_names(&self) -> Vec<&str> {
210        self.sources.keys().map(|s| s.as_str()).collect()
211    }
212    /// Number of watched shaders.
213    pub fn len(&self) -> usize {
214        self.sources.len()
215    }
216    /// Whether no shaders are watched.
217    pub fn is_empty(&self) -> bool {
218        self.sources.is_empty()
219    }
220}
221/// Type of a GPU descriptor.
222#[derive(Debug, Clone, Copy, PartialEq, Eq)]
223pub enum DescriptorType {
224    /// Uniform buffer.
225    UniformBuffer,
226    /// Storage buffer.
227    StorageBuffer,
228    /// Combined image sampler.
229    CombinedImageSampler,
230    /// Storage image.
231    StorageImage,
232}
233/// Registry mapping shader names to their [`ShaderMetadata`].
234#[derive(Debug, Default)]
235pub struct ShaderMetaRegistry {
236    pub(super) entries: HashMap<String, ShaderMetadata>,
237}
238impl ShaderMetaRegistry {
239    /// Create an empty registry.
240    pub fn new() -> Self {
241        Self::default()
242    }
243    /// Register a shader under `name`.
244    pub fn register(&mut self, name: &str, meta: ShaderMetadata) {
245        self.entries.insert(name.to_string(), meta);
246    }
247    /// Look up a shader by name.
248    pub fn lookup(&self, name: &str) -> Option<&ShaderMetadata> {
249        self.entries.get(name)
250    }
251    /// Return all registered shader names.
252    pub fn all_names(&self) -> Vec<&str> {
253        self.entries.keys().map(|s| s.as_str()).collect()
254    }
255    /// Number of registered shaders.
256    pub fn len(&self) -> usize {
257        self.entries.len()
258    }
259    /// True if no shaders are registered.
260    pub fn is_empty(&self) -> bool {
261        self.entries.is_empty()
262    }
263}
264/// A specialization constant for shader compilation.
265#[derive(Debug, Clone)]
266pub struct SpecializationConstant {
267    /// Name of the constant in the shader source.
268    pub name: String,
269    /// Default value as a string.
270    pub default_value: String,
271    /// Description of what the constant controls.
272    pub description: String,
273}
274impl SpecializationConstant {
275    /// Create a new specialization constant.
276    pub fn new(name: &str, default_value: &str, description: &str) -> Self {
277        Self {
278            name: name.to_string(),
279            default_value: default_value.to_string(),
280            description: description.to_string(),
281        }
282    }
283}
284/// A parameterised WGSL shader template.
285///
286/// Placeholders of the form `{KEY}` in the template string are replaced by
287/// the corresponding values supplied via [`ShaderTemplate::instantiate`].
288#[derive(Debug, Clone)]
289pub struct ShaderTemplate {
290    /// The raw template source with `{KEY}` placeholders.
291    pub template: String,
292}
293impl ShaderTemplate {
294    /// Create a new shader template from a source string.
295    pub fn new(template: impl Into<String>) -> Self {
296        Self {
297            template: template.into(),
298        }
299    }
300    /// Instantiate the template by replacing all `{KEY}` placeholders.
301    ///
302    /// # Arguments
303    /// * `params` – map from placeholder name (without braces) to replacement string.
304    ///
305    /// # Returns
306    /// A new `String` with all recognised placeholders replaced.  Unknown
307    /// placeholders are left as-is.
308    pub fn instantiate(&self, params: &HashMap<&str, &str>) -> String {
309        let mut result = self.template.clone();
310        for (key, value) in params {
311            let placeholder = format!("{{{}}}", key);
312            result = result.replace(&placeholder, value);
313        }
314        result
315    }
316    /// Return a list of placeholder names found in the template.
317    ///
318    /// Only returns names that look like valid template placeholders:
319    /// `{IDENTIFIER}` where IDENTIFIER is uppercase letters, digits, and
320    /// underscores, starting with an uppercase letter.
321    pub fn placeholders(&self) -> Vec<String> {
322        let mut result = Vec::new();
323        let chars: Vec<char> = self.template.chars().collect();
324        let mut i = 0;
325        while i < chars.len() {
326            if chars[i] == '{' {
327                if i + 1 < chars.len() && chars[i + 1].is_ascii_uppercase() {
328                    let start = i + 1;
329                    let mut end = start;
330                    while end < chars.len() && chars[end] != '}' {
331                        end += 1;
332                    }
333                    if end < chars.len() {
334                        let name: String = chars[start..end].iter().collect();
335                        let is_valid = name
336                            .chars()
337                            .all(|c| c.is_ascii_uppercase() || c.is_ascii_digit() || c == '_');
338                        if is_valid && !result.contains(&name) {
339                            result.push(name);
340                        }
341                        i = end + 1;
342                    } else {
343                        i += 1;
344                    }
345                } else {
346                    i += 1;
347                }
348            } else {
349                i += 1;
350            }
351        }
352        result
353    }
354    /// Check if all placeholders are provided in the params map.
355    pub fn all_placeholders_provided(&self, params: &HashMap<&str, &str>) -> bool {
356        for p in self.placeholders() {
357            if !params.contains_key(p.as_str()) {
358                return false;
359            }
360        }
361        true
362    }
363}
364/// Texture format for attachments.
365#[derive(Debug, Clone, Copy, PartialEq, Eq)]
366pub enum TextureFormat {
367    /// 8-bit RGBA unorm.
368    Rgba8Unorm,
369    /// 8-bit RGBA sRGB.
370    Rgba8Srgb,
371    /// 16-bit float RGBA (HDR).
372    Rgba16Float,
373    /// 32-bit float single channel.
374    R32Float,
375    /// 32-bit depth.
376    Depth32Float,
377    /// 24-bit depth + 8-bit stencil.
378    Depth24PlusStencil8,
379}
380/// A flexible bind group layout builder.
381#[derive(Debug, Clone, Default)]
382pub struct BindGroupLayout {
383    pub(super) uniforms: Vec<UniformBinding>,
384    pub(super) storages: Vec<StorageBinding>,
385}
386impl BindGroupLayout {
387    /// Create an empty layout.
388    pub fn new() -> Self {
389        Self::default()
390    }
391    /// Add a uniform buffer binding.
392    pub fn add_uniform(&mut self, name: &str, _group: u32, binding: u32, size_bytes: u32) {
393        self.uniforms.push(UniformBinding {
394            name: name.to_string(),
395            binding,
396            size_bytes,
397        });
398    }
399    /// Add a storage buffer binding.
400    pub fn add_storage(&mut self, name: &str, _group: u32, binding: u32, read_only: bool) {
401        self.storages.push(StorageBinding {
402            name: name.to_string(),
403            binding,
404            read_only,
405        });
406    }
407    /// Total number of bindings.
408    pub fn binding_count(&self) -> usize {
409        self.uniforms.len() + self.storages.len()
410    }
411    /// Whether this layout has no bindings.
412    pub fn is_empty(&self) -> bool {
413        self.uniforms.is_empty() && self.storages.is_empty()
414    }
415    /// Generate a WGSL snippet declaring all bindings.
416    pub fn to_wgsl_snippet(&self) -> String {
417        let mut out = String::new();
418        for u in &self.uniforms {
419            out.push_str(&format!(
420                "@group(0) @binding({}) var<uniform> {}: {};\n",
421                u.binding,
422                u.name.to_lowercase(),
423                u.name
424            ));
425        }
426        for s in &self.storages {
427            let access = if s.read_only { "read" } else { "read_write" };
428            out.push_str(&format!(
429                "@group(0) @binding({}) var<storage, {}> {}: array<f32>;\n",
430                s.binding,
431                access,
432                s.name.to_lowercase()
433            ));
434        }
435        out
436    }
437}
438/// Description of a depth/stencil attachment.
439#[derive(Debug, Clone)]
440pub struct DepthAttachmentDesc {
441    /// Format of the depth attachment.
442    pub format: TextureFormat,
443    /// Load operation for depth.
444    pub load_op: LoadOp,
445    /// Store operation for depth.
446    pub store_op: StoreOp,
447    /// Clear depth value.
448    pub clear_depth: f32,
449}
450/// A pipeline for compiling WGSL shaders with preprocessing steps.
451///
452/// Steps: resolve includes -> apply specialization -> validate -> cache.
453pub struct ShaderCompilationPipeline {
454    pub(super) includes: HashMap<String, String>,
455    pub(super) cache: ShaderCache,
456}
457impl ShaderCompilationPipeline {
458    /// Create a new compilation pipeline.
459    pub fn new() -> Self {
460        Self {
461            includes: HashMap::new(),
462            cache: ShaderCache::new(),
463        }
464    }
465    /// Register an include file for `#include` resolution.
466    pub fn add_include(&mut self, name: &str, source: &str) {
467        self.includes.insert(name.to_string(), source.to_string());
468    }
469    /// Compile a shader source through the pipeline.
470    ///
471    /// Returns the processed source string. Caches the result.
472    pub fn compile(
473        &mut self,
474        name: &str,
475        source: &str,
476        spec_map: Option<&SpecializationMap>,
477    ) -> Result<String, String> {
478        if let Some(cached) = self.cache.entries.get(name) {
479            return Ok(cached.clone());
480        }
481        let includes_ref: HashMap<&str, &str> = self
482            .includes
483            .iter()
484            .map(|(k, v)| (k.as_str(), v.as_str()))
485            .collect();
486        let resolved = resolve_includes(source, &includes_ref);
487        let specialized = if let Some(sm) = spec_map {
488            sm.apply(&resolved)
489        } else {
490            resolved
491        };
492        if !validate_wgsl_structure(&specialized) {
493            return Err(format!("shader '{}' failed structural validation", name));
494        }
495        self.cache
496            .entries
497            .insert(name.to_string(), specialized.clone());
498        Ok(specialized)
499    }
500    /// Return the number of cached shaders.
501    pub fn cache_size(&self) -> usize {
502        self.cache.len()
503    }
504    /// Clear the compilation cache.
505    pub fn clear_cache(&mut self) {
506        self.cache.clear();
507    }
508}
509/// A registry of named compute shader descriptors.
510#[derive(Debug, Default)]
511pub struct ShaderRegistry {
512    pub(super) shaders: HashMap<String, ComputeShaderDesc>,
513}
514impl ShaderRegistry {
515    /// Create a new empty registry.
516    pub fn new() -> Self {
517        Self::default()
518    }
519    /// Register a shader under the given name.
520    pub fn register(&mut self, name: impl Into<String>, desc: ComputeShaderDesc) {
521        self.shaders.insert(name.into(), desc);
522    }
523    /// Retrieve a shader descriptor by name.
524    pub fn get(&self, name: &str) -> Option<&ComputeShaderDesc> {
525        self.shaders.get(name)
526    }
527    /// Return the number of registered shaders.
528    pub fn len(&self) -> usize {
529        self.shaders.len()
530    }
531    /// Return true if the registry is empty.
532    pub fn is_empty(&self) -> bool {
533        self.shaders.is_empty()
534    }
535    /// Return an iterator over registered shader names.
536    pub fn names(&self) -> impl Iterator<Item = &str> {
537        self.shaders.keys().map(|s| s.as_str())
538    }
539    /// Remove a shader from the registry.
540    pub fn unregister(&mut self, name: &str) -> Option<ComputeShaderDesc> {
541        self.shaders.remove(name)
542    }
543    /// Check if a shader is registered.
544    pub fn contains(&self, name: &str) -> bool {
545        self.shaders.contains_key(name)
546    }
547    /// Create a registry pre-populated with the built-in shaders.
548    pub fn with_builtins() -> Self {
549        let mut reg = Self::new();
550        reg.register(
551            "sph_density",
552            ComputeShaderDesc::new("main", [64, 1, 1], SPH_DENSITY_WGSL),
553        );
554        reg.register(
555            "sph_force",
556            ComputeShaderDesc::new("main", [64, 1, 1], SPH_FORCE_WGSL),
557        );
558        reg.register(
559            "integrate",
560            ComputeShaderDesc::new("main", [64, 1, 1], INTEGRATE_WGSL),
561        );
562        reg.register(
563            "lbm_bgk_d2q9",
564            ComputeShaderDesc::new("main", [64, 1, 1], LBM_BGK_D2Q9_WGSL),
565        );
566        reg.register(
567            "lbm_streaming",
568            ComputeShaderDesc::new("main", [64, 1, 1], LBM_STREAMING_SHADER),
569        );
570        reg.register(
571            "rigid_integrate",
572            ComputeShaderDesc::new("main", [64, 1, 1], RIGID_INTEGRATE_SHADER),
573        );
574        reg.register(
575            "broadphase_sort",
576            ComputeShaderDesc::new("main", [64, 1, 1], BROADPHASE_SORT_SHADER),
577        );
578        reg.register(
579            "boundary_enforce",
580            ComputeShaderDesc::new("main", [64, 1, 1], BOUNDARY_ENFORCE_WGSL),
581        );
582        reg
583    }
584}
585/// Descriptor for a WGSL compute shader.
586#[derive(Debug, Clone)]
587pub struct ComputeShaderDesc {
588    /// Name of the entry-point function (e.g. `"main"`).
589    pub entry_point: String,
590    /// Workgroup size `[x, y, z]`.
591    pub workgroup_size: [u32; 3],
592    /// Full WGSL source string.
593    pub source: String,
594}
595impl ComputeShaderDesc {
596    /// Create a new compute shader descriptor.
597    pub fn new(
598        entry_point: impl Into<String>,
599        workgroup_size: [u32; 3],
600        source: impl Into<String>,
601    ) -> Self {
602        Self {
603            entry_point: entry_point.into(),
604            workgroup_size,
605            source: source.into(),
606        }
607    }
608    /// Number of threads per workgroup.
609    pub fn threads_per_workgroup(&self) -> u32 {
610        self.workgroup_size[0] * self.workgroup_size[1] * self.workgroup_size[2]
611    }
612    /// Count the number of binding annotations in the shader source.
613    pub fn binding_count(&self) -> usize {
614        self.source.matches("@binding(").count()
615    }
616}
617/// Sampler descriptor for creating GPU samplers.
618#[derive(Debug, Clone)]
619pub struct SamplerDesc {
620    /// Minification filter.
621    pub filter_min: FilterMode,
622    /// Magnification filter.
623    pub filter_mag: FilterMode,
624    /// Address mode for all axes.
625    pub address_mode: AddressMode,
626    /// Anisotropy level (1 = disabled).
627    pub anisotropy: u32,
628    /// LOD bias.
629    pub lod_bias: f32,
630    /// Maximum LOD level.
631    pub lod_max: f32,
632}
633impl SamplerDesc {
634    /// Create a linear sampler with clamp-to-edge.
635    pub fn linear() -> Self {
636        Self {
637            filter_min: FilterMode::Linear,
638            filter_mag: FilterMode::Linear,
639            address_mode: AddressMode::ClampToEdge,
640            anisotropy: 1,
641            lod_bias: 0.0,
642            lod_max: 16.0,
643        }
644    }
645    /// Create a nearest-neighbor sampler.
646    pub fn nearest() -> Self {
647        Self {
648            filter_min: FilterMode::Nearest,
649            filter_mag: FilterMode::Nearest,
650            address_mode: AddressMode::ClampToEdge,
651            anisotropy: 1,
652            lod_bias: 0.0,
653            lod_max: 0.0,
654        }
655    }
656    /// Create an anisotropic sampler with repeat addressing.
657    pub fn anisotropic(max_anisotropy: u32) -> Self {
658        Self {
659            filter_min: FilterMode::Linear,
660            filter_mag: FilterMode::Linear,
661            address_mode: AddressMode::Repeat,
662            anisotropy: max_anisotropy,
663            lod_bias: 0.0,
664            lod_max: 16.0,
665        }
666    }
667}
668/// Description of a full render pass.
669#[derive(Debug, Clone)]
670pub struct RenderPassDesc {
671    /// Color attachments.
672    pub color_attachments: Vec<ColorAttachmentDesc>,
673    /// Optional depth/stencil attachment.
674    pub depth_attachment: Option<DepthAttachmentDesc>,
675    /// Name of this render pass (for debugging).
676    pub name: String,
677}
678impl RenderPassDesc {
679    /// Create a simple render pass with one RGBA8 color attachment, no depth.
680    pub fn new_simple_color() -> Self {
681        Self {
682            color_attachments: vec![ColorAttachmentDesc {
683                format: TextureFormat::Rgba8Unorm,
684                load_op: LoadOp::Clear,
685                store_op: StoreOp::Store,
686                clear_color: [0.0, 0.0, 0.0, 1.0],
687            }],
688            depth_attachment: None,
689            name: "SimpleColor".to_string(),
690        }
691    }
692    /// Create a render pass with an HDR color attachment and depth buffer.
693    pub fn new_with_depth() -> Self {
694        Self {
695            color_attachments: vec![ColorAttachmentDesc {
696                format: TextureFormat::Rgba16Float,
697                load_op: LoadOp::Clear,
698                store_op: StoreOp::Store,
699                clear_color: [0.0, 0.0, 0.0, 1.0],
700            }],
701            depth_attachment: Some(DepthAttachmentDesc {
702                format: TextureFormat::Depth32Float,
703                load_op: LoadOp::Clear,
704                store_op: StoreOp::Store,
705                clear_depth: 1.0,
706            }),
707            name: "ColorDepth".to_string(),
708        }
709    }
710    /// Number of total attachments (color + optional depth).
711    pub fn total_attachment_count(&self) -> usize {
712        self.color_attachments.len()
713            + if self.depth_attachment.is_some() {
714                1
715            } else {
716                0
717            }
718    }
719}
720/// High-level category for a GPU compute shader.
721#[derive(Debug, Clone, PartialEq, Eq, Hash)]
722pub enum ShaderVariant {
723    /// General physics integration / force kernels.
724    Physics,
725    /// Collision detection and response.
726    Collision,
727    /// Smoothed particle hydrodynamics.
728    Sph,
729    /// Lattice Boltzmann method.
730    Lbm,
731    /// Rigid body dynamics.
732    RigidBody,
733    /// Neural network inference.
734    NeuralInference,
735}
736/// Metadata describing a single compute shader.
737#[derive(Debug, Clone)]
738pub struct ShaderMetadata {
739    /// The high-level shader category.
740    pub variant: ShaderVariant,
741    /// Name of the entry-point function (e.g. `"main"`).
742    pub entry_point: String,
743    /// Workgroup size `[x, y, z]`.
744    pub workgroup_size: [u32; 3],
745    /// Number of bind groups required by the shader.
746    pub bind_group_count: u32,
747}
748impl ShaderMetadata {
749    /// Create a new metadata record.
750    pub fn new(
751        variant: ShaderVariant,
752        entry_point: impl Into<String>,
753        workgroup_size: [u32; 3],
754        bind_group_count: u32,
755    ) -> Self {
756        Self {
757            variant,
758            entry_point: entry_point.into(),
759            workgroup_size,
760            bind_group_count,
761        }
762    }
763    /// Total number of threads per workgroup.
764    pub fn threads_per_workgroup(&self) -> u32 {
765        self.workgroup_size[0] * self.workgroup_size[1] * self.workgroup_size[2]
766    }
767}
768/// A simple cache for compiled shader sources.
769///
770/// Caches shader source strings by a composite key of name + specialization.
771#[derive(Debug, Default)]
772pub struct ShaderCache {
773    pub(super) entries: HashMap<String, String>,
774}
775impl ShaderCache {
776    /// Create a new empty cache.
777    pub fn new() -> Self {
778        Self::default()
779    }
780    /// Get a cached shader source by key, or compute and cache it.
781    pub fn get_or_insert(&mut self, key: &str, compute: impl FnOnce() -> String) -> &str {
782        self.entries.entry(key.to_string()).or_insert_with(compute)
783    }
784    /// Check if a shader is cached.
785    pub fn contains(&self, key: &str) -> bool {
786        self.entries.contains_key(key)
787    }
788    /// Return the number of cached entries.
789    pub fn len(&self) -> usize {
790        self.entries.len()
791    }
792    /// Check if the cache is empty.
793    pub fn is_empty(&self) -> bool {
794        self.entries.is_empty()
795    }
796    /// Clear the cache.
797    pub fn clear(&mut self) {
798        self.entries.clear();
799    }
800    /// Remove a specific entry.
801    pub fn remove(&mut self, key: &str) -> Option<String> {
802        self.entries.remove(key)
803    }
804}
805/// Shader stage flags.
806#[derive(Debug, Clone, Copy, PartialEq, Eq)]
807pub enum ShaderStage {
808    /// Vertex shader stage.
809    Vertex,
810    /// Fragment shader stage.
811    Fragment,
812    /// Compute shader stage.
813    Compute,
814    /// All stages.
815    All,
816}
817/// A parameterised shader template that replaces `#define KEY` tokens.
818///
819/// Instantiation replaces occurrences of each key found in `defines`
820/// with the corresponding value.  The key must appear as a whole word
821/// (surrounded by non-identifier characters or at start/end of line).
822#[derive(Debug, Clone)]
823pub struct ShaderTemplateV2 {
824    /// Raw template source text.
825    pub source: String,
826    /// Map from token name to replacement value.
827    pub defines: HashMap<String, String>,
828}
829impl ShaderTemplateV2 {
830    /// Create a new template with the given source and defines.
831    pub fn new(source: impl Into<String>, defines: HashMap<String, String>) -> Self {
832        Self {
833            source: source.into(),
834            defines,
835        }
836    }
837    /// Instantiate the template by performing all `#define` substitutions.
838    ///
839    /// Each entry in `defines` replaces all occurrences of the key
840    /// in the source with the corresponding value.
841    pub fn instantiate(&self) -> String {
842        let mut result = self.source.clone();
843        for (key, value) in &self.defines {
844            result = result.replace(key.as_str(), value.as_str());
845        }
846        result
847    }
848}
849/// Description of a color attachment in a render pass.
850#[derive(Debug, Clone)]
851pub struct ColorAttachmentDesc {
852    /// Texture format of the attachment.
853    pub format: TextureFormat,
854    /// Load operation.
855    pub load_op: LoadOp,
856    /// Store operation.
857    pub store_op: StoreOp,
858    /// Clear color \[r, g, b, a\].
859    pub clear_color: [f32; 4],
860}
861/// A uniform binding entry in a bind group.
862#[derive(Debug, Clone)]
863pub struct UniformBinding {
864    /// Name in WGSL.
865    pub name: String,
866    /// Binding index.
867    pub binding: u32,
868    /// Size in bytes.
869    pub size_bytes: u32,
870}
871/// Store operation for an attachment.
872#[derive(Debug, Clone, Copy, PartialEq, Eq)]
873pub enum StoreOp {
874    /// Store the rendered content.
875    Store,
876    /// Discard the content after rendering.
877    Discard,
878}
879/// A single binding entry in a descriptor set layout.
880#[derive(Debug, Clone)]
881pub struct DescriptorBinding {
882    /// Binding index.
883    pub binding: u32,
884    /// Type of the descriptor.
885    pub descriptor_type: DescriptorType,
886    /// Shader stage that uses this binding.
887    pub stage: ShaderStage,
888    /// Whether a storage buffer is read-only.
889    pub read_only: bool,
890}
891/// Texture filtering mode.
892#[derive(Debug, Clone, Copy, PartialEq, Eq)]
893pub enum FilterMode {
894    /// Nearest neighbor (point) filtering.
895    Nearest,
896    /// Bilinear/trilinear filtering.
897    Linear,
898}
899/// Load operation for an attachment.
900#[derive(Debug, Clone, Copy, PartialEq, Eq)]
901pub enum LoadOp {
902    /// Load the existing content.
903    Load,
904    /// Clear to a specified value.
905    Clear,
906    /// Content doesn't matter.
907    DontCare,
908}
909/// A collection of specialization constants for a shader.
910#[derive(Debug, Clone, Default)]
911pub struct SpecializationMap {
912    pub(super) constants: Vec<SpecializationConstant>,
913    pub(super) overrides: HashMap<String, String>,
914}
915impl SpecializationMap {
916    /// Create a new empty specialization map.
917    pub fn new() -> Self {
918        Self::default()
919    }
920    /// Define a specialization constant with a default value.
921    pub fn define(&mut self, name: &str, default_value: &str, description: &str) {
922        self.constants.push(SpecializationConstant::new(
923            name,
924            default_value,
925            description,
926        ));
927    }
928    /// Override a constant's value.
929    pub fn set(&mut self, name: &str, value: &str) {
930        self.overrides.insert(name.to_string(), value.to_string());
931    }
932    /// Get the effective value for a constant (override or default).
933    pub fn get(&self, name: &str) -> Option<&str> {
934        if let Some(v) = self.overrides.get(name) {
935            return Some(v.as_str());
936        }
937        for c in &self.constants {
938            if c.name == name {
939                return Some(c.default_value.as_str());
940            }
941        }
942        None
943    }
944    /// Return the number of defined constants.
945    pub fn len(&self) -> usize {
946        self.constants.len()
947    }
948    /// Check if there are no constants defined.
949    pub fn is_empty(&self) -> bool {
950        self.constants.is_empty()
951    }
952    /// Apply specialization constants to a shader source by replacing
953    /// `const {NAME} = {DEFAULT};` patterns with overridden values.
954    pub fn apply(&self, source: &str) -> String {
955        let mut result = source.to_string();
956        for c in &self.constants {
957            let value = self
958                .overrides
959                .get(&c.name)
960                .map(|s| s.as_str())
961                .unwrap_or(&c.default_value);
962            let old = format!("const {} = {};", c.name, c.default_value);
963            let new = format!("const {} = {};", c.name, value);
964            result = result.replace(&old, &new);
965        }
966        result
967    }
968}
969/// Description of a uniform buffer binding.
970#[derive(Debug, Clone)]
971pub struct UniformBufferDesc {
972    /// Name of the uniform block.
973    pub name: String,
974    /// Bind group index.
975    pub group: u32,
976    /// Binding slot within the group.
977    pub binding: u32,
978    /// Size of the uniform buffer in bytes.
979    pub size_bytes: u32,
980}
981impl UniformBufferDesc {
982    /// Create a new uniform buffer description.
983    pub fn new(name: &str, group: u32, binding: u32, size_bytes: u32) -> Self {
984        Self {
985            name: name.to_string(),
986            group,
987            binding,
988            size_bytes,
989        }
990    }
991    /// Generate the WGSL binding annotation string.
992    pub fn wgsl_annotation(&self) -> String {
993        format!(
994            "@group({}) @binding({}) var<uniform> {}: {};",
995            self.group,
996            self.binding,
997            self.name.to_lowercase(),
998            self.name
999        )
1000    }
1001}
1002/// Cache for compiled shader bytecode (mock compiled blobs).
1003///
1004/// Supports a maximum total byte budget; when the budget is exceeded,
1005/// `evict_oldest` removes the oldest-inserted entry.
1006#[derive(Debug)]
1007pub struct BytecodeShaderCache {
1008    /// Map from shader name to compiled bytecode.
1009    pub cache: HashMap<String, Vec<u8>>,
1010    /// Ordered insertion keys for LRU-like eviction.
1011    pub(super) insertion_order: Vec<String>,
1012    /// Maximum total bytes allowed.
1013    pub max_size: usize,
1014}
1015impl BytecodeShaderCache {
1016    /// Create a new cache with the given byte budget.
1017    pub fn new(max_size: usize) -> Self {
1018        Self {
1019            cache: HashMap::new(),
1020            insertion_order: Vec::new(),
1021            max_size,
1022        }
1023    }
1024    /// Insert (or replace) a shader's compiled bytecode.
1025    ///
1026    /// If the total size would exceed `max_size`, `evict_oldest` is called
1027    /// before inserting.
1028    pub fn insert(&mut self, name: &str, bytecode: Vec<u8>) {
1029        if self.cache.contains_key(name) {
1030            self.insertion_order.retain(|k| k != name);
1031        }
1032        self.cache.insert(name.to_string(), bytecode);
1033        self.insertion_order.push(name.to_string());
1034        while self.total_bytes() > self.max_size && !self.insertion_order.is_empty() {
1035            self.evict_oldest();
1036        }
1037    }
1038    /// Retrieve compiled bytecode by shader name.
1039    pub fn get(&self, name: &str) -> Option<&Vec<u8>> {
1040        self.cache.get(name)
1041    }
1042    /// Evict the oldest cached entry.  No-op if cache is empty.
1043    pub fn evict_oldest(&mut self) {
1044        if let Some(oldest) = self.insertion_order.first().cloned() {
1045            self.insertion_order.remove(0);
1046            self.cache.remove(&oldest);
1047        }
1048    }
1049    /// Total bytes currently stored across all cached entries.
1050    pub fn total_bytes(&self) -> usize {
1051        self.cache.values().map(|v| v.len()).sum()
1052    }
1053    /// Number of cached entries.
1054    pub fn len(&self) -> usize {
1055        self.cache.len()
1056    }
1057    /// True if the cache is empty.
1058    pub fn is_empty(&self) -> bool {
1059        self.cache.is_empty()
1060    }
1061}