use crate::dtype::Element;
use crate::error::Result;
use crate::runtime::cpu::{CpuClient, CpuRuntime};
use crate::sparse::CscData;
use crate::tensor::Tensor;
pub(super) fn column_parallel_dsmm(
_client: &CpuClient,
dense_a: &Tensor<CpuRuntime>,
sparse_b_csc: &CscData<CpuRuntime>,
) -> Result<Tensor<CpuRuntime>> {
use crate::algorithm::sparse::validate_dsmm_shapes;
let [k, n] = sparse_b_csc.shape;
let dtype = sparse_b_csc.values.dtype();
let device = dense_a.device();
let ([m, _output_n], _k) = validate_dsmm_shapes(dense_a.shape(), sparse_b_csc.shape)?;
crate::dispatch_dtype!(dtype, T => {
dsmm_typed::<T>(dense_a, sparse_b_csc, m, k, n, device)
}, "dsmm")
}
fn dsmm_typed<T: Element>(
a: &Tensor<CpuRuntime>,
csc: &CscData<CpuRuntime>,
m: usize,
k: usize,
n: usize,
device: &<CpuRuntime as crate::runtime::Runtime>::Device,
) -> Result<Tensor<CpuRuntime>> {
let a_data: Vec<T> = a.to_vec();
let col_ptrs_data: Vec<i64> = csc.col_ptrs.to_vec();
let row_indices_data: Vec<i64> = csc.row_indices.to_vec();
let b_values: Vec<T> = csc.values.to_vec();
let mut c_data: Vec<T> = vec![T::zero(); m * n];
for col in 0..n {
let start = col_ptrs_data[col] as usize;
let end = col_ptrs_data[col + 1] as usize;
for idx in start..end {
let row_b = row_indices_data[idx] as usize; let b_val = b_values[idx].to_f64();
for row_a in 0..m {
let a_val = a_data[row_a * k + row_b].to_f64();
let c_idx = row_a * n + col;
let current = c_data[c_idx].to_f64();
c_data[c_idx] = T::from_f64(current + a_val * b_val);
}
}
}
Ok(Tensor::from_slice(&c_data, &[m, n], device))
}