use crate::{
CameraBuffer, GaussiansDepthBuffer, IndirectArgsBuffer, IndirectIndicesBuffer,
PreprocessorCreateError, RadixSortIndirectArgsBuffer,
core::{
BufferWrapper, ComputeBundle, ComputeBundleBuilder, GaussianPod, GaussianTransformBuffer,
GaussiansBuffer, ModelTransformBuffer,
},
wesl_utils,
};
#[cfg(feature = "viewer-selection")]
use crate::{editor::SelectionBuffer, selection};
#[derive(Debug)]
pub struct Preprocessor<G: GaussianPod, B = wgpu::BindGroup> {
#[allow(dead_code)]
bind_group_layout: wgpu::BindGroupLayout,
bind_group: B,
pre_bundle: ComputeBundle<()>,
bundle: ComputeBundle<()>,
post_bundle: ComputeBundle<()>,
gaussian_pod_marker: std::marker::PhantomData<G>,
}
impl<G: GaussianPod, B> Preprocessor<G, B> {
#[allow(clippy::too_many_arguments)]
pub fn create_bind_group(
&self,
device: &wgpu::Device,
camera: &CameraBuffer,
model_transform: &ModelTransformBuffer,
gaussian_transform: &GaussianTransformBuffer,
gaussians: &GaussiansBuffer<G>,
indirect_args: &IndirectArgsBuffer,
radix_sort_indirect_args: &RadixSortIndirectArgsBuffer,
indirect_indices: &IndirectIndicesBuffer,
gaussians_depth: &GaussiansDepthBuffer,
#[cfg(feature = "viewer-selection")] selection: &SelectionBuffer,
#[cfg(feature = "viewer-selection")]
invert_selection: &selection::PreprocessorInvertSelectionBuffer,
) -> wgpu::BindGroup {
Preprocessor::create_bind_group_static(
device,
&self.bind_group_layout,
camera,
model_transform,
gaussian_transform,
gaussians,
indirect_args,
radix_sort_indirect_args,
indirect_indices,
gaussians_depth,
#[cfg(feature = "viewer-selection")]
selection,
#[cfg(feature = "viewer-selection")]
invert_selection,
)
}
pub fn workgroup_size(&self) -> u32 {
self.bundle.workgroup_size()
}
pub fn bind_group_layout(&self) -> &wgpu::BindGroupLayout {
&self.bind_group_layout
}
pub fn pre_bundle(&self) -> &ComputeBundle<()> {
&self.pre_bundle
}
pub fn bundle(&self) -> &ComputeBundle<()> {
&self.bundle
}
pub fn post_bundle(&self) -> &ComputeBundle<()> {
&self.post_bundle
}
}
impl<G: GaussianPod> Preprocessor<G> {
const LABEL: &str = "Preprocessor";
const MAIN_SHADER: &str = "wgpu_3dgs_viewer::preprocess";
pub const BIND_GROUP_LAYOUT_DESCRIPTOR: wgpu::BindGroupLayoutDescriptor<'static> =
wgpu::BindGroupLayoutDescriptor {
label: Some("Preprocessor Bind Group Layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 4,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 5,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 6,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 7,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
#[cfg(feature = "viewer-selection")]
wgpu::BindGroupLayoutEntry {
binding: 8,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
#[cfg(feature = "viewer-selection")]
wgpu::BindGroupLayoutEntry {
binding: 9,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
};
#[allow(clippy::too_many_arguments)]
pub fn new(
device: &wgpu::Device,
camera: &CameraBuffer,
model_transform: &ModelTransformBuffer,
gaussian_transform: &GaussianTransformBuffer,
gaussians: &GaussiansBuffer<G>,
indirect_args: &IndirectArgsBuffer,
radix_sort_indirect_args: &RadixSortIndirectArgsBuffer,
indirect_indices: &IndirectIndicesBuffer,
gaussians_depth: &GaussiansDepthBuffer,
#[cfg(feature = "viewer-selection")] selection: &SelectionBuffer,
#[cfg(feature = "viewer-selection")]
invert_selection: &selection::PreprocessorInvertSelectionBuffer,
) -> Result<Self, PreprocessorCreateError> {
if (device.limits().max_storage_buffer_binding_size as wgpu::BufferAddress)
< gaussians.buffer().size()
{
return Err(PreprocessorCreateError::ModelSizeExceedsDeviceLimit {
model_size: gaussians.buffer().size(),
device_limit: device.limits().max_storage_buffer_binding_size,
});
}
let this = Preprocessor::new_without_bind_group(device)?;
log::debug!("Creating preprocessor bind group");
let bind_group = this.create_bind_group(
device,
camera,
model_transform,
gaussian_transform,
gaussians,
indirect_args,
radix_sort_indirect_args,
indirect_indices,
gaussians_depth,
#[cfg(feature = "viewer-selection")]
selection,
#[cfg(feature = "viewer-selection")]
invert_selection,
);
Ok(Self {
bind_group_layout: this.bind_group_layout,
bind_group,
pre_bundle: this.pre_bundle,
bundle: this.bundle,
post_bundle: this.post_bundle,
gaussian_pod_marker: std::marker::PhantomData,
})
}
pub fn bind_group(&self) -> &wgpu::BindGroup {
&self.bind_group
}
pub fn preprocess(&self, encoder: &mut wgpu::CommandEncoder, gaussian_count: u32) {
self.pre_bundle.dispatch(encoder, 1, [&self.bind_group]);
self.bundle
.dispatch(encoder, gaussian_count, [&self.bind_group]);
self.post_bundle.dispatch(encoder, 1, [&self.bind_group]);
}
#[allow(clippy::too_many_arguments)]
fn create_bind_group_static(
device: &wgpu::Device,
bind_group_layout: &wgpu::BindGroupLayout,
camera: &CameraBuffer,
model_transform: &ModelTransformBuffer,
gaussian_transform: &GaussianTransformBuffer,
gaussians: &GaussiansBuffer<G>,
indirect_args: &IndirectArgsBuffer,
radix_sort_indirect_args: &RadixSortIndirectArgsBuffer,
indirect_indices: &IndirectIndicesBuffer,
gaussians_depth: &GaussiansDepthBuffer,
#[cfg(feature = "viewer-selection")] selection: &SelectionBuffer,
#[cfg(feature = "viewer-selection")]
invert_selection: &selection::PreprocessorInvertSelectionBuffer,
) -> wgpu::BindGroup {
device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Preprocessor Bind Group"),
layout: bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: camera.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: model_transform.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: gaussian_transform.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: gaussians.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: indirect_args.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 5,
resource: radix_sort_indirect_args.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 6,
resource: indirect_indices.buffer().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 7,
resource: gaussians_depth.buffer().as_entire_binding(),
},
#[cfg(feature = "viewer-selection")]
wgpu::BindGroupEntry {
binding: 8,
resource: selection.buffer().as_entire_binding(),
},
#[cfg(feature = "viewer-selection")]
wgpu::BindGroupEntry {
binding: 9,
resource: invert_selection.buffer().as_entire_binding(),
},
],
})
}
}
impl<G: GaussianPod> Preprocessor<G, ()> {
pub fn new_without_bind_group(device: &wgpu::Device) -> Result<Self, PreprocessorCreateError> {
let main_shader: wesl::ModulePath = Preprocessor::<G>::MAIN_SHADER
.parse()
.expect("preprocess module path");
let wesl_compile_options = wesl::CompileOptions {
features: wesl::Features {
flags: G::features()
.into_iter()
.chain(std::iter::once((
"selection_buffer",
cfg!(feature = "viewer-selection"),
)))
.map(|(k, v)| (k.to_string(), v.into()))
.collect(),
..Default::default()
},
..Default::default()
};
let bind_group_layout =
device.create_bind_group_layout(&Preprocessor::<G>::BIND_GROUP_LAYOUT_DESCRIPTOR);
let pre_bundle = ComputeBundleBuilder::new()
.label(format!("Pre {}", Preprocessor::<G>::LABEL).as_str())
.bind_group_layout(&Preprocessor::<G>::BIND_GROUP_LAYOUT_DESCRIPTOR)
.entry_point("pre")
.main_shader(main_shader.clone())
.wesl_compile_options(wesl_compile_options.clone())
.resolver(wesl_utils::resolver())
.build_without_bind_groups(device)?;
let bundle = ComputeBundleBuilder::new()
.label(Preprocessor::<G>::LABEL)
.bind_group_layout(&Preprocessor::<G>::BIND_GROUP_LAYOUT_DESCRIPTOR)
.entry_point("main")
.main_shader(main_shader.clone())
.wesl_compile_options(wesl_compile_options.clone())
.resolver(wesl_utils::resolver())
.build_without_bind_groups(device)?;
let post_bundle = ComputeBundleBuilder::new()
.label(format!("Post {}", Preprocessor::<G>::LABEL).as_str())
.bind_group_layout(&Preprocessor::<G>::BIND_GROUP_LAYOUT_DESCRIPTOR)
.entry_point("post")
.main_shader(main_shader)
.wesl_compile_options(wesl_compile_options)
.resolver(wesl_utils::resolver())
.build_without_bind_groups(device)?;
log::info!("Preprocessor created");
Ok(Self {
bind_group_layout,
bind_group: (),
pre_bundle,
bundle,
post_bundle,
gaussian_pod_marker: std::marker::PhantomData,
})
}
pub fn preprocess(
&self,
encoder: &mut wgpu::CommandEncoder,
bind_group: &wgpu::BindGroup,
gaussian_count: u32,
) {
self.pre_bundle.dispatch(encoder, 1, [bind_group]);
self.bundle.dispatch(encoder, gaussian_count, [bind_group]);
self.post_bundle.dispatch(encoder, 1, [bind_group]);
}
}