use nalgebra::Vector3;
use wgpu::{
BindGroup, BindGroupLayoutDescriptor, BindGroupLayoutEntry, ComputePassDescriptor,
ComputePipelineDescriptor, PipelineCompilationOptions, PipelineLayoutDescriptor, ShaderModule,
ShaderModuleDescriptor, ShaderStages,
};
use crate::{
bindings::{Bindable, BindableResourceId},
gpu::Gpu,
misc::ids::PipelineId,
};
use super::PipelineStatus;
pub struct ComputePipeline {
gpu: Gpu,
id: PipelineId,
pipeline: wgpu::ComputePipeline,
entries: Vec<BindableResourceId>,
bind_group: BindGroup,
}
pub struct ComputePipelineBuilder {
gpu: Gpu,
module: ShaderModule,
bind_group_layout: Vec<BindGroupLayoutEntry>,
entries: Vec<BindableResourceId>,
}
impl ComputePipeline {
pub fn dispatch(&mut self, workgroups: Vector3<u32>) {
self.dispatch_inner(workgroups, true);
}
pub fn dispatch_callback(
&mut self,
workgroups: Vector3<u32>,
callback: impl FnOnce() + Send + 'static,
) {
self.dispatch_callback_inner(workgroups, callback, true);
}
pub fn queue_dispatch(&mut self, workgroups: Vector3<u32>) {
self.dispatch_inner(workgroups, false);
}
pub fn queue_dispatch_callback(
&mut self,
workgroups: Vector3<u32>,
callback: impl FnOnce() + Send + 'static,
) {
self.dispatch_callback_inner(workgroups, callback, false);
}
fn recreate_bind_group(&mut self) {
if self.gpu.binding_manager.get_pipeline(self.id).dirty {
self.bind_group = self.gpu.binding_manager.create_bind_group(
&self.gpu.device,
&self.pipeline.get_bind_group_layout(0),
&self.entries,
)
}
}
fn dispatch_inner(&mut self, workgroups: Vector3<u32>, immediate: bool) {
self.recreate_bind_group();
self.gpu.dispach(
|encoder| {
let mut compute_pass =
encoder.begin_compute_pass(&ComputePassDescriptor::default());
compute_pass.set_pipeline(&self.pipeline);
compute_pass.set_bind_group(0, Some(&self.bind_group), &[]);
compute_pass.dispatch_workgroups(workgroups.x, workgroups.y, workgroups.z);
},
immediate,
);
}
fn dispatch_callback_inner(
&mut self,
workgroups: Vector3<u32>,
callback: impl FnOnce() + Send + 'static,
immediate: bool,
) {
self.recreate_bind_group();
self.gpu.dispach_callback(
|encoder| {
let mut compute_pass =
encoder.begin_compute_pass(&ComputePassDescriptor::default());
compute_pass.set_pipeline(&self.pipeline);
compute_pass.set_bind_group(0, Some(&self.bind_group), &[]);
compute_pass.dispatch_workgroups(workgroups.x, workgroups.y, workgroups.z);
},
callback,
immediate,
);
}
}
impl ComputePipelineBuilder {
pub fn bind(mut self, entry: &impl Bindable) -> Self {
self.entries.push(entry.resource_id());
self.bind_group_layout.push(BindGroupLayoutEntry {
binding: self.bind_group_layout.len() as u32,
visibility: ShaderStages::COMPUTE,
ty: entry.binding_type(),
count: entry.count(),
});
self
}
pub fn finish(self) -> ComputePipeline {
let device = &self.gpu.device;
let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: None,
entries: &self.bind_group_layout,
});
let layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
label: None,
layout: Some(&layout),
module: &self.module,
entry_point: Some("main"),
compilation_options: PipelineCompilationOptions::default(),
cache: None,
});
let id = PipelineId::new();
self.gpu.binding_manager.add_pipeline(
id,
PipelineStatus {
resources: self.entries.clone(),
dirty: false,
},
);
ComputePipeline {
id,
bind_group: self.gpu.binding_manager.create_bind_group(
&self.gpu.device,
&pipeline.get_bind_group_layout(0),
&self.entries,
),
gpu: self.gpu,
entries: self.entries,
pipeline,
}
}
}
impl Gpu {
pub fn compute_pipeline(&self, source: ShaderModuleDescriptor) -> ComputePipelineBuilder {
let module = self.device.create_shader_module(source);
ComputePipelineBuilder {
gpu: self.clone(),
module,
bind_group_layout: Vec::new(),
entries: Vec::new(),
}
}
}
impl Drop for ComputePipeline {
fn drop(&mut self) {
self.gpu.binding_manager.remove_pipeline(self.id);
}
}