use glam::*;
use crate::{
SelectionBuffer, SelectionOpBuffer,
core::{
self, BufferWrapper, ComputeBundle, ComputeBundleBuilder, GaussianPod,
GaussianTransformBuffer, GaussiansBuffer, ModelTransformBuffer,
},
shader,
};
#[derive(Debug, Default)]
pub enum SelectionExpr {
#[default]
Identity,
Union(Box<SelectionExpr>, Box<SelectionExpr>),
Intersection(Box<SelectionExpr>, Box<SelectionExpr>),
Difference(Box<SelectionExpr>, Box<SelectionExpr>),
SymmetricDifference(Box<SelectionExpr>, Box<SelectionExpr>),
Complement(Box<SelectionExpr>),
Unary(usize, Box<SelectionExpr>, Vec<wgpu::BindGroup>),
Binary(
Box<SelectionExpr>,
usize,
Box<SelectionExpr>,
Vec<wgpu::BindGroup>,
),
Selection(usize, Vec<wgpu::BindGroup>),
Buffer(SelectionBuffer),
}
impl SelectionExpr {
pub const CUSTOM_OP_START: u32 = 5;
pub fn identity() -> Self {
Self::Identity
}
pub fn union(self, other: Self) -> Self {
Self::Union(Box::new(self), Box::new(other))
}
pub fn intersection(self, other: Self) -> Self {
Self::Intersection(Box::new(self), Box::new(other))
}
pub fn difference(self, other: Self) -> Self {
Self::Difference(Box::new(self), Box::new(other))
}
pub fn symmetric_difference(self, other: Self) -> Self {
Self::SymmetricDifference(Box::new(self), Box::new(other))
}
pub fn complement(self) -> Self {
Self::Complement(Box::new(self))
}
pub fn unary(self, op: usize, bind_groups: Vec<wgpu::BindGroup>) -> Self {
Self::Unary(op, Box::new(self), bind_groups)
}
pub fn binary(self, op: usize, other: Self, bind_groups: Vec<wgpu::BindGroup>) -> Self {
Self::Binary(Box::new(self), op, Box::new(other), bind_groups)
}
pub fn selection(op: usize, bind_groups: Vec<wgpu::BindGroup>) -> Self {
Self::Selection(op, bind_groups)
}
pub fn buffer(buffer: SelectionBuffer) -> Self {
Self::Buffer(buffer)
}
pub fn update_with(&mut self, f: impl FnOnce(Self) -> Self) {
*self = f(std::mem::take(self));
}
pub fn as_u32(&self) -> Option<u32> {
match self {
SelectionExpr::Union(_, _) => Some(0),
SelectionExpr::Intersection(_, _) => Some(1),
SelectionExpr::SymmetricDifference(_, _) => Some(2),
SelectionExpr::Difference(_, _) => Some(3),
SelectionExpr::Complement(_) => Some(4),
SelectionExpr::Unary(op, _, _) => Some(*op as u32 + Self::CUSTOM_OP_START),
SelectionExpr::Binary(_, op, _, _) => Some(*op as u32 + Self::CUSTOM_OP_START),
SelectionExpr::Selection(op, _) => Some(*op as u32 + Self::CUSTOM_OP_START),
SelectionExpr::Buffer(_) => None,
SelectionExpr::Identity => None,
}
}
pub fn is_identity(&self) -> bool {
matches!(self, SelectionExpr::Identity)
}
pub fn is_primitive(&self) -> bool {
matches!(
self,
SelectionExpr::Union(..)
| SelectionExpr::Intersection(..)
| SelectionExpr::Difference(..)
| SelectionExpr::SymmetricDifference(..)
| SelectionExpr::Complement(..)
)
}
pub fn is_custom(&self) -> bool {
matches!(
self,
SelectionExpr::Unary(..) | SelectionExpr::Binary(..) | SelectionExpr::Selection(..)
)
}
pub fn is_operation(&self) -> bool {
matches!(
self,
SelectionExpr::Union(..)
| SelectionExpr::Intersection(..)
| SelectionExpr::Difference(..)
| SelectionExpr::SymmetricDifference(..)
| SelectionExpr::Complement(..)
| SelectionExpr::Unary(..)
| SelectionExpr::Binary(..)
| SelectionExpr::Selection(..)
)
}
pub fn is_buffer(&self) -> bool {
matches!(self, SelectionExpr::Buffer(_))
}
pub fn custom_op_index(&self) -> Option<usize> {
match self {
SelectionExpr::Unary(op, _, _)
| SelectionExpr::Binary(_, op, _, _)
| SelectionExpr::Selection(op, _) => Some(*op),
_ => None,
}
}
pub fn custom_bind_groups(&self) -> Option<&Vec<wgpu::BindGroup>> {
match self {
SelectionExpr::Unary(_, _, bind_groups) => Some(bind_groups),
SelectionExpr::Binary(_, _, _, bind_groups) => Some(bind_groups),
SelectionExpr::Selection(_, bind_groups) => Some(bind_groups),
_ => None,
}
}
pub fn custom_op_index_and_bind_groups(&self) -> Option<(usize, &Vec<wgpu::BindGroup>)> {
match self {
SelectionExpr::Unary(op, _, bind_groups)
| SelectionExpr::Binary(_, op, _, bind_groups)
| SelectionExpr::Selection(op, bind_groups) => Some((*op, bind_groups)),
_ => None,
}
}
}
#[derive(Debug)]
pub struct SelectionBundle<G: GaussianPod> {
primitive_bundle: ComputeBundle<()>,
pub bundles: Vec<ComputeBundle<()>>,
gaussian_pod_marker: std::marker::PhantomData<G>,
}
impl<G: GaussianPod> SelectionBundle<G> {
pub const GAUSSIANS_BIND_GROUP_LAYOUT_DESCRIPTOR: wgpu::BindGroupLayoutDescriptor<'static> =
wgpu::BindGroupLayoutDescriptor {
label: Some("Selection Gaussians Bind Group Layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 4,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 5,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
};
pub fn new(device: &wgpu::Device, bundles: Vec<ComputeBundle<()>>) -> Self {
let primitive_bundle = Self::create_primitive_bundle(device);
Self {
primitive_bundle,
bundles,
gaussian_pod_marker: std::marker::PhantomData,
}
}
pub fn gaussians_bind_group_layout(&self) -> &wgpu::BindGroupLayout {
&self.primitive_bundle.bind_group_layouts()[0]
}
#[allow(clippy::too_many_arguments)]
pub fn evaluate(
&self,
device: &wgpu::Device,
encoder: &mut wgpu::CommandEncoder,
expr: &SelectionExpr,
dest: &SelectionBuffer,
model_transform: &ModelTransformBuffer,
gaussian_transform: &GaussianTransformBuffer,
gaussians: &GaussiansBuffer<G>,
) {
if let SelectionExpr::Identity = expr {
return;
} else if let SelectionExpr::Buffer(buffer) = expr {
encoder.copy_buffer_to_buffer(
buffer.buffer(),
0,
dest.buffer(),
0,
dest.buffer().size(),
);
return;
}
let d = dest;
let m = model_transform;
let g = gaussian_transform;
let gs = gaussians;
let op = SelectionOpBuffer::new(device, expr.as_u32().expect("operation expression"));
let source = SelectionBuffer::new(device, gaussians.len() as u32);
match expr {
SelectionExpr::Union(l, r) => {
self.evaluate(device, encoder, l, &source, m, g, gs);
self.evaluate(device, encoder, r, d, m, g, gs);
}
SelectionExpr::Intersection(l, r) => {
self.evaluate(device, encoder, l, &source, m, g, gs);
self.evaluate(device, encoder, r, d, m, g, gs);
}
SelectionExpr::Difference(l, r) => {
self.evaluate(device, encoder, l, &source, m, g, gs);
self.evaluate(device, encoder, r, d, m, g, gs);
}
SelectionExpr::SymmetricDifference(l, r) => {
self.evaluate(device, encoder, l, &source, m, g, gs);
self.evaluate(device, encoder, r, d, m, g, gs);
}
SelectionExpr::Complement(e) => {
self.evaluate(device, encoder, e, d, m, g, gs);
}
SelectionExpr::Unary(_, e, _) => {
self.evaluate(device, encoder, e, d, m, g, gs);
}
SelectionExpr::Binary(l, _, r, _) => {
self.evaluate(device, encoder, l, &source, m, g, gs);
self.evaluate(device, encoder, r, d, m, g, gs);
}
SelectionExpr::Selection(_, _) => {}
SelectionExpr::Identity | SelectionExpr::Buffer(_) => {
unreachable!();
}
}
let gaussians_bind_group = self
.primitive_bundle
.create_bind_group(
device,
0,
[
op.buffer().as_entire_binding(),
source.buffer().as_entire_binding(),
d.buffer().as_entire_binding(),
m.buffer().as_entire_binding(),
g.buffer().as_entire_binding(),
gs.buffer().as_entire_binding(),
],
)
.expect("gaussians bind group");
match expr.custom_op_index_and_bind_groups() {
None => self.primitive_bundle.dispatch(
encoder,
(gaussians.len() as u32).div_ceil(32),
[&gaussians_bind_group],
),
Some((i, bind_groups)) => {
let bind_groups = std::iter::once(&gaussians_bind_group)
.chain(bind_groups)
.collect::<Vec<_>>();
let bundle = &self.bundles[i];
bundle.dispatch(encoder, gaussians.len() as u32, bind_groups);
}
}
}
pub fn create_primitive_bundle(device: &wgpu::Device) -> ComputeBundle<()> {
ComputeBundleBuilder::new()
.label("Selection Primitive Operations")
.bind_group_layout(&Self::GAUSSIANS_BIND_GROUP_LAYOUT_DESCRIPTOR)
.resolver({
let mut resolver = wesl::PkgResolver::new();
resolver.add_package(&core::shader::PACKAGE);
resolver.add_package(&shader::PACKAGE);
resolver
})
.main_shader(
"wgpu_3dgs_editor::selection::primitive"
.parse()
.expect("selection::primitive module path"),
)
.entry_point("main")
.wesl_compile_options(wesl::CompileOptions {
features: G::wesl_features(),
..Default::default()
})
.build_without_bind_groups(device)
.map_err(|e| log::error!("{e}"))
.expect("primitive bundle")
}
pub const SPHERE_BIND_GROUP_LAYOUT_DESCRIPTOR: wgpu::BindGroupLayoutDescriptor<'static> =
wgpu::BindGroupLayoutDescriptor {
label: Some("Sphere Selection Bind Group Layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
};
pub fn create_sphere_bundle(device: &wgpu::Device) -> ComputeBundle<()> {
let mut resolver = wesl::PkgResolver::new();
resolver.add_package(&core::shader::PACKAGE);
resolver.add_package(&shader::PACKAGE);
ComputeBundleBuilder::new()
.label("Sphere Selection")
.bind_group_layouts([
&Self::GAUSSIANS_BIND_GROUP_LAYOUT_DESCRIPTOR,
&Self::SPHERE_BIND_GROUP_LAYOUT_DESCRIPTOR,
])
.main_shader(
"wgpu_3dgs_editor::selection::sphere"
.parse()
.expect("selection::sphere module path"),
)
.entry_point("main")
.wesl_compile_options(wesl::CompileOptions {
features: G::wesl_features(),
..Default::default()
})
.resolver(resolver)
.build_without_bind_groups(device)
.map_err(|e| log::error!("{e}"))
.expect("sphere selection compute bundle")
}
pub const BOX_BIND_GROUP_LAYOUT_DESCRIPTOR: wgpu::BindGroupLayoutDescriptor<'static> =
wgpu::BindGroupLayoutDescriptor {
label: Some("Box Selection Bind Group Layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
};
pub fn create_box_bundle(device: &wgpu::Device) -> ComputeBundle<()> {
let mut resolver = wesl::PkgResolver::new();
resolver.add_package(&core::shader::PACKAGE);
resolver.add_package(&shader::PACKAGE);
ComputeBundleBuilder::new()
.label("Box Selection")
.bind_group_layouts([
&Self::GAUSSIANS_BIND_GROUP_LAYOUT_DESCRIPTOR,
&Self::BOX_BIND_GROUP_LAYOUT_DESCRIPTOR,
])
.main_shader(
"wgpu_3dgs_editor::selection::box"
.parse()
.expect("selection::box module path"),
)
.entry_point("main")
.wesl_compile_options(wesl::CompileOptions {
features: G::wesl_features(),
..Default::default()
})
.resolver(resolver)
.build_without_bind_groups(device)
.map_err(|e| log::error!("{e}"))
.expect("box selection compute bundle")
}
}