use wgpu::{Buffer, Queue};
use super::pipeline::{LayoutKey, PipelineCache};
use crate::dtype::DType;
use crate::error::{Error, Result};
const MODE_F32: &str = include_str!("statistics_f32.wgsl");
const MODE_I32: &str = include_str!("statistics_i32.wgsl");
const MODE_U32: &str = include_str!("statistics_u32.wgsl");
fn mode_shader_info(dtype: DType) -> Result<(&'static str, &'static str, &'static str)> {
Ok(match dtype {
DType::F32 => (MODE_F32, "statistics_f32", "mode_dim_f32"),
DType::I32 => (MODE_I32, "statistics_i32", "mode_dim_i32"),
DType::U32 => (MODE_U32, "statistics_u32", "mode_dim_u32"),
_ => {
return Err(Error::UnsupportedDType {
dtype,
op: "mode (WebGPU)",
});
}
})
}
pub fn launch_mode_dim(
cache: &PipelineCache,
queue: &Queue,
sorted: &Buffer,
mode_values: &Buffer,
mode_counts: &Buffer,
params_buffer: &Buffer,
num_outputs: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_key, entry_point) = mode_shader_info(dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 3,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout);
let bind_group =
cache.create_bind_group(&layout, &[sorted, mode_values, mode_counts, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("mode_dim"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("mode_dim"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(num_outputs as u32, 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
#[allow(dead_code)]
pub fn launch_mode_full(
cache: &PipelineCache,
queue: &Queue,
sorted: &Buffer,
mode_value: &Buffer,
mode_count: &Buffer,
numel_buffer: &Buffer,
dtype: DType,
) -> Result<()> {
let (shader, module_key, entry_point) = mode_shader_info(dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 3,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout);
let bind_group =
cache.create_bind_group(&layout, &[sorted, mode_value, mode_count, numel_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("mode_full"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("mode_full"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(1, 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}