use wgpu::{Buffer, Queue};
use super::pipeline::{LayoutKey, PipelineCache, workgroup_count};
use crate::dtype::DType;
use crate::error::{Error, Result};
const SPARSE_CONVERSIONS_INDICES: &str = include_str!("sparse_conversions_indices.wgsl");
const SPARSE_CONVERSIONS_F32: &str = include_str!("sparse_conversions_f32.wgsl");
const SPARSE_CONVERSIONS_I32: &str = include_str!("sparse_conversions_i32.wgsl");
const SPARSE_CONVERSIONS_U32: &str = include_str!("sparse_conversions_u32.wgsl");
fn typed_shader(dtype: DType) -> Result<(&'static str, &'static str)> {
match dtype {
DType::F32 => Ok(("sparse_conversions_f32", SPARSE_CONVERSIONS_F32)),
DType::I32 => Ok(("sparse_conversions_i32", SPARSE_CONVERSIONS_I32)),
DType::U32 => Ok(("sparse_conversions_u32", SPARSE_CONVERSIONS_U32)),
_ => Err(Error::UnsupportedDType {
dtype,
op: "sparse_conversions (WebGPU)",
}),
}
}
pub fn launch_expand_row_ptrs(
cache: &PipelineCache,
queue: &Queue,
row_ptrs: &Buffer,
row_indices: &Buffer,
params: &Buffer,
nrows: usize,
) -> Result<()> {
let module =
cache.get_or_create_module("sparse_conversions_indices", SPARSE_CONVERSIONS_INDICES);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(
"sparse_conversions_indices",
"expand_row_ptrs",
&module,
&layout,
);
let bind_group = cache.create_bind_group(&layout, &[row_ptrs, row_indices, params]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("expand_row_ptrs"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("expand_row_ptrs"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(nrows), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_expand_col_ptrs(
cache: &PipelineCache,
queue: &Queue,
col_ptrs: &Buffer,
col_indices: &Buffer,
params: &Buffer,
ncols: usize,
) -> Result<()> {
let module =
cache.get_or_create_module("sparse_conversions_indices", SPARSE_CONVERSIONS_INDICES);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(
"sparse_conversions_indices",
"expand_col_ptrs",
&module,
&layout,
);
let bind_group = cache.create_bind_group(&layout, &[col_ptrs, col_indices, params]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("expand_col_ptrs"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("expand_col_ptrs"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(ncols), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_histogram(
cache: &PipelineCache,
queue: &Queue,
indices: &Buffer,
counts: &Buffer,
params: &Buffer,
nnz: usize,
) -> Result<()> {
let module =
cache.get_or_create_module("sparse_conversions_indices", SPARSE_CONVERSIONS_INDICES);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0,
});
let pipeline =
cache.get_or_create_pipeline("sparse_conversions_indices", "histogram", &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[indices, counts, params]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("histogram"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("histogram"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(nnz), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_coo_to_csr_scatter(
cache: &PipelineCache,
queue: &Queue,
in_row_indices: &Buffer,
in_col_indices: &Buffer,
in_values: &Buffer,
row_ptrs_atomic: &Buffer,
out_col_indices: &Buffer,
out_values: &Buffer,
params: &Buffer,
nnz: usize,
dtype: DType,
) -> Result<()> {
let (module_key, shader) = typed_shader(dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 6, num_uniform_buffers: 1, num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, "coo_to_csr_scatter", &module, &layout);
let bind_group = cache.create_bind_group(
&layout,
&[
in_row_indices,
in_col_indices,
in_values,
row_ptrs_atomic,
out_col_indices,
out_values,
params,
],
);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("coo_to_csr_scatter"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("coo_to_csr_scatter"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(nnz), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_coo_to_csc_scatter(
cache: &PipelineCache,
queue: &Queue,
in_row_indices: &Buffer,
in_col_indices: &Buffer,
in_values: &Buffer,
col_ptrs_atomic: &Buffer,
out_row_indices: &Buffer,
out_values: &Buffer,
params: &Buffer,
nnz: usize,
dtype: DType,
) -> Result<()> {
let (module_key, shader) = typed_shader(dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 6,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, "coo_to_csc_scatter", &module, &layout);
let bind_group = cache.create_bind_group(
&layout,
&[
in_row_indices,
in_col_indices,
in_values,
col_ptrs_atomic,
out_row_indices,
out_values,
params,
],
);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("coo_to_csc_scatter"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("coo_to_csc_scatter"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(nnz), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_csr_to_csc_scatter(
cache: &PipelineCache,
queue: &Queue,
in_row_ptrs: &Buffer,
in_col_indices: &Buffer,
in_values: &Buffer,
col_ptrs_atomic: &Buffer,
out_row_indices: &Buffer,
out_values: &Buffer,
params: &Buffer,
nrows: usize,
dtype: DType,
) -> Result<()> {
let (module_key, shader) = typed_shader(dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 6,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, "csr_to_csc_scatter", &module, &layout);
let bind_group = cache.create_bind_group(
&layout,
&[
in_row_ptrs,
in_col_indices,
in_values,
col_ptrs_atomic,
out_row_indices,
out_values,
params,
],
);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("csr_to_csc_scatter"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("csr_to_csc_scatter"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(nrows), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_csc_to_csr_scatter(
cache: &PipelineCache,
queue: &Queue,
in_col_ptrs: &Buffer,
in_row_indices: &Buffer,
in_values: &Buffer,
row_ptrs_atomic: &Buffer,
out_col_indices: &Buffer,
out_values: &Buffer,
params: &Buffer,
ncols: usize,
dtype: DType,
) -> Result<()> {
let (module_key, shader) = typed_shader(dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 6,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, "csc_to_csr_scatter", &module, &layout);
let bind_group = cache.create_bind_group(
&layout,
&[
in_col_ptrs,
in_row_indices,
in_values,
row_ptrs_atomic,
out_col_indices,
out_values,
params,
],
);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("csc_to_csr_scatter"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("csc_to_csr_scatter"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(ncols), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_copy_ptrs(
cache: &PipelineCache,
queue: &Queue,
src: &Buffer,
dst: &Buffer,
params: &Buffer,
n: usize,
) -> Result<()> {
let module =
cache.get_or_create_module("sparse_conversions_indices", SPARSE_CONVERSIONS_INDICES);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0,
});
let pipeline =
cache.get_or_create_pipeline("sparse_conversions_indices", "copy_ptrs", &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[src, dst, params]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("copy_ptrs"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("copy_ptrs"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(n), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_csr_to_dense(
cache: &PipelineCache,
queue: &Queue,
row_ptrs: &Buffer,
col_indices: &Buffer,
values: &Buffer,
dense: &Buffer,
params: &Buffer,
nrows: usize,
dtype: DType,
) -> Result<()> {
let (module_key, shader) = typed_shader(dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 4, num_uniform_buffers: 1, num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, "csr_to_dense", &module, &layout);
let bind_group =
cache.create_bind_group(&layout, &[row_ptrs, col_indices, values, dense, params]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("csr_to_dense"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("csr_to_dense"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(nrows), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_count_nonzeros(
cache: &PipelineCache,
queue: &Queue,
dense: &Buffer,
count: &Buffer,
params: &Buffer,
total_elems: usize,
dtype: DType,
) -> Result<()> {
let (module_key, shader) = typed_shader(dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2, num_uniform_buffers: 1, num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_key, "count_nonzeros", &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[dense, count, params]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("count_nonzeros"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("count_nonzeros"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(total_elems), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_dense_to_coo_scatter(
cache: &PipelineCache,
queue: &Queue,
dense: &Buffer,
row_indices: &Buffer,
col_indices: &Buffer,
values: &Buffer,
write_pos: &Buffer,
params: &Buffer,
total_elems: usize,
dtype: DType,
) -> Result<()> {
let (module_key, shader) = typed_shader(dtype)?;
let module = cache.get_or_create_module(module_key, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 5, num_uniform_buffers: 1, num_readonly_storage: 0,
});
let pipeline =
cache.get_or_create_pipeline(module_key, "dense_to_coo_scatter", &module, &layout);
let bind_group = cache.create_bind_group(
&layout,
&[dense, row_indices, col_indices, values, write_pos, params],
);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("dense_to_coo_scatter"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("dense_to_coo_scatter"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(total_elems), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}