use crate::algorithm::sparse::validate_dsmm_shapes;
use crate::error::Result;
use crate::runtime::cuda::kernels::launch_dsmm_csc;
use crate::runtime::cuda::{CudaClient, CudaRuntime};
use crate::sparse::CscData;
use crate::tensor::Tensor;
pub(super) fn column_parallel_dsmm(
client: &CudaClient,
dense_a: &Tensor<CudaRuntime>,
sparse_b_csc: &CscData<CudaRuntime>,
) -> Result<Tensor<CudaRuntime>> {
let dtype = dense_a.dtype();
let device = dense_a.device();
let ([m, n], k) = validate_dsmm_shapes(dense_a.shape(), sparse_b_csc.shape)?;
let a_contig = dense_a.contiguous();
let output = Tensor::<CudaRuntime>::zeros(&[m, n], dtype, device);
let a_ptr = a_contig.ptr();
let col_ptrs_ptr = sparse_b_csc.col_ptrs.ptr();
let row_indices_ptr = sparse_b_csc.row_indices.ptr();
let values_ptr = sparse_b_csc.values.ptr();
let output_ptr = output.ptr();
unsafe {
crate::dispatch_dtype!(dtype, T => {
launch_dsmm_csc::<T>(
&client.context,
&client.stream,
client.device.index,
a_ptr,
col_ptrs_ptr,
row_indices_ptr,
values_ptr,
output_ptr,
m,
k,
n,
)
}, "dsmm")?;
}
Ok(output)
}