use crate::{render_types::GpuContext, wgpu};
use crate::wgpu::util::DeviceExt;
use std::num::NonZeroU64;
use std::sync::Arc;
use futures::FutureExt;
use futures_intrusive::channel::shared::oneshot_channel;
use encase::{ShaderType, UniformBuffer};
#[derive(Debug, Clone)]
pub enum ReduceMode {
Min,
Max,
Sum,
Extent,
Histogram {
num_bins: u32,
data_min: f32,
data_max: f32,
},
}
impl ReduceMode {
fn discriminant(&self) -> u32 {
match self {
ReduceMode::Min => 0,
ReduceMode::Max => 1,
ReduceMode::Sum => 2,
ReduceMode::Extent => 3,
ReduceMode::Histogram { .. } => 4,
}
}
fn is_histogram(&self) -> bool {
matches!(self, ReduceMode::Histogram { .. })
}
}
#[derive(ShaderType)]
struct ReduceUniforms {
mode: u32,
num_elements: u32,
num_bins: u32,
data_min: f32,
data_max: f32,
}
pub async fn compute_reduce(
gpu_context: &GpuContext<'_>,
input_arr: Arc<Vec<f32>>,
mode: ReduceMode,
) -> Vec<f32> {
let GpuContext { device, queue } = gpu_context;
let is_histogram = mode.is_histogram();
let workgroup_count = input_arr.len().div_ceil(64);
let shader = device.create_shader_module(wgpu::include_wgsl!("shaders/reduce.wgsl"));
let (num_bins, data_min, data_max) = match &mode {
ReduceMode::Histogram { num_bins, data_min, data_max } => (*num_bins, *data_min, *data_max),
_ => (0, 0.0, 0.0),
};
let uniforms = ReduceUniforms {
mode: mode.discriminant(),
num_elements: input_arr.len() as u32,
num_bins,
data_min,
data_max,
};
let mut buffer = UniformBuffer::new(Vec::<u8>::new());
buffer.write(&uniforms).unwrap();
let uniform_bytes = buffer.into_inner();
let uniform_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("reduce_uniforms"),
size: uniform_bytes.len() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
queue.write_buffer(&uniform_buffer, 0, &uniform_bytes);
let input_data_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("reduce_input"),
contents: bytemuck::cast_slice(&input_arr),
usage: wgpu::BufferUsages::STORAGE,
});
let output_size_bytes: u64 = if is_histogram {
(num_bins as u64) * 4
} else if matches!(mode, ReduceMode::Extent) {
(workgroup_count as u64) * 2 * 4
} else {
(workgroup_count as u64) * 4
};
let output_data_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("reduce_output"),
size: output_size_bytes,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let download_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("reduce_download"),
size: output_size_bytes,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let uniform_entry = wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
min_binding_size: NonZeroU64::new(uniform_bytes.len() as u64),
has_dynamic_offset: false,
},
count: None,
};
let input_entry = wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
min_binding_size: Some(NonZeroU64::new(4).unwrap()),
has_dynamic_offset: false,
},
count: None,
};
let output_entry = wgpu::BindGroupLayoutEntry {
binding: if is_histogram { 3 } else { 2 },
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
min_binding_size: Some(NonZeroU64::new(4).unwrap()),
has_dynamic_offset: false,
},
count: None,
};
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: None,
entries: &[uniform_entry, input_entry, output_entry],
});
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: uniform_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: input_data_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: if is_histogram { 3 } else { 2 },
resource: output_data_buffer.as_entire_binding(),
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[Some(&bind_group_layout)],
immediate_size: 0,
});
let entry_point = if is_histogram { "main_histogram" } else { "main_scalar" };
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some(entry_point),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
});
let mut encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
compute_pass.set_pipeline(&pipeline);
compute_pass.set_bind_group(0, &bind_group, &[]);
compute_pass.dispatch_workgroups(workgroup_count as u32, 1, 1);
}
encoder.copy_buffer_to_buffer(&output_data_buffer, 0, &download_buffer, 0, output_size_bytes);
queue.submit([encoder.finish()]);
let buffer_slice = download_buffer.slice(..);
#[cfg(target_arch = "wasm32")]
{
let (sender, receiver) = oneshot_channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |res| {
if res.is_err() {
panic!("Failed to map buffer for reading");
}
sender.send(res).ok();
});
let _ = device.poll(wgpu::PollType::Poll);
receiver.receive().await.unwrap().unwrap();
}
#[cfg(not(target_arch = "wasm32"))]
{
buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
if result.is_err() {
panic!("Failed to map buffer for reading");
}
});
let _ = device.poll(wgpu::PollType::wait_indefinitely());
}
let data = buffer_slice.get_mapped_range();
bytemuck::allocation::pod_collect_to_vec(&data)
}
fn cpu_reduce_min(input: &[f32]) -> f32 {
input.iter().copied().fold(f32::INFINITY, f32::min)
}
fn cpu_reduce_max(input: &[f32]) -> f32 {
input.iter().copied().fold(f32::NEG_INFINITY, f32::max)
}
fn cpu_reduce_sum(input: &[f32]) -> f32 {
input.iter().copied().sum()
}
fn cpu_reduce_extent(input: &[f32]) -> (f32, f32) {
input.iter().copied().fold(
(f32::INFINITY, f32::NEG_INFINITY),
|(lo, hi), v| (f32::min(lo, v), f32::max(hi, v)),
)
}
fn cpu_reduce_histogram(input: &[f32], num_bins: u32, data_min: f32, data_max: f32) -> Vec<u32> {
let mut bins = vec![0u32; num_bins as usize];
let range = data_max - data_min;
for &v in input {
let bin = if range <= 0.0 {
0
} else {
let t = (v - data_min) / range;
(t * num_bins as f32).clamp(0.0, (num_bins - 1) as f32) as u32
};
bins[bin as usize] += 1;
}
bins
}
pub async fn reduce_min(gpu_context: Option<&GpuContext<'_>>, input_arr: Arc<Vec<f32>>) -> f32 {
match gpu_context {
Some(ctx) => {
let partials = compute_reduce(ctx, input_arr, ReduceMode::Min).await;
partials.into_iter().fold(f32::INFINITY, f32::min)
}
None => cpu_reduce_min(&input_arr),
}
}
pub async fn reduce_max(gpu_context: Option<&GpuContext<'_>>, input_arr: Arc<Vec<f32>>) -> f32 {
match gpu_context {
Some(ctx) => {
let partials = compute_reduce(ctx, input_arr, ReduceMode::Max).await;
partials.into_iter().fold(f32::NEG_INFINITY, f32::max)
}
None => cpu_reduce_max(&input_arr),
}
}
pub async fn reduce_sum(gpu_context: Option<&GpuContext<'_>>, input_arr: Arc<Vec<f32>>) -> f32 {
match gpu_context {
Some(ctx) => {
let partials = compute_reduce(ctx, input_arr, ReduceMode::Sum).await;
partials.into_iter().sum()
}
None => cpu_reduce_sum(&input_arr),
}
}
pub async fn reduce_extent(
gpu_context: Option<&GpuContext<'_>>,
input_arr: Arc<Vec<f32>>,
) -> (f32, f32) {
match gpu_context {
Some(ctx) => {
let partials = compute_reduce(ctx, input_arr, ReduceMode::Extent).await;
let global_min = partials.iter().copied().step_by(2).fold(f32::INFINITY, f32::min);
let global_max = partials.iter().copied().skip(1).step_by(2).fold(f32::NEG_INFINITY, f32::max);
(global_min, global_max)
}
None => cpu_reduce_extent(&input_arr),
}
}
pub async fn reduce_histogram_with_known_extent(
gpu_context: Option<&GpuContext<'_>>,
input_arr: Arc<Vec<f32>>,
num_bins: u32,
data_min: f32,
data_max: f32,
) -> Vec<u32> {
match gpu_context {
Some(ctx) => {
let mode = ReduceMode::Histogram { num_bins, data_min, data_max };
let raw = compute_reduce(ctx, input_arr, mode).await;
bytemuck::cast_slice::<f32, u32>(&raw).to_vec()
}
None => cpu_reduce_histogram(&input_arr, num_bins, data_min, data_max),
}
}
pub async fn reduce_histogram_with_unknown_extent(
gpu_context: Option<&GpuContext<'_>>,
input_arr: Arc<Vec<f32>>,
num_bins: u32,
) -> Vec<u32> {
let (data_min, data_max) = reduce_extent(gpu_context, Arc::clone(&input_arr)).await;
reduce_histogram_with_known_extent(gpu_context, input_arr, num_bins, data_min, data_max).await
}