#[allow(unused_imports)]
use super::functions_2::*;
use std::collections::HashMap;
#[allow(unused_imports)]
use super::functions::*;
use super::functions::{
BOUNDARY_ENFORCE_WGSL, BROADPHASE_SORT_SHADER, INTEGRATE_WGSL, LBM_BGK_D2Q9_WGSL,
LBM_STREAMING_SHADER, RIGID_INTEGRATE_SHADER, SPH_DENSITY_WGSL, SPH_FORCE_WGSL,
};
#[derive(Debug, Clone)]
pub struct SpirVModule {
pub entry_points: Vec<String>,
pub binding_count: usize,
pub workgroup_size: [u32; 3],
pub spirv_bytes: Vec<u8>,
}
impl SpirVModule {
pub fn from_wgsl(source: &str) -> Self {
let mut entry_points = Vec::new();
for line in source.lines() {
let trimmed = line.trim();
if let Some(pos) = trimmed.find("fn ") {
let rest = &trimmed[pos + 3..];
let name: String = rest
.chars()
.take_while(|c| c.is_alphanumeric() || *c == '_')
.collect();
if !name.is_empty() {
entry_points.push(name);
}
}
}
let binding_count = source.matches("@binding(").count();
let workgroup_size = parse_workgroup_size(source);
let spirv_bytes = mock_compile_to_spirv(
source,
entry_points.first().map(|s| s.as_str()).unwrap_or("main"),
);
Self {
entry_points,
binding_count,
workgroup_size,
spirv_bytes,
}
}
pub fn has_entry_point(&self, name: &str) -> bool {
self.entry_points.iter().any(|e| e == name)
}
pub fn byte_size(&self) -> usize {
self.spirv_bytes.len()
}
}
#[derive(Debug, Clone, Copy)]
pub struct PushConstantRange {
pub offset: u32,
pub size: u32,
pub stage: ShaderStage,
}
impl PushConstantRange {
pub fn new(offset: u32, size: u32, stage: ShaderStage) -> Self {
Self {
offset,
size,
stage,
}
}
pub fn fits_standard_limit(&self) -> bool {
self.offset + self.size <= 128
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AddressMode {
ClampToEdge,
Repeat,
MirrorRepeat,
}
#[derive(Debug, Clone)]
pub struct DescriptorSetLayout {
pub group: u32,
pub bindings: Vec<DescriptorBinding>,
}
impl DescriptorSetLayout {
pub fn new(group: u32) -> Self {
Self {
group,
bindings: Vec::new(),
}
}
pub fn add_storage_buffer(&mut self, binding: u32, stage: ShaderStage, read_only: bool) {
self.bindings.push(DescriptorBinding {
binding,
descriptor_type: DescriptorType::StorageBuffer,
stage,
read_only,
});
}
pub fn add_uniform_buffer(&mut self, binding: u32, stage: ShaderStage) {
self.bindings.push(DescriptorBinding {
binding,
descriptor_type: DescriptorType::UniformBuffer,
stage,
read_only: true,
});
}
pub fn add_sampler(&mut self, binding: u32, stage: ShaderStage) {
self.bindings.push(DescriptorBinding {
binding,
descriptor_type: DescriptorType::CombinedImageSampler,
stage,
read_only: true,
});
}
pub fn len(&self) -> usize {
self.bindings.len()
}
pub fn is_empty(&self) -> bool {
self.bindings.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct StorageBinding {
pub name: String,
pub binding: u32,
pub read_only: bool,
}
#[derive(Debug, Default)]
pub struct ShaderHotReloadManager {
pub(super) sources: HashMap<String, String>,
pub(super) hashes: HashMap<String, u64>,
}
impl ShaderHotReloadManager {
pub fn new() -> Self {
Self::default()
}
pub fn watch(&mut self, name: &str, source: &str) {
let hash = simple_hash(source);
self.sources.insert(name.to_string(), source.to_string());
self.hashes.insert(name.to_string(), hash);
}
pub fn unwatch(&mut self, name: &str) {
self.sources.remove(name);
self.hashes.remove(name);
}
pub fn update(&mut self, name: &str, new_source: &str) -> bool {
let new_hash = simple_hash(new_source);
if let Some(old_hash) = self.hashes.get(name)
&& *old_hash == new_hash
{
return false;
}
self.sources
.insert(name.to_string(), new_source.to_string());
self.hashes.insert(name.to_string(), new_hash);
true
}
pub fn is_watched(&self, name: &str) -> bool {
self.sources.contains_key(name)
}
pub fn get_source(&self, name: &str) -> Option<&str> {
self.sources.get(name).map(|s| s.as_str())
}
pub fn watched_names(&self) -> Vec<&str> {
self.sources.keys().map(|s| s.as_str()).collect()
}
pub fn len(&self) -> usize {
self.sources.len()
}
pub fn is_empty(&self) -> bool {
self.sources.is_empty()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DescriptorType {
UniformBuffer,
StorageBuffer,
CombinedImageSampler,
StorageImage,
}
#[derive(Debug, Default)]
pub struct ShaderMetaRegistry {
pub(super) entries: HashMap<String, ShaderMetadata>,
}
impl ShaderMetaRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, name: &str, meta: ShaderMetadata) {
self.entries.insert(name.to_string(), meta);
}
pub fn lookup(&self, name: &str) -> Option<&ShaderMetadata> {
self.entries.get(name)
}
pub fn all_names(&self) -> Vec<&str> {
self.entries.keys().map(|s| s.as_str()).collect()
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct SpecializationConstant {
pub name: String,
pub default_value: String,
pub description: String,
}
impl SpecializationConstant {
pub fn new(name: &str, default_value: &str, description: &str) -> Self {
Self {
name: name.to_string(),
default_value: default_value.to_string(),
description: description.to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct ShaderTemplate {
pub template: String,
}
impl ShaderTemplate {
pub fn new(template: impl Into<String>) -> Self {
Self {
template: template.into(),
}
}
pub fn instantiate(&self, params: &HashMap<&str, &str>) -> String {
let mut result = self.template.clone();
for (key, value) in params {
let placeholder = format!("{{{}}}", key);
result = result.replace(&placeholder, value);
}
result
}
pub fn placeholders(&self) -> Vec<String> {
let mut result = Vec::new();
let chars: Vec<char> = self.template.chars().collect();
let mut i = 0;
while i < chars.len() {
if chars[i] == '{' {
if i + 1 < chars.len() && chars[i + 1].is_ascii_uppercase() {
let start = i + 1;
let mut end = start;
while end < chars.len() && chars[end] != '}' {
end += 1;
}
if end < chars.len() {
let name: String = chars[start..end].iter().collect();
let is_valid = name
.chars()
.all(|c| c.is_ascii_uppercase() || c.is_ascii_digit() || c == '_');
if is_valid && !result.contains(&name) {
result.push(name);
}
i = end + 1;
} else {
i += 1;
}
} else {
i += 1;
}
} else {
i += 1;
}
}
result
}
pub fn all_placeholders_provided(&self, params: &HashMap<&str, &str>) -> bool {
for p in self.placeholders() {
if !params.contains_key(p.as_str()) {
return false;
}
}
true
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TextureFormat {
Rgba8Unorm,
Rgba8Srgb,
Rgba16Float,
R32Float,
Depth32Float,
Depth24PlusStencil8,
}
#[derive(Debug, Clone, Default)]
pub struct BindGroupLayout {
pub(super) uniforms: Vec<UniformBinding>,
pub(super) storages: Vec<StorageBinding>,
}
impl BindGroupLayout {
pub fn new() -> Self {
Self::default()
}
pub fn add_uniform(&mut self, name: &str, _group: u32, binding: u32, size_bytes: u32) {
self.uniforms.push(UniformBinding {
name: name.to_string(),
binding,
size_bytes,
});
}
pub fn add_storage(&mut self, name: &str, _group: u32, binding: u32, read_only: bool) {
self.storages.push(StorageBinding {
name: name.to_string(),
binding,
read_only,
});
}
pub fn binding_count(&self) -> usize {
self.uniforms.len() + self.storages.len()
}
pub fn is_empty(&self) -> bool {
self.uniforms.is_empty() && self.storages.is_empty()
}
pub fn to_wgsl_snippet(&self) -> String {
let mut out = String::new();
for u in &self.uniforms {
out.push_str(&format!(
"@group(0) @binding({}) var<uniform> {}: {};\n",
u.binding,
u.name.to_lowercase(),
u.name
));
}
for s in &self.storages {
let access = if s.read_only { "read" } else { "read_write" };
out.push_str(&format!(
"@group(0) @binding({}) var<storage, {}> {}: array<f32>;\n",
s.binding,
access,
s.name.to_lowercase()
));
}
out
}
}
#[derive(Debug, Clone)]
pub struct DepthAttachmentDesc {
pub format: TextureFormat,
pub load_op: LoadOp,
pub store_op: StoreOp,
pub clear_depth: f32,
}
pub struct ShaderCompilationPipeline {
pub(super) includes: HashMap<String, String>,
pub(super) cache: ShaderCache,
}
impl ShaderCompilationPipeline {
pub fn new() -> Self {
Self {
includes: HashMap::new(),
cache: ShaderCache::new(),
}
}
pub fn add_include(&mut self, name: &str, source: &str) {
self.includes.insert(name.to_string(), source.to_string());
}
pub fn compile(
&mut self,
name: &str,
source: &str,
spec_map: Option<&SpecializationMap>,
) -> Result<String, String> {
if let Some(cached) = self.cache.entries.get(name) {
return Ok(cached.clone());
}
let includes_ref: HashMap<&str, &str> = self
.includes
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect();
let resolved = resolve_includes(source, &includes_ref);
let specialized = if let Some(sm) = spec_map {
sm.apply(&resolved)
} else {
resolved
};
if !validate_wgsl_structure(&specialized) {
return Err(format!("shader '{}' failed structural validation", name));
}
self.cache
.entries
.insert(name.to_string(), specialized.clone());
Ok(specialized)
}
pub fn cache_size(&self) -> usize {
self.cache.len()
}
pub fn clear_cache(&mut self) {
self.cache.clear();
}
}
#[derive(Debug, Default)]
pub struct ShaderRegistry {
pub(super) shaders: HashMap<String, ComputeShaderDesc>,
}
impl ShaderRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, name: impl Into<String>, desc: ComputeShaderDesc) {
self.shaders.insert(name.into(), desc);
}
pub fn get(&self, name: &str) -> Option<&ComputeShaderDesc> {
self.shaders.get(name)
}
pub fn len(&self) -> usize {
self.shaders.len()
}
pub fn is_empty(&self) -> bool {
self.shaders.is_empty()
}
pub fn names(&self) -> impl Iterator<Item = &str> {
self.shaders.keys().map(|s| s.as_str())
}
pub fn unregister(&mut self, name: &str) -> Option<ComputeShaderDesc> {
self.shaders.remove(name)
}
pub fn contains(&self, name: &str) -> bool {
self.shaders.contains_key(name)
}
pub fn with_builtins() -> Self {
let mut reg = Self::new();
reg.register(
"sph_density",
ComputeShaderDesc::new("main", [64, 1, 1], SPH_DENSITY_WGSL),
);
reg.register(
"sph_force",
ComputeShaderDesc::new("main", [64, 1, 1], SPH_FORCE_WGSL),
);
reg.register(
"integrate",
ComputeShaderDesc::new("main", [64, 1, 1], INTEGRATE_WGSL),
);
reg.register(
"lbm_bgk_d2q9",
ComputeShaderDesc::new("main", [64, 1, 1], LBM_BGK_D2Q9_WGSL),
);
reg.register(
"lbm_streaming",
ComputeShaderDesc::new("main", [64, 1, 1], LBM_STREAMING_SHADER),
);
reg.register(
"rigid_integrate",
ComputeShaderDesc::new("main", [64, 1, 1], RIGID_INTEGRATE_SHADER),
);
reg.register(
"broadphase_sort",
ComputeShaderDesc::new("main", [64, 1, 1], BROADPHASE_SORT_SHADER),
);
reg.register(
"boundary_enforce",
ComputeShaderDesc::new("main", [64, 1, 1], BOUNDARY_ENFORCE_WGSL),
);
reg
}
}
#[derive(Debug, Clone)]
pub struct ComputeShaderDesc {
pub entry_point: String,
pub workgroup_size: [u32; 3],
pub source: String,
}
impl ComputeShaderDesc {
pub fn new(
entry_point: impl Into<String>,
workgroup_size: [u32; 3],
source: impl Into<String>,
) -> Self {
Self {
entry_point: entry_point.into(),
workgroup_size,
source: source.into(),
}
}
pub fn threads_per_workgroup(&self) -> u32 {
self.workgroup_size[0] * self.workgroup_size[1] * self.workgroup_size[2]
}
pub fn binding_count(&self) -> usize {
self.source.matches("@binding(").count()
}
}
#[derive(Debug, Clone)]
pub struct SamplerDesc {
pub filter_min: FilterMode,
pub filter_mag: FilterMode,
pub address_mode: AddressMode,
pub anisotropy: u32,
pub lod_bias: f32,
pub lod_max: f32,
}
impl SamplerDesc {
pub fn linear() -> Self {
Self {
filter_min: FilterMode::Linear,
filter_mag: FilterMode::Linear,
address_mode: AddressMode::ClampToEdge,
anisotropy: 1,
lod_bias: 0.0,
lod_max: 16.0,
}
}
pub fn nearest() -> Self {
Self {
filter_min: FilterMode::Nearest,
filter_mag: FilterMode::Nearest,
address_mode: AddressMode::ClampToEdge,
anisotropy: 1,
lod_bias: 0.0,
lod_max: 0.0,
}
}
pub fn anisotropic(max_anisotropy: u32) -> Self {
Self {
filter_min: FilterMode::Linear,
filter_mag: FilterMode::Linear,
address_mode: AddressMode::Repeat,
anisotropy: max_anisotropy,
lod_bias: 0.0,
lod_max: 16.0,
}
}
}
#[derive(Debug, Clone)]
pub struct RenderPassDesc {
pub color_attachments: Vec<ColorAttachmentDesc>,
pub depth_attachment: Option<DepthAttachmentDesc>,
pub name: String,
}
impl RenderPassDesc {
pub fn new_simple_color() -> Self {
Self {
color_attachments: vec![ColorAttachmentDesc {
format: TextureFormat::Rgba8Unorm,
load_op: LoadOp::Clear,
store_op: StoreOp::Store,
clear_color: [0.0, 0.0, 0.0, 1.0],
}],
depth_attachment: None,
name: "SimpleColor".to_string(),
}
}
pub fn new_with_depth() -> Self {
Self {
color_attachments: vec![ColorAttachmentDesc {
format: TextureFormat::Rgba16Float,
load_op: LoadOp::Clear,
store_op: StoreOp::Store,
clear_color: [0.0, 0.0, 0.0, 1.0],
}],
depth_attachment: Some(DepthAttachmentDesc {
format: TextureFormat::Depth32Float,
load_op: LoadOp::Clear,
store_op: StoreOp::Store,
clear_depth: 1.0,
}),
name: "ColorDepth".to_string(),
}
}
pub fn total_attachment_count(&self) -> usize {
self.color_attachments.len()
+ if self.depth_attachment.is_some() {
1
} else {
0
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ShaderVariant {
Physics,
Collision,
Sph,
Lbm,
RigidBody,
NeuralInference,
}
#[derive(Debug, Clone)]
pub struct ShaderMetadata {
pub variant: ShaderVariant,
pub entry_point: String,
pub workgroup_size: [u32; 3],
pub bind_group_count: u32,
}
impl ShaderMetadata {
pub fn new(
variant: ShaderVariant,
entry_point: impl Into<String>,
workgroup_size: [u32; 3],
bind_group_count: u32,
) -> Self {
Self {
variant,
entry_point: entry_point.into(),
workgroup_size,
bind_group_count,
}
}
pub fn threads_per_workgroup(&self) -> u32 {
self.workgroup_size[0] * self.workgroup_size[1] * self.workgroup_size[2]
}
}
#[derive(Debug, Default)]
pub struct ShaderCache {
pub(super) entries: HashMap<String, String>,
}
impl ShaderCache {
pub fn new() -> Self {
Self::default()
}
pub fn get_or_insert(&mut self, key: &str, compute: impl FnOnce() -> String) -> &str {
self.entries.entry(key.to_string()).or_insert_with(compute)
}
pub fn contains(&self, key: &str) -> bool {
self.entries.contains_key(key)
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn clear(&mut self) {
self.entries.clear();
}
pub fn remove(&mut self, key: &str) -> Option<String> {
self.entries.remove(key)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShaderStage {
Vertex,
Fragment,
Compute,
All,
}
#[derive(Debug, Clone)]
pub struct ShaderTemplateV2 {
pub source: String,
pub defines: HashMap<String, String>,
}
impl ShaderTemplateV2 {
pub fn new(source: impl Into<String>, defines: HashMap<String, String>) -> Self {
Self {
source: source.into(),
defines,
}
}
pub fn instantiate(&self) -> String {
let mut result = self.source.clone();
for (key, value) in &self.defines {
result = result.replace(key.as_str(), value.as_str());
}
result
}
}
#[derive(Debug, Clone)]
pub struct ColorAttachmentDesc {
pub format: TextureFormat,
pub load_op: LoadOp,
pub store_op: StoreOp,
pub clear_color: [f32; 4],
}
#[derive(Debug, Clone)]
pub struct UniformBinding {
pub name: String,
pub binding: u32,
pub size_bytes: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StoreOp {
Store,
Discard,
}
#[derive(Debug, Clone)]
pub struct DescriptorBinding {
pub binding: u32,
pub descriptor_type: DescriptorType,
pub stage: ShaderStage,
pub read_only: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FilterMode {
Nearest,
Linear,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LoadOp {
Load,
Clear,
DontCare,
}
#[derive(Debug, Clone, Default)]
pub struct SpecializationMap {
pub(super) constants: Vec<SpecializationConstant>,
pub(super) overrides: HashMap<String, String>,
}
impl SpecializationMap {
pub fn new() -> Self {
Self::default()
}
pub fn define(&mut self, name: &str, default_value: &str, description: &str) {
self.constants.push(SpecializationConstant::new(
name,
default_value,
description,
));
}
pub fn set(&mut self, name: &str, value: &str) {
self.overrides.insert(name.to_string(), value.to_string());
}
pub fn get(&self, name: &str) -> Option<&str> {
if let Some(v) = self.overrides.get(name) {
return Some(v.as_str());
}
for c in &self.constants {
if c.name == name {
return Some(c.default_value.as_str());
}
}
None
}
pub fn len(&self) -> usize {
self.constants.len()
}
pub fn is_empty(&self) -> bool {
self.constants.is_empty()
}
pub fn apply(&self, source: &str) -> String {
let mut result = source.to_string();
for c in &self.constants {
let value = self
.overrides
.get(&c.name)
.map(|s| s.as_str())
.unwrap_or(&c.default_value);
let old = format!("const {} = {};", c.name, c.default_value);
let new = format!("const {} = {};", c.name, value);
result = result.replace(&old, &new);
}
result
}
}
#[derive(Debug, Clone)]
pub struct UniformBufferDesc {
pub name: String,
pub group: u32,
pub binding: u32,
pub size_bytes: u32,
}
impl UniformBufferDesc {
pub fn new(name: &str, group: u32, binding: u32, size_bytes: u32) -> Self {
Self {
name: name.to_string(),
group,
binding,
size_bytes,
}
}
pub fn wgsl_annotation(&self) -> String {
format!(
"@group({}) @binding({}) var<uniform> {}: {};",
self.group,
self.binding,
self.name.to_lowercase(),
self.name
)
}
}
#[derive(Debug)]
pub struct BytecodeShaderCache {
pub cache: HashMap<String, Vec<u8>>,
pub(super) insertion_order: Vec<String>,
pub max_size: usize,
}
impl BytecodeShaderCache {
pub fn new(max_size: usize) -> Self {
Self {
cache: HashMap::new(),
insertion_order: Vec::new(),
max_size,
}
}
pub fn insert(&mut self, name: &str, bytecode: Vec<u8>) {
if self.cache.contains_key(name) {
self.insertion_order.retain(|k| k != name);
}
self.cache.insert(name.to_string(), bytecode);
self.insertion_order.push(name.to_string());
while self.total_bytes() > self.max_size && !self.insertion_order.is_empty() {
self.evict_oldest();
}
}
pub fn get(&self, name: &str) -> Option<&Vec<u8>> {
self.cache.get(name)
}
pub fn evict_oldest(&mut self) {
if let Some(oldest) = self.insertion_order.first().cloned() {
self.insertion_order.remove(0);
self.cache.remove(&oldest);
}
}
pub fn total_bytes(&self) -> usize {
self.cache.values().map(|v| v.len()).sum()
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
}