use wgpu::{Buffer, Queue};
use super::pipeline::{LayoutKey, PipelineCache, workgroup_count};
use crate::dtype::DType;
use crate::error::{Error, Result};
const SPARSE_LINALG: &str = include_str!("sparse_linalg.wgsl");
const SPARSE_LINALG_SPLIT_F32: &str = include_str!("sparse_linalg_split_f32.wgsl");
pub fn launch_split_lu_count(
cache: &PipelineCache,
queue: &Queue,
row_ptrs: &Buffer,
col_indices: &Buffer,
l_counts: &Buffer,
u_counts: &Buffer,
params_buffer: &Buffer,
n: usize,
) -> Result<()> {
let module = cache.get_or_create_module("sparse_linalg_split_f32", SPARSE_LINALG_SPLIT_F32);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 4,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(
"sparse_linalg_split_f32",
"split_lu_count",
&module,
&layout,
);
let bind_group = cache.create_bind_group(
&layout,
&[row_ptrs, col_indices, l_counts, u_counts, params_buffer],
);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("split_lu_count"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("split_lu_count"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(n), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_split_lu_scatter_l(
cache: &PipelineCache,
queue: &Queue,
row_ptrs: &Buffer,
col_indices: &Buffer,
values: &Buffer,
l_row_ptrs: &Buffer,
l_col_indices: &Buffer,
l_values: &Buffer,
params_buffer: &Buffer,
n: usize,
dtype: DType,
) -> Result<()> {
if dtype != DType::F32 {
return Err(Error::UnsupportedDType {
dtype,
op: "split_lu_scatter_l (WebGPU)",
});
}
let module = cache.get_or_create_module("sparse_linalg_split_f32", SPARSE_LINALG_SPLIT_F32);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 6,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(
"sparse_linalg_split_f32",
"split_lu_scatter_l_f32",
&module,
&layout,
);
let bind_group = cache.create_bind_group(
&layout,
&[
row_ptrs,
col_indices,
values,
l_row_ptrs,
l_col_indices,
l_values,
params_buffer,
],
);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("split_lu_scatter_l"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("split_lu_scatter_l"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(n), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_split_lu_scatter_u(
cache: &PipelineCache,
queue: &Queue,
row_ptrs: &Buffer,
col_indices: &Buffer,
values: &Buffer,
u_row_ptrs: &Buffer,
u_col_indices: &Buffer,
u_values: &Buffer,
params_buffer: &Buffer,
n: usize,
dtype: DType,
) -> Result<()> {
if dtype != DType::F32 {
return Err(Error::UnsupportedDType {
dtype,
op: "split_lu_scatter_u (WebGPU)",
});
}
let module = cache.get_or_create_module("sparse_linalg_split_f32", SPARSE_LINALG_SPLIT_F32);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 6,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(
"sparse_linalg_split_f32",
"split_lu_scatter_u_f32",
&module,
&layout,
);
let bind_group = cache.create_bind_group(
&layout,
&[
row_ptrs,
col_indices,
values,
u_row_ptrs,
u_col_indices,
u_values,
params_buffer,
],
);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("split_lu_scatter_u"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("split_lu_scatter_u"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(n), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_extract_lower_count(
cache: &PipelineCache,
queue: &Queue,
row_ptrs: &Buffer,
col_indices: &Buffer,
l_counts: &Buffer,
params_buffer: &Buffer,
n: usize,
) -> Result<()> {
let module = cache.get_or_create_module("sparse_linalg_split_f32", SPARSE_LINALG_SPLIT_F32);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 3,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(
"sparse_linalg_split_f32",
"extract_lower_count",
&module,
&layout,
);
let bind_group =
cache.create_bind_group(&layout, &[row_ptrs, col_indices, l_counts, params_buffer]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("extract_lower_count"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("extract_lower_count"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(n), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_extract_lower_scatter(
cache: &PipelineCache,
queue: &Queue,
row_ptrs: &Buffer,
col_indices: &Buffer,
values: &Buffer,
l_row_ptrs: &Buffer,
l_col_indices: &Buffer,
l_values: &Buffer,
params_buffer: &Buffer,
n: usize,
dtype: DType,
) -> Result<()> {
if dtype != DType::F32 {
return Err(Error::UnsupportedDType {
dtype,
op: "extract_lower_scatter (WebGPU)",
});
}
let module = cache.get_or_create_module("sparse_linalg_split_f32", SPARSE_LINALG_SPLIT_F32);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 6,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline = cache.get_or_create_pipeline(
"sparse_linalg_split_f32",
"extract_lower_scatter_f32",
&module,
&layout,
);
let bind_group = cache.create_bind_group(
&layout,
&[
row_ptrs,
col_indices,
values,
l_row_ptrs,
l_col_indices,
l_values,
params_buffer,
],
);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("extract_lower_scatter"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("extract_lower_scatter"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(n), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
#[repr(C)]
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
pub struct SparseLuParams {
pub scale: f32,
pub nnz: u32,
}
pub fn launch_sparse_scatter_f32(
cache: &PipelineCache,
queue: &Queue,
values: &Buffer,
row_indices: &Buffer,
work: &Buffer,
nnz: usize,
) -> Result<()> {
let module = cache.get_or_create_module("sparse_linalg", SPARSE_LINALG);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 3,
num_uniform_buffers: 0,
num_readonly_storage: 0,
});
let pipeline =
cache.get_or_create_pipeline("sparse_linalg", "sparse_scatter_f32", &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[values, row_indices, work]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("sparse_scatter_f32"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("sparse_scatter_f32"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(nnz), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_sparse_axpy_f32(
cache: &PipelineCache,
queue: &Queue,
params_buffer: &Buffer,
values: &Buffer,
row_indices: &Buffer,
work: &Buffer,
nnz: usize,
) -> Result<()> {
let module = cache.get_or_create_module("sparse_linalg", SPARSE_LINALG);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 3,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline =
cache.get_or_create_pipeline("sparse_linalg", "sparse_axpy_f32", &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[params_buffer, values, row_indices, work]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("sparse_axpy_f32"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("sparse_axpy_f32"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(nnz), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_sparse_gather_clear_f32(
cache: &PipelineCache,
queue: &Queue,
work: &Buffer,
row_indices: &Buffer,
output: &Buffer,
nnz: usize,
) -> Result<()> {
let module = cache.get_or_create_module("sparse_linalg", SPARSE_LINALG);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 3,
num_uniform_buffers: 0,
num_readonly_storage: 0,
});
let pipeline =
cache.get_or_create_pipeline("sparse_linalg", "sparse_gather_clear_f32", &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[work, row_indices, output]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("sparse_gather_clear_f32"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("sparse_gather_clear_f32"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(nnz), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
#[repr(C)]
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
pub struct DividePivotParams {
pub inv_pivot: f32,
pub nnz: u32,
}
pub fn launch_sparse_divide_pivot_f32(
cache: &PipelineCache,
queue: &Queue,
params_buffer: &Buffer,
work: &Buffer,
row_indices: &Buffer,
nnz: usize,
) -> Result<()> {
let module = cache.get_or_create_module("sparse_linalg", SPARSE_LINALG);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 1,
num_readonly_storage: 0,
});
let pipeline =
cache.get_or_create_pipeline("sparse_linalg", "sparse_divide_pivot_f32", &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[params_buffer, work, row_indices]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("sparse_divide_pivot_f32"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("sparse_divide_pivot_f32"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(nnz), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}
pub fn launch_sparse_clear_f32(
cache: &PipelineCache,
queue: &Queue,
work: &Buffer,
row_indices: &Buffer,
nnz: usize,
) -> Result<()> {
let module = cache.get_or_create_module("sparse_linalg", SPARSE_LINALG);
let layout = cache.get_or_create_layout(LayoutKey {
num_storage_buffers: 2,
num_uniform_buffers: 0,
num_readonly_storage: 0,
});
let pipeline =
cache.get_or_create_pipeline("sparse_linalg", "sparse_clear_f32", &module, &layout);
let bind_group = cache.create_bind_group(&layout, &[work, row_indices]);
let mut encoder = cache
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("sparse_clear_f32"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("sparse_clear_f32"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(workgroup_count(nnz), 1, 1);
}
queue.submit(std::iter::once(encoder.finish()));
Ok(())
}