use super::super::ops::helpers::get_tensor_buffer;
use super::super::shaders::{launch_csr_spmm, launch_csr_spmv};
use super::super::{WgpuClient, WgpuRuntime};
use super::common::validate_wgpu_dtype;
use crate::dtype::Element;
use crate::error::{Error, Result};
use crate::tensor::Tensor;
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
pub struct SpmvParams {
pub nrows: u32,
pub ncols: u32,
pub _pad0: u32,
pub _pad1: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
pub struct SpmmParams {
pub m: u32,
pub k: u32,
pub n: u32,
pub _pad: u32,
}
impl WgpuClient {
pub(crate) fn spmv_csr_impl<T: Element>(
&self,
row_ptrs: &Tensor<WgpuRuntime>,
col_indices: &Tensor<WgpuRuntime>,
values: &Tensor<WgpuRuntime>,
x: &Tensor<WgpuRuntime>,
shape: [usize; 2],
) -> Result<Tensor<WgpuRuntime>> {
let [nrows, ncols] = shape;
let dtype = values.dtype();
validate_wgpu_dtype(dtype, "spmv_csr")?;
if x.numel() != ncols {
return Err(Error::ShapeMismatch {
expected: vec![ncols],
got: vec![x.numel()],
});
}
let y = Tensor::<WgpuRuntime>::zeros(&[nrows], dtype, &self.device_id);
let params = SpmvParams {
nrows: nrows as u32,
ncols: ncols as u32,
_pad0: 0,
_pad1: 0,
};
let params_buffer = self.create_uniform_buffer("spmv_params", 16);
self.write_buffer(
¶ms_buffer,
&[params.nrows, params.ncols, params._pad0, params._pad1],
);
let row_ptrs_buffer = get_tensor_buffer(row_ptrs)?;
let col_indices_buffer = get_tensor_buffer(col_indices)?;
let values_buffer = get_tensor_buffer(values)?;
let x_buffer = get_tensor_buffer(x)?;
let y_buffer = get_tensor_buffer(&y)?;
launch_csr_spmv(
self.pipeline_cache(),
self.wgpu_queue(),
&row_ptrs_buffer,
&col_indices_buffer,
&values_buffer,
&x_buffer,
&y_buffer,
¶ms_buffer,
nrows,
dtype,
)?;
Ok(y)
}
pub(crate) fn spmm_csr_impl<T: Element>(
&self,
row_ptrs: &Tensor<WgpuRuntime>,
col_indices: &Tensor<WgpuRuntime>,
values: &Tensor<WgpuRuntime>,
b: &Tensor<WgpuRuntime>,
shape: [usize; 2],
) -> Result<Tensor<WgpuRuntime>> {
let [m, k] = shape;
let dtype = values.dtype();
validate_wgpu_dtype(dtype, "spmm_csr")?;
if b.ndim() != 2 {
return Err(Error::Internal(format!(
"Expected 2D tensor for SpMM, got {}D",
b.ndim()
)));
}
let b_shape = b.shape();
let b_k = b_shape[0];
let n = b_shape[1];
if b_k != k {
return Err(Error::ShapeMismatch {
expected: vec![k],
got: vec![b_k],
});
}
let c = Tensor::<WgpuRuntime>::zeros(&[m, n], dtype, &self.device_id);
let params = SpmmParams {
m: m as u32,
k: k as u32,
n: n as u32,
_pad: 0,
};
let params_buffer = self.create_uniform_buffer("spmm_params", 16);
self.write_buffer(¶ms_buffer, &[params.m, params.k, params.n, params._pad]);
let row_ptrs_buffer = get_tensor_buffer(row_ptrs)?;
let col_indices_buffer = get_tensor_buffer(col_indices)?;
let values_buffer = get_tensor_buffer(values)?;
let b_buffer = get_tensor_buffer(b)?;
let c_buffer = get_tensor_buffer(&c)?;
launch_csr_spmm(
self.pipeline_cache(),
self.wgpu_queue(),
&row_ptrs_buffer,
&col_indices_buffer,
&values_buffer,
&b_buffer,
&c_buffer,
¶ms_buffer,
m,
n,
dtype,
)?;
Ok(c)
}
}