use wgpu::{Buffer, Queue};
use super::pipeline::{LayoutKey, PipelineCache, workgroup_count};
use crate::dtype::DType;
use crate::error::{Error, Result};
const SPARSE_ALGORITHMS_F32: &str = include_str!("sparse_algorithms_f32.wgsl");
fn algorithms_shader_info(dtype: DType) -> Result<(&'static str, &'static str)> {
match dtype {
DType::F32 => Ok((SPARSE_ALGORITHMS_F32, "sparse_algorithms_f32")),
_ => Err(Error::UnsupportedDType {
dtype,
op: "sparse_algorithms (WebGPU)",
}),
}
}
pub fn launch_dsmm_csc(
cache: &PipelineCache,
queue: &Queue,
a: &Buffer,
col_ptrs: &Buffer,
row_indices: &Buffer,
b_values: &Buffer,
c: &Buffer,
params_buffer: &Buffer,
m: usize,
n: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_name) = algorithms_shader_info(dtype)?;
let module = cache.get_or_create_module(module_name, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 5, num_uniform_buffers: 1, num_readonly_storage: 4, });
let pipeline = cache.get_or_create_pipeline(module_name, "dsmm_csc_f32", &module, &layout);
let bind_group = cache.create_bind_group(
&layout,
&[a, col_ptrs, row_indices, b_values, c, params_buffer],
);
let total_elements = m * n;
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("dsmm_csc"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("dsmm_csc"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(total_elements), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_spgemm_symbolic(
cache: &PipelineCache,
queue: &Queue,
a_row_ptrs: &Buffer,
a_col_indices: &Buffer,
b_row_ptrs: &Buffer,
b_col_indices: &Buffer,
row_nnz: &Buffer,
params_buffer: &Buffer,
bitmap: &Buffer,
m: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_name) = algorithms_shader_info(dtype)?;
let module = cache.get_or_create_module(module_name, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 6, num_uniform_buffers: 1, num_readonly_storage: 4, });
let pipeline =
cache.get_or_create_pipeline(module_name, "spgemm_symbolic_f32", &module, &layout);
let bind_group = cache.create_bind_group(
&layout,
&[
a_row_ptrs,
a_col_indices,
b_row_ptrs,
b_col_indices,
row_nnz,
bitmap,
params_buffer,
],
);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("spgemm_symbolic"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("spgemm_symbolic"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(m), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_spgemm_accumulate(
cache: &PipelineCache,
queue: &Queue,
a_row_ptrs: &Buffer,
a_col_indices: &Buffer,
a_values: &Buffer,
b_row_ptrs: &Buffer,
b_col_indices: &Buffer,
b_values: &Buffer,
params_buffer: &Buffer,
accum: &Buffer,
flags: &Buffer,
m: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_name) = algorithms_shader_info(dtype)?;
let module = cache.get_or_create_module(module_name, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 8, num_uniform_buffers: 1, num_readonly_storage: 6, });
let pipeline =
cache.get_or_create_pipeline(module_name, "spgemm_accumulate_f32", &module, &layout);
let bind_group = cache.create_bind_group(
&layout,
&[
a_row_ptrs,
a_col_indices,
a_values,
b_row_ptrs,
b_col_indices,
b_values,
accum,
flags,
params_buffer,
],
);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("spgemm_accumulate"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("spgemm_accumulate"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(m), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_spgemm_scatter(
cache: &PipelineCache,
queue: &Queue,
c_row_ptrs: &Buffer,
accum: &Buffer,
flags: &Buffer,
c_col_indices: &Buffer,
c_values: &Buffer,
params_buffer: &Buffer,
m: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_name) = algorithms_shader_info(dtype)?;
let module = cache.get_or_create_module(module_name, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 5, num_uniform_buffers: 1, num_readonly_storage: 3, });
let pipeline =
cache.get_or_create_pipeline(module_name, "spgemm_scatter_f32", &module, &layout);
let bind_group = cache.create_bind_group(
&layout,
&[
c_row_ptrs,
accum,
flags,
c_col_indices,
c_values,
params_buffer,
],
);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("spgemm_scatter"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("spgemm_scatter"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(m), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}