use crate::algorithm::sparse::{validate_spgemm_shapes, zero_tolerance};
use crate::dtype::Element;
use crate::error::Result;
use crate::runtime::cpu::{CpuClient, CpuRuntime};
use crate::sparse::CsrData;
use crate::tensor::Tensor;
use std::collections::{HashMap, HashSet};
pub(super) fn esc_spgemm_csr(
_client: &CpuClient,
a_csr: &CsrData<CpuRuntime>,
b_csr: &CsrData<CpuRuntime>,
) -> Result<CsrData<CpuRuntime>> {
let ([m, n], _k) = validate_spgemm_shapes(a_csr.shape, b_csr.shape)?;
let dtype = a_csr.values.dtype();
let device = a_csr.values.device();
crate::dispatch_dtype!(dtype, T => {
esc_spgemm_typed::<T>(a_csr, b_csr, m, n, device)
}, "esc_spgemm")
}
fn esc_spgemm_typed<T: Element>(
a_csr: &CsrData<CpuRuntime>,
b_csr: &CsrData<CpuRuntime>,
m: usize,
n: usize,
device: &<CpuRuntime as crate::runtime::Runtime>::Device,
) -> Result<CsrData<CpuRuntime>> {
let a_row_ptrs: Vec<i64> = a_csr.row_ptrs.to_vec();
let a_col_indices: Vec<i64> = a_csr.col_indices.to_vec();
let a_values: Vec<T> = a_csr.values.to_vec();
let b_row_ptrs: Vec<i64> = b_csr.row_ptrs.to_vec();
let b_col_indices: Vec<i64> = b_csr.col_indices.to_vec();
let b_values: Vec<T> = b_csr.values.to_vec();
let mut row_nnz: Vec<i64> = vec![0; m];
for i in 0..m {
let a_start = a_row_ptrs[i] as usize;
let a_end = a_row_ptrs[i + 1] as usize;
let mut col_set: HashSet<usize> = HashSet::new();
#[allow(clippy::needless_range_loop)]
for a_idx in a_start..a_end {
let k = a_col_indices[a_idx] as usize;
let b_start = b_row_ptrs[k] as usize;
let b_end = b_row_ptrs[k + 1] as usize;
#[allow(clippy::needless_range_loop)]
for b_idx in b_start..b_end {
let j = b_col_indices[b_idx] as usize;
col_set.insert(j);
}
}
row_nnz[i] = col_set.len() as i64;
}
let mut c_row_ptrs: Vec<i64> = Vec::with_capacity(m + 1);
c_row_ptrs.push(0);
for i in 0..m {
c_row_ptrs.push(c_row_ptrs[i] + row_nnz[i]);
}
let mut c_col_indices: Vec<i64> = Vec::new();
let mut c_values: Vec<T> = Vec::new();
let mut c_row_ptrs_final: Vec<i64> = Vec::with_capacity(m + 1);
c_row_ptrs_final.push(0);
for i in 0..m {
let a_start = a_row_ptrs[i] as usize;
let a_end = a_row_ptrs[i + 1] as usize;
let capacity = row_nnz[i] as usize;
let mut row_accum: HashMap<usize, f64> = HashMap::with_capacity(capacity);
for a_idx in a_start..a_end {
let k = a_col_indices[a_idx] as usize;
let a_val = a_values[a_idx].to_f64();
let b_start = b_row_ptrs[k] as usize;
let b_end = b_row_ptrs[k + 1] as usize;
for b_idx in b_start..b_end {
let j = b_col_indices[b_idx] as usize;
let b_val = b_values[b_idx].to_f64();
*row_accum.entry(j).or_insert(0.0) += a_val * b_val;
}
}
let mut row_entries: Vec<(usize, f64)> = row_accum.into_iter().collect();
row_entries.sort_by_key(|&(col, _)| col);
for (col, val) in row_entries {
if val.abs() > zero_tolerance::<T>() {
c_col_indices.push(col as i64);
c_values.push(T::from_f64(val));
}
}
c_row_ptrs_final.push(c_col_indices.len() as i64);
}
let final_nnz = c_col_indices.len();
let result_row_ptrs = Tensor::from_slice(&c_row_ptrs_final, &[m + 1], device);
let result_col_indices = Tensor::from_slice(&c_col_indices, &[final_nnz], device);
let result_values = Tensor::from_slice(&c_values, &[final_nnz], device);
Ok(CsrData {
row_ptrs: result_row_ptrs,
col_indices: result_col_indices,
values: result_values,
shape: [m, n],
})
}