use wgpu::{Buffer, Queue};
use super::pipeline::{LayoutKey, PipelineCache, workgroup_count};
use crate::dtype::DType;
use crate::error::{Error, Result};
const INDEX_SELECT_SHADER_F32: &str = include_str!("index_select_f32.wgsl");
const INDEX_SELECT_SHADER_I32: &str = include_str!("index_select_i32.wgsl");
const INDEX_SELECT_SHADER_U32: &str = include_str!("index_select_u32.wgsl");
const INDEX_PUT_SHADER_F32: &str = include_str!("index_put_f32.wgsl");
const INDEX_PUT_SHADER_I32: &str = include_str!("index_put_i32.wgsl");
const INDEX_PUT_SHADER_U32: &str = include_str!("index_put_u32.wgsl");
const GATHER_SHADER_F32: &str = include_str!("gather_f32.wgsl");
const GATHER_SHADER_I32: &str = include_str!("gather_i32.wgsl");
const GATHER_SHADER_U32: &str = include_str!("gather_u32.wgsl");
const SCATTER_SHADER_F32: &str = include_str!("scatter_f32.wgsl");
const SCATTER_SHADER_I32: &str = include_str!("scatter_i32.wgsl");
const SCATTER_SHADER_U32: &str = include_str!("scatter_u32.wgsl");
const MASKED_FILL_SHADER_F32: &str = include_str!("masked_fill_f32.wgsl");
const MASKED_FILL_SHADER_I32: &str = include_str!("masked_fill_i32.wgsl");
const MASKED_FILL_SHADER_U32: &str = include_str!("masked_fill_u32.wgsl");
const MASKED_SELECT_SHADER_F32: &str = include_str!("masked_select_f32.wgsl");
const MASKED_SELECT_SHADER_I32: &str = include_str!("masked_select_i32.wgsl");
const MASKED_SELECT_SHADER_U32: &str = include_str!("masked_select_u32.wgsl");
const EMBEDDING_LOOKUP_SHADER_F32: &str = include_str!("embedding_lookup_f32.wgsl");
const EMBEDDING_LOOKUP_SHADER_I32: &str = include_str!("embedding_lookup_i32.wgsl");
const EMBEDDING_LOOKUP_SHADER_U32: &str = include_str!("embedding_lookup_u32.wgsl");
const GATHER_ND_SHADER_F32: &str = include_str!("gather_nd_f32.wgsl");
const GATHER_ND_SHADER_I32: &str = include_str!("gather_nd_i32.wgsl");
const GATHER_ND_SHADER_U32: &str = include_str!("gather_nd_u32.wgsl");
const SCATTER_REDUCE_SUM_SHADER_F32: &str = include_str!("scatter_reduce_sum_f32.wgsl");
const SCATTER_REDUCE_SUM_SHADER_I32: &str = include_str!("scatter_reduce_sum_i32.wgsl");
const SCATTER_REDUCE_SUM_SHADER_U32: &str = include_str!("scatter_reduce_sum_u32.wgsl");
const SCATTER_REDUCE_MAX_SHADER_F32: &str = include_str!("scatter_reduce_max_f32.wgsl");
const SCATTER_REDUCE_MAX_SHADER_I32: &str = include_str!("scatter_reduce_max_i32.wgsl");
const SCATTER_REDUCE_MAX_SHADER_U32: &str = include_str!("scatter_reduce_max_u32.wgsl");
const SCATTER_REDUCE_MIN_SHADER_F32: &str = include_str!("scatter_reduce_min_f32.wgsl");
const SCATTER_REDUCE_MIN_SHADER_I32: &str = include_str!("scatter_reduce_min_i32.wgsl");
const SCATTER_REDUCE_MIN_SHADER_U32: &str = include_str!("scatter_reduce_min_u32.wgsl");
const SCATTER_REDUCE_PROD_SHADER_F32: &str = include_str!("scatter_reduce_prod_f32.wgsl");
const SCATTER_REDUCE_PROD_SHADER_I32: &str = include_str!("scatter_reduce_prod_i32.wgsl");
const SCATTER_REDUCE_PROD_SHADER_U32: &str = include_str!("scatter_reduce_prod_u32.wgsl");
const SCATTER_REDUCE_COUNT_SHADER_F32: &str = include_str!("scatter_reduce_count_f32.wgsl");
const SCATTER_REDUCE_MEAN_DIV_SHADER_F32: &str = include_str!("scatter_reduce_mean_div_f32.wgsl");
const SLICE_ASSIGN_SHADER_F32: &str = include_str!("slice_assign_f32.wgsl");
const SLICE_ASSIGN_SHADER_I32: &str = include_str!("slice_assign_i32.wgsl");
const SLICE_ASSIGN_SHADER_U32: &str = include_str!("slice_assign_u32.wgsl");
const GATHER_2D_SHADER_F32: &str = include_str!("gather_2d_f32.wgsl");
const GATHER_2D_SHADER_I32: &str = include_str!("gather_2d_i32.wgsl");
const GATHER_2D_SHADER_U32: &str = include_str!("gather_2d_u32.wgsl");
const VALIDATE_INDICES_SHADER: &str = include_str!("validate_indices.wgsl");
const BINCOUNT_UNWEIGHTED_SHADER: &str = include_str!("bincount_i32.wgsl");
const BINCOUNT_WEIGHTED_SHADER_F32: &str = include_str!("bincount_weighted_f32.wgsl");
fn shader_info(
op: &'static str,
dtype: DType,
) -> Result<(&'static str, &'static str, &'static str)> {
Ok(match (op, dtype) {
("index_select", DType::F32) => (
INDEX_SELECT_SHADER_F32,
"index_select_f32",
"index_select_f32",
),
("index_select", DType::I32) => (
INDEX_SELECT_SHADER_I32,
"index_select_i32",
"index_select_i32",
),
("index_select", DType::U32) => (
INDEX_SELECT_SHADER_U32,
"index_select_u32",
"index_select_u32",
),
("index_put", DType::F32) => (INDEX_PUT_SHADER_F32, "index_put_f32", "index_put_f32"),
("index_put", DType::I32) => (INDEX_PUT_SHADER_I32, "index_put_i32", "index_put_i32"),
("index_put", DType::U32) => (INDEX_PUT_SHADER_U32, "index_put_u32", "index_put_u32"),
("gather", DType::F32) => (GATHER_SHADER_F32, "gather_f32", "gather_f32"),
("gather", DType::I32) => (GATHER_SHADER_I32, "gather_i32", "gather_i32"),
("gather", DType::U32) => (GATHER_SHADER_U32, "gather_u32", "gather_u32"),
("scatter", DType::F32) => (SCATTER_SHADER_F32, "scatter_f32", "scatter_f32"),
("scatter", DType::I32) => (SCATTER_SHADER_I32, "scatter_i32", "scatter_i32"),
("scatter", DType::U32) => (SCATTER_SHADER_U32, "scatter_u32", "scatter_u32"),
("copy", DType::F32) => (SCATTER_SHADER_F32, "scatter_f32", "copy_f32"),
("copy", DType::I32) => (SCATTER_SHADER_I32, "scatter_i32", "copy_i32"),
("copy", DType::U32) => (SCATTER_SHADER_U32, "scatter_u32", "copy_u32"),
("masked_fill", DType::F32) => {
(MASKED_FILL_SHADER_F32, "masked_fill_f32", "masked_fill_f32")
}
("masked_fill", DType::I32) => {
(MASKED_FILL_SHADER_I32, "masked_fill_i32", "masked_fill_i32")
}
("masked_fill", DType::U32) => {
(MASKED_FILL_SHADER_U32, "masked_fill_u32", "masked_fill_u32")
}
("masked_select", DType::F32) => (
MASKED_SELECT_SHADER_F32,
"masked_select_f32",
"masked_select_f32",
),
("masked_select", DType::I32) => (
MASKED_SELECT_SHADER_I32,
"masked_select_i32",
"masked_select_i32",
),
("masked_select", DType::U32) => (
MASKED_SELECT_SHADER_U32,
"masked_select_u32",
"masked_select_u32",
),
("masked_count", DType::F32) => (
MASKED_SELECT_SHADER_F32,
"masked_select_f32",
"masked_count",
),
("masked_count", DType::I32) => (
MASKED_SELECT_SHADER_I32,
"masked_select_i32",
"masked_count",
),
("masked_count", DType::U32) => (
MASKED_SELECT_SHADER_U32,
"masked_select_u32",
"masked_count",
),
("masked_prefix_sum", DType::F32) => (
MASKED_SELECT_SHADER_F32,
"masked_select_f32",
"masked_prefix_sum",
),
("masked_prefix_sum", DType::I32) => (
MASKED_SELECT_SHADER_I32,
"masked_select_i32",
"masked_prefix_sum",
),
("masked_prefix_sum", DType::U32) => (
MASKED_SELECT_SHADER_U32,
"masked_select_u32",
"masked_prefix_sum",
),
("embedding_lookup", DType::F32) => (
EMBEDDING_LOOKUP_SHADER_F32,
"embedding_lookup_f32",
"embedding_lookup_f32",
),
("embedding_lookup", DType::I32) => (
EMBEDDING_LOOKUP_SHADER_I32,
"embedding_lookup_i32",
"embedding_lookup_i32",
),
("embedding_lookup", DType::U32) => (
EMBEDDING_LOOKUP_SHADER_U32,
"embedding_lookup_u32",
"embedding_lookup_u32",
),
("gather_nd", DType::F32) => (GATHER_ND_SHADER_F32, "gather_nd_f32", "gather_nd_f32"),
("gather_nd", DType::I32) => (GATHER_ND_SHADER_I32, "gather_nd_i32", "gather_nd_i32"),
("gather_nd", DType::U32) => (GATHER_ND_SHADER_U32, "gather_nd_u32", "gather_nd_u32"),
("scatter_reduce_sum", DType::F32) => (
SCATTER_REDUCE_SUM_SHADER_F32,
"scatter_reduce_sum_f32",
"scatter_reduce_sum_f32",
),
("scatter_reduce_sum", DType::I32) => (
SCATTER_REDUCE_SUM_SHADER_I32,
"scatter_reduce_sum_i32",
"scatter_reduce_sum_i32",
),
("scatter_reduce_sum", DType::U32) => (
SCATTER_REDUCE_SUM_SHADER_U32,
"scatter_reduce_sum_u32",
"scatter_reduce_sum_u32",
),
("scatter_reduce_max", DType::F32) => (
SCATTER_REDUCE_MAX_SHADER_F32,
"scatter_reduce_max_f32",
"scatter_reduce_max_f32",
),
("scatter_reduce_max", DType::I32) => (
SCATTER_REDUCE_MAX_SHADER_I32,
"scatter_reduce_max_i32",
"scatter_reduce_max_i32",
),
("scatter_reduce_max", DType::U32) => (
SCATTER_REDUCE_MAX_SHADER_U32,
"scatter_reduce_max_u32",
"scatter_reduce_max_u32",
),
("scatter_reduce_min", DType::F32) => (
SCATTER_REDUCE_MIN_SHADER_F32,
"scatter_reduce_min_f32",
"scatter_reduce_min_f32",
),
("scatter_reduce_min", DType::I32) => (
SCATTER_REDUCE_MIN_SHADER_I32,
"scatter_reduce_min_i32",
"scatter_reduce_min_i32",
),
("scatter_reduce_min", DType::U32) => (
SCATTER_REDUCE_MIN_SHADER_U32,
"scatter_reduce_min_u32",
"scatter_reduce_min_u32",
),
("scatter_reduce_prod", DType::F32) => (
SCATTER_REDUCE_PROD_SHADER_F32,
"scatter_reduce_prod_f32",
"scatter_reduce_prod_f32",
),
("scatter_reduce_prod", DType::I32) => (
SCATTER_REDUCE_PROD_SHADER_I32,
"scatter_reduce_prod_i32",
"scatter_reduce_prod_i32",
),
("scatter_reduce_prod", DType::U32) => (
SCATTER_REDUCE_PROD_SHADER_U32,
"scatter_reduce_prod_u32",
"scatter_reduce_prod_u32",
),
("scatter_reduce_count", DType::F32) => (
SCATTER_REDUCE_COUNT_SHADER_F32,
"scatter_reduce_count_f32",
"scatter_reduce_count_f32",
),
("scatter_reduce_mean_div", DType::F32) => (
SCATTER_REDUCE_MEAN_DIV_SHADER_F32,
"scatter_reduce_mean_div_f32",
"scatter_reduce_mean_div_f32",
),
("slice_assign", DType::F32) => (
SLICE_ASSIGN_SHADER_F32,
"slice_assign_f32",
"slice_assign_f32",
),
("slice_assign", DType::I32) => (
SLICE_ASSIGN_SHADER_I32,
"slice_assign_i32",
"slice_assign_i32",
),
("slice_assign", DType::U32) => (
SLICE_ASSIGN_SHADER_U32,
"slice_assign_u32",
"slice_assign_u32",
),
("gather_2d", DType::F32) => (GATHER_2D_SHADER_F32, "gather_2d_f32", "gather_2d_f32"),
("gather_2d", DType::I32) => (GATHER_2D_SHADER_I32, "gather_2d_i32", "gather_2d_i32"),
("gather_2d", DType::U32) => (GATHER_2D_SHADER_U32, "gather_2d_u32", "gather_2d_u32"),
_ => return Err(Error::UnsupportedDType { dtype, op }),
})
}
pub fn launch_index_select(
cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
indices: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
total_output: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_key, entry_point) = shader_info("index_select", dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 3,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[input, indices, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("index_select"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("index_select"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(total_output), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_index_put(
cache: &PipelineCache,
queue: &Queue,
indices: &Buffer,
src: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
total_src: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_key, entry_point) = shader_info("index_put", dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 3,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[indices, src, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("index_put"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("index_put"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(total_src), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_validate_indices(
cache: &PipelineCache,
queue: &Queue,
indices: &Buffer,
error_count: &Buffer,
params_buffer: &Buffer,
index_len: usize,
) -> Result<()> {
if index_len == 0 {
return Ok(());
}
let module = cache.get_or_create_module("validate_indices", VALIDATE_INDICES_SHADER);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline =
cache.get_or_create_pipeline("validate_indices", "validate_indices", &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[indices, error_count, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("validate_indices"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("validate_indices"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(index_len), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_gather(
cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
indices: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
total_elements: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_key, entry_point) = shader_info("gather", dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 3,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[input, indices, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("gather"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("gather"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(total_elements), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_copy(
cache: &PipelineCache,
queue: &Queue,
src: &Buffer,
dst: &Buffer,
params_buffer: &Buffer,
numel: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_key, entry_point) = shader_info("copy", dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[src, dst, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("copy"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("copy"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(numel), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_scatter(
cache: &PipelineCache,
queue: &Queue,
src: &Buffer,
indices: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
src_total: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_key, entry_point) = shader_info("scatter", dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 3,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[src, indices, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("scatter"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("scatter"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(src_total), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_masked_fill(
cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
mask: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
numel: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_key, entry_point) = shader_info("masked_fill", dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 3,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[input, mask, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("masked_fill"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("masked_fill"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(numel), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_masked_count(
cache: &PipelineCache,
queue: &Queue,
mask: &Buffer,
count_result: &Buffer,
params_buffer: &Buffer,
numel: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_key, entry_point) = shader_info("masked_count", dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[mask, count_result, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("masked_count"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("masked_count"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(numel), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_masked_prefix_sum(
cache: &PipelineCache,
queue: &Queue,
mask: &Buffer,
prefix_sum: &Buffer,
params_buffer: &Buffer,
_numel: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_key, entry_point) = shader_info("masked_prefix_sum", dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[mask, prefix_sum, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("masked_prefix_sum"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("masked_prefix_sum"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(1, 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_masked_select(
cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
mask: &Buffer,
prefix_sum: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
numel: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_key, entry_point) = shader_info("masked_select", dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 4,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout);
let bind_group =
cache.create_bind_group(&layout, &[input, mask, prefix_sum, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("masked_select"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("masked_select"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(numel), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_gather_nd(
cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
indices: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
total_output: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_key, entry_point) = shader_info("gather_nd", dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 4,
num_uniform_buffers: 0,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[input, indices, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("gather_nd"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("gather_nd"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(total_output), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_bincount(
cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
weights: Option<&Buffer>,
output: &Buffer,
params_buffer: &Buffer,
n: usize,
weights_dtype: Option<DType>,
) -> Result<()> {
let (name, shader) = if let Some(dtype) = weights_dtype {
if dtype != DType::F32 {
return Err(Error::UnsupportedDType {
dtype,
op: "bincount_weighted",
});
}
("bincount_weighted_f32", BINCOUNT_WEIGHTED_SHADER_F32)
} else {
("bincount_i32", BINCOUNT_UNWEIGHTED_SHADER)
};
let module = cache.get_or_create_module(name, shader);
let (layout, bind_group) = if let Some(weights_buf) = weights {
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 3,
num_uniform_buffers: 1,
num_readonly_storage: 2, });
let bind_group =
cache.create_bind_group(&layout, &[input, weights_buf, output, params_buffer]);
(layout, bind_group)
} else {
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 1, });
let bind_group = cache.create_bind_group(&layout, &[input, output, params_buffer]);
(layout, bind_group)
};
let pipeline = cache.get_or_create_pipeline(name, name, &module, &layout);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("bincount"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("bincount"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(n), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_scatter_reduce(
cache: &PipelineCache,
queue: &Queue,
src: &Buffer,
indices: &Buffer,
dst: &Buffer,
params_buffer: &Buffer,
total_src: usize,
dtype: DType,
op: &str,
) -> Result<()> {
let op_name: &'static str = match op {
"sum" => "scatter_reduce_sum",
"max" => "scatter_reduce_max",
"min" => "scatter_reduce_min",
_ => {
return Err(Error::InvalidArgument {
arg: "op",
reason: format!("scatter_reduce op must be sum, max, or min, got {}", op),
});
}
};
let (shader, module_key, entry_point) = shader_info(op_name, dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 3,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[src, indices, dst, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("scatter_reduce"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("scatter_reduce"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(total_src), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_scatter_reduce_prod(
cache: &PipelineCache,
queue: &Queue,
src: &Buffer,
indices: &Buffer,
dst: &Buffer,
params_buffer: &Buffer,
total_src: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_key, entry_point) = shader_info("scatter_reduce_prod", dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 3,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[src, indices, dst, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("scatter_reduce_prod"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("scatter_reduce_prod"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(total_src), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_scatter_reduce_count(
cache: &PipelineCache,
queue: &Queue,
indices: &Buffer,
count: &Buffer,
params_buffer: &Buffer,
total_src: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_key, entry_point) = shader_info("scatter_reduce_count", dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[indices, count, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("scatter_reduce_count"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("scatter_reduce_count"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(total_src), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_scatter_reduce_mean_div(
cache: &PipelineCache,
queue: &Queue,
sum_buf: &Buffer,
count_buf: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
n: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_key, entry_point) = shader_info("scatter_reduce_mean_div", dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 3,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[sum_buf, count_buf, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("scatter_reduce_mean_div"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("scatter_reduce_mean_div"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(n), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_embedding_lookup(
cache: &PipelineCache,
queue: &Queue,
embeddings: &Buffer,
indices: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
num_indices: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_key, entry_point) = shader_info("embedding_lookup", dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 3,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout);
let bind_group =
cache.create_bind_group(&layout, &[embeddings, indices, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("embedding_lookup"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("embedding_lookup"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(num_indices), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_slice_assign(
cache: &PipelineCache,
queue: &Queue,
src: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
total_src: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_key, entry_point) = shader_info("slice_assign", dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[src, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("slice_assign"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("slice_assign"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(total_src), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn launch_gather_2d(
cache: &PipelineCache,
queue: &Queue,
input: &Buffer,
rows: &Buffer,
cols: &Buffer,
output: &Buffer,
params_buffer: &Buffer,
num_indices: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_key, entry_point) = shader_info("gather_2d", dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 4,
num_uniform_buffers: 1,
num_readonly_storage: 3,
});
let pipeline = cache.get_or_create_pipeline(module_key, entry_point, &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[input, rows, cols, output, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("gather_2d"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("gather_2d"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(num_indices), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}