use super::super::ops::helpers::get_tensor_buffer;
use super::super::shaders::{
launch_exclusive_scan_i32, launch_spgemm_accumulate, launch_spgemm_scatter,
launch_spgemm_symbolic,
};
use super::common::validate_wgpu_dtype;
use super::merge::ScanParams;
use crate::algorithm::sparse::validate_spgemm_shapes;
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::TypeConversionOps;
use crate::runtime::wgpu::{WgpuClient, WgpuRuntime};
use crate::sparse::CsrData;
use crate::tensor::Tensor;
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
pub struct SpgemmSymbolicParams {
pub m: u32,
pub n: u32,
pub _pad0: u32,
pub _pad1: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
pub struct SpgemmParams {
pub m: u32,
pub n: u32,
pub _pad0: u32,
pub _pad1: u32,
}
pub(super) fn esc_spgemm_csr(
client: &WgpuClient,
a_csr: &CsrData<WgpuRuntime>,
b_csr: &CsrData<WgpuRuntime>,
) -> Result<CsrData<WgpuRuntime>> {
let ([m, n], _k) = validate_spgemm_shapes(a_csr.shape, b_csr.shape)?;
let dtype = a_csr.values.dtype();
let device = a_csr.values.device();
validate_wgpu_dtype(dtype, "esc_spgemm_csr")?;
const MAX_DENSE_ELEMENTS: usize = 64 * 1024 * 1024; if m * n > MAX_DENSE_ELEMENTS {
return Err(Error::Internal(format!(
"SpGEMM output matrix {}x{} = {} elements exceeds WebGPU limit of {} elements",
m,
n,
m * n,
MAX_DENSE_ELEMENTS
)));
}
let a_row_ptrs_i32 = client.cast(&a_csr.row_ptrs, DType::I32)?;
let a_col_indices_i32 = client.cast(&a_csr.col_indices, DType::I32)?;
let b_row_ptrs_i32 = client.cast(&b_csr.row_ptrs, DType::I32)?;
let b_col_indices_i32 = client.cast(&b_csr.col_indices, DType::I32)?;
let row_nnz = Tensor::<WgpuRuntime>::zeros(&[m], DType::I32, device);
let words_per_row = (n + 31) / 32;
let bitmap_size = (m * words_per_row * 4) as u64; let bitmap = client.wgpu_device().create_buffer(&wgpu::BufferDescriptor {
label: Some("spgemm_bitmap"),
size: bitmap_size.max(4), usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
if bitmap_size > 0 {
client
.wgpu_queue()
.write_buffer(&bitmap, 0, &vec![0u8; bitmap_size as usize]);
}
let symbolic_params = SpgemmSymbolicParams {
m: m as u32,
n: n as u32,
_pad0: 0,
_pad1: 0,
};
let symbolic_params_buffer = client.create_uniform_buffer("spgemm_symbolic_params", 16);
client.write_buffer(
&symbolic_params_buffer,
&[
symbolic_params.m,
symbolic_params.n,
symbolic_params._pad0,
symbolic_params._pad1,
],
);
let a_row_ptrs_buffer = get_tensor_buffer(&a_row_ptrs_i32)?;
let a_col_indices_buffer = get_tensor_buffer(&a_col_indices_i32)?;
let b_row_ptrs_buffer = get_tensor_buffer(&b_row_ptrs_i32)?;
let b_col_indices_buffer = get_tensor_buffer(&b_col_indices_i32)?;
let row_nnz_buffer = get_tensor_buffer(&row_nnz)?;
launch_spgemm_symbolic(
client.pipeline_cache(),
client.wgpu_queue(),
&a_row_ptrs_buffer,
&a_col_indices_buffer,
&b_row_ptrs_buffer,
&b_col_indices_buffer,
&row_nnz_buffer,
&symbolic_params_buffer,
&bitmap,
m,
dtype,
)?;
let c_row_ptrs_i32 = Tensor::<WgpuRuntime>::zeros(&[m + 1], DType::I32, device);
let scan_params = ScanParams {
n: m as u32,
_pad0: 0,
_pad1: 0,
_pad2: 0,
};
let scan_params_buffer = client.create_uniform_buffer("spgemm_scan_params", 16);
client.write_buffer(
&scan_params_buffer,
&[
scan_params.n,
scan_params._pad0,
scan_params._pad1,
scan_params._pad2,
],
);
let c_row_ptrs_buffer = get_tensor_buffer(&c_row_ptrs_i32)?;
launch_exclusive_scan_i32(
client.pipeline_cache(),
client.wgpu_queue(),
&row_nnz_buffer,
&c_row_ptrs_buffer,
&scan_params_buffer,
)?;
let staging = client.create_staging_buffer("total_nnz_staging", 4);
let mut encoder =
client
.wgpu_device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("total_nnz_copy"),
});
let offset = (m * 4) as u64;
encoder.copy_buffer_to_buffer(&c_row_ptrs_buffer, offset, &staging, 0, 4);
client.submit_and_wait(encoder);
let mut total_nnz_arr = [0i32; 1];
client.read_buffer(&staging, &mut total_nnz_arr)?;
let total_nnz = total_nnz_arr[0];
if total_nnz == 0 {
let c_row_ptrs = Tensor::<WgpuRuntime>::zeros(&[m + 1], DType::I64, device);
let c_col_indices = Tensor::<WgpuRuntime>::zeros(&[0], DType::I64, device);
let c_values = Tensor::<WgpuRuntime>::zeros(&[0], dtype, device);
return CsrData::new(c_row_ptrs, c_col_indices, c_values, [m, n]);
}
let c_col_indices = Tensor::<WgpuRuntime>::zeros(&[total_nnz as usize], DType::I32, device);
let c_values = Tensor::<WgpuRuntime>::zeros(&[total_nnz as usize], dtype, device);
let elem_size = match dtype {
DType::F32 => 4,
DType::F16 => 2,
_ => 4, };
let accum_size = (m * n * elem_size) as u64;
let accum = client.wgpu_device().create_buffer(&wgpu::BufferDescriptor {
label: Some("spgemm_accum"),
size: accum_size.max(4),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let flags_size = (m * n * 4) as u64; let flags = client.wgpu_device().create_buffer(&wgpu::BufferDescriptor {
label: Some("spgemm_flags"),
size: flags_size.max(4),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let numeric_params_buffer = client.create_uniform_buffer("spgemm_numeric_params", 16);
let numeric_params = SpgemmParams {
m: m as u32,
n: n as u32,
_pad0: 0,
_pad1: 0,
};
client.write_buffer(
&numeric_params_buffer,
&[
numeric_params.m,
numeric_params.n,
numeric_params._pad0,
numeric_params._pad1,
],
);
let a_values_buffer = get_tensor_buffer(&a_csr.values)?;
let b_values_buffer = get_tensor_buffer(&b_csr.values)?;
let c_col_indices_buffer = get_tensor_buffer(&c_col_indices)?;
let c_values_buffer = get_tensor_buffer(&c_values)?;
launch_spgemm_accumulate(
client.pipeline_cache(),
client.wgpu_queue(),
&a_row_ptrs_buffer,
&a_col_indices_buffer,
&a_values_buffer,
&b_row_ptrs_buffer,
&b_col_indices_buffer,
&b_values_buffer,
&numeric_params_buffer,
&accum,
&flags,
m,
dtype,
)?;
launch_spgemm_scatter(
client.pipeline_cache(),
client.wgpu_queue(),
&c_row_ptrs_buffer,
&accum,
&flags,
&c_col_indices_buffer,
&c_values_buffer,
&numeric_params_buffer,
m,
dtype,
)?;
let c_row_ptrs = client.cast(&c_row_ptrs_i32, DType::I64)?;
let c_col_indices_i64 = client.cast(&c_col_indices, DType::I64)?;
CsrData::new(c_row_ptrs, c_col_indices_i64, c_values, [m, n])
}