use wgpu::{Buffer, Queue};
use super::pipeline::{LayoutKey, PipelineCache, workgroup_count};
use crate::dtype::DType;
use crate::error::{Error, Result};
const SPARSE_SPMV_F32: &str = include_str!("sparse_spmv_f32.wgsl");
fn spmv_shader_info(dtype: DType) -> Result<(&'static str, &'static str)> {
match dtype {
DType::F32 => Ok((SPARSE_SPMV_F32, "sparse_spmv_f32")),
_ => Err(Error::UnsupportedDType {
dtype,
op: "csr_spmv (WebGPU)",
}),
}
}
pub fn launch_csr_spmv(
cache: &PipelineCache,
queue: &Queue,
row_ptrs: &Buffer,
col_indices: &Buffer,
values: &Buffer,
x: &Buffer,
y: &Buffer,
params_buffer: &Buffer,
nrows: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_name) = spmv_shader_info(dtype)?;
let module = cache.get_or_create_module(module_name, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 5, num_uniform_buffers: 1, num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_name, "csr_spmv_f32", &module, &layout);
let bind_group = cache.create_bind_group(
&layout,
&[row_ptrs, col_indices, values, x, y, params_buffer],
);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("csr_spmv"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("csr_spmv"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(nrows), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_csr_spmm(
cache: &PipelineCache,
queue: &Queue,
row_ptrs: &Buffer,
col_indices: &Buffer,
a_values: &Buffer,
b: &Buffer,
c: &Buffer,
params_buffer: &Buffer,
m: usize,
n: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_name) = spmv_shader_info(dtype)?;
let module = cache.get_or_create_module(module_name, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 5, num_uniform_buffers: 1, num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(module_name, "csr_spmm_f32", &module, &layout);
let bind_group = cache.create_bind_group(
&layout,
&[row_ptrs, col_indices, a_values, b, c, params_buffer],
);
let total_elements = m * n;
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("csr_spmm"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("csr_spmm"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(total_elements), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_csr_extract_diagonal(
cache: &PipelineCache,
queue: &Queue,
row_ptrs: &Buffer,
col_indices: &Buffer,
values: &Buffer,
diag: &Buffer,
params_buffer: &Buffer,
n: usize,
dtype: DType,
) -> Result<()> {
let (shader, module_name) = spmv_shader_info(dtype)?;
let module = cache.get_or_create_module(module_name, shader);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 4, num_uniform_buffers: 1, num_readonly_storage: 0,
});
let pipeline =
cache.get_or_create_pipeline(module_name, "csr_extract_diagonal_f32", &module, &layout);
let bind_group = cache.create_bind_group(
&layout,
&[row_ptrs, col_indices, values, diag, params_buffer],
);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("csr_extract_diagonal"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("csr_extract_diagonal"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(n), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}