use super::super::ops::helpers::get_tensor_buffer;
use super::super::shaders::launch_dsmm_csc;
use super::common::validate_wgpu_dtype;
use crate::algorithm::sparse::validate_dsmm_shapes;
use crate::dtype::DType;
use crate::error::Result;
use crate::ops::TypeConversionOps;
use crate::runtime::wgpu::{WgpuClient, WgpuRuntime};
use crate::sparse::CscData;
use crate::tensor::Tensor;
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
pub struct DsmmParams {
pub m: u32,
pub k: u32,
pub n: u32,
pub _pad: u32,
}
pub(super) fn column_parallel_dsmm(
client: &WgpuClient,
dense_a: &Tensor<WgpuRuntime>,
sparse_b_csc: &CscData<WgpuRuntime>,
) -> Result<Tensor<WgpuRuntime>> {
let dtype = dense_a.dtype();
let device = dense_a.device();
validate_wgpu_dtype(dtype, "column_parallel_dsmm")?;
let ([m, n], k) = validate_dsmm_shapes(dense_a.shape(), sparse_b_csc.shape)?;
let a_contig = dense_a.contiguous();
let output = Tensor::<WgpuRuntime>::zeros(&[m, n], dtype, device);
let col_ptrs_i32 = client.cast(&sparse_b_csc.col_ptrs, DType::I32)?;
let row_indices_i32 = client.cast(&sparse_b_csc.row_indices, DType::I32)?;
let params = DsmmParams {
m: m as u32,
k: k as u32,
n: n as u32,
_pad: 0,
};
let params_buffer = client.create_uniform_buffer("dsmm_params", 16);
client.write_buffer(¶ms_buffer, &[params.m, params.k, params.n, params._pad]);
let a_buffer = get_tensor_buffer(&a_contig)?;
let col_ptrs_buffer = get_tensor_buffer(&col_ptrs_i32)?;
let row_indices_buffer = get_tensor_buffer(&row_indices_i32)?;
let b_values_buffer = get_tensor_buffer(&sparse_b_csc.values)?;
let c_buffer = get_tensor_buffer(&output)?;
launch_dsmm_csc(
client.pipeline_cache(),
client.wgpu_queue(),
&a_buffer,
&col_ptrs_buffer,
&row_indices_buffer,
&b_values_buffer,
&c_buffer,
¶ms_buffer,
m,
n,
dtype,
)?;
Ok(output)
}