use crate::{CooTensor, SparseFormat, SparseTensor, TorshResult};
use torsh_core::{device::DeviceType, DType, Shape, TorshError};
use torsh_tensor::{creation::zeros, Tensor};
pub struct CscTensor {
col_ptr: Vec<usize>,
row_indices: Vec<usize>,
values: Vec<f32>,
shape: Shape,
dtype: DType,
device: DeviceType,
}
impl CscTensor {
pub fn new(
col_ptr: Vec<usize>,
row_indices: Vec<usize>,
values: Vec<f32>,
shape: Shape,
) -> TorshResult<Self> {
if row_indices.len() != values.len() {
return Err(TorshError::InvalidArgument(
"Row indices and values must have the same length".to_string(),
));
}
if shape.ndim() != 2 {
return Err(TorshError::InvalidArgument(
"CSC format currently only supports 2D tensors".to_string(),
));
}
let cols = shape.dims()[1];
if col_ptr.len() != cols + 1 {
return Err(TorshError::InvalidArgument(format!(
"Column pointer length must be cols + 1, got {} for {} columns",
col_ptr.len(),
cols
)));
}
let rows = shape.dims()[0];
for &row in &row_indices {
if row >= rows {
return Err(TorshError::InvalidArgument(format!(
"Row index {row} out of bounds for {rows} rows"
)));
}
}
Ok(Self {
col_ptr,
row_indices,
values,
shape,
dtype: DType::F32,
device: DeviceType::Cpu,
})
}
pub fn from_raw_parts(
col_ptr: Vec<usize>,
row_indices: Vec<usize>,
values: Vec<f32>,
shape: Shape,
) -> TorshResult<Self> {
Self::new(col_ptr, row_indices, values, shape)
}
pub fn empty(shape: Shape) -> TorshResult<Self> {
if shape.ndim() != 2 {
return Err(TorshError::InvalidArgument(
"CSC format currently only supports 2D tensors".to_string(),
));
}
let cols = shape.dims()[1];
let col_ptr = vec![0; cols + 1];
let row_indices = Vec::new();
let values = Vec::new();
Ok(Self {
col_ptr,
row_indices,
values,
shape,
dtype: DType::F32,
device: DeviceType::Cpu,
})
}
pub fn from_dense(dense: &Tensor, threshold: f32) -> TorshResult<Self> {
let coo = CooTensor::from_dense(dense, threshold)?;
Self::from_coo(&coo)
}
pub fn from_coo(coo: &CooTensor) -> TorshResult<Self> {
let shape = coo.shape().clone();
let cols = shape.dims()[1];
let mut triplets = coo.triplets();
triplets.sort_by_key(|&(row, col, _)| (col, row));
let mut col_ptr = vec![0];
let mut row_indices = Vec::new();
let mut values = Vec::new();
let mut current_col = 0;
for (row, col, val) in triplets {
while current_col < col {
col_ptr.push(row_indices.len());
current_col += 1;
}
row_indices.push(row);
values.push(val);
}
while current_col < cols {
col_ptr.push(row_indices.len());
current_col += 1;
}
Self::new(col_ptr, row_indices, values, shape)
}
pub fn get_col(&self, col: usize) -> TorshResult<(Vec<usize>, Vec<f32>)> {
if col >= self.shape.dims()[1] {
return Err(TorshError::InvalidArgument(format!(
"Column index {col} out of bounds"
)));
}
let start = self.col_ptr[col];
let end = self.col_ptr[col + 1];
let rows = self.row_indices[start..end].to_vec();
let vals = self.values[start..end].to_vec();
Ok((rows, vals))
}
pub fn vecmat(&self, vector: &Tensor) -> TorshResult<Tensor> {
if vector.shape().ndim() != 1 {
return Err(TorshError::InvalidArgument(
"Vector must be 1-dimensional".to_string(),
));
}
if vector.shape().dims()[0] != self.shape.dims()[0] {
return Err(TorshError::InvalidArgument(format!(
"Vector length {} doesn't match matrix rows {}",
vector.shape().dims()[0],
self.shape.dims()[0]
)));
}
let result = zeros::<f32>(&[self.shape.dims()[1]])?;
for col in 0..self.shape.dims()[1] {
let start = self.col_ptr[col];
let end = self.col_ptr[col + 1];
let mut sum = 0.0;
for i in start..end {
let row = self.row_indices[i];
let val = self.values[i];
sum += vector.get(&[row])? * val;
}
result.set(&[col], sum)?;
}
Ok(result)
}
pub fn col_ptr(&self) -> &Vec<usize> {
&self.col_ptr
}
pub fn row_indices(&self) -> &Vec<usize> {
&self.row_indices
}
pub fn values(&self) -> &Vec<f32> {
&self.values
}
pub fn transpose_to_csr(&self) -> crate::CsrTensor {
crate::CsrTensor::new(
self.col_ptr.clone(),
self.row_indices.clone(),
self.values.clone(),
Shape::new(vec![self.shape.dims()[1], self.shape.dims()[0]]),
)
.expect("CSR construction from valid CSC data should succeed")
}
}
impl SparseTensor for CscTensor {
fn format(&self) -> SparseFormat {
SparseFormat::Csc
}
fn shape(&self) -> &Shape {
&self.shape
}
fn dtype(&self) -> DType {
self.dtype
}
fn device(&self) -> DeviceType {
self.device
}
fn nnz(&self) -> usize {
self.values.len()
}
fn to_dense(&self) -> TorshResult<Tensor> {
let dense = zeros::<f32>(self.shape.dims())?;
for col in 0..self.shape.dims()[1] {
let start = self.col_ptr[col];
let end = self.col_ptr[col + 1];
for i in start..end {
let row = self.row_indices[i];
let val = self.values[i];
dense.set(&[row, col], val)?;
}
}
Ok(dense)
}
fn to_coo(&self) -> TorshResult<CooTensor> {
let mut row_indices = Vec::new();
let mut col_indices = Vec::new();
let mut values = Vec::new();
for col in 0..self.shape.dims()[1] {
let start = self.col_ptr[col];
let end = self.col_ptr[col + 1];
for i in start..end {
row_indices.push(self.row_indices[i]);
col_indices.push(col);
values.push(self.values[i]);
}
}
CooTensor::new(row_indices, col_indices, values, self.shape.clone())
}
fn to_csr(&self) -> TorshResult<crate::CsrTensor> {
let coo = self.to_coo()?;
crate::CsrTensor::from_coo(&coo)
}
fn to_csc(&self) -> TorshResult<CscTensor> {
Ok(self.clone())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
impl Clone for CscTensor {
fn clone(&self) -> Self {
Self {
col_ptr: self.col_ptr.clone(),
row_indices: self.row_indices.clone(),
values: self.values.clone(),
shape: self.shape.clone(),
dtype: self.dtype,
device: self.device,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::tensor_1d;
#[test]
fn test_csc_creation() {
let col_ptr = vec![0, 1, 2, 3];
let row_indices = vec![1, 0, 2];
let values = vec![1.0, 3.0, 2.0];
let shape = Shape::new(vec![3, 3]);
let csc = CscTensor::new(col_ptr, row_indices, values, shape).unwrap();
assert_eq!(csc.nnz(), 3);
}
#[test]
fn test_csc_vecmat() {
let col_ptr = vec![0, 1, 2, 3];
let row_indices = vec![1, 0, 2];
let values = vec![1.0, 3.0, 2.0];
let shape = Shape::new(vec![3, 3]);
let csc = CscTensor::new(col_ptr, row_indices, values, shape).unwrap();
let vector = tensor_1d(&[1.0, 2.0, 3.0]).unwrap();
let result = csc.vecmat(&vector).unwrap();
assert_eq!(result.get(&[0]).unwrap(), 2.0);
assert_eq!(result.get(&[1]).unwrap(), 3.0);
assert_eq!(result.get(&[2]).unwrap(), 6.0);
}
#[test]
fn test_csc_to_dense() {
let col_ptr = vec![0, 1, 2, 3];
let row_indices = vec![1, 0, 2];
let values = vec![1.0, 3.0, 2.0];
let shape = Shape::new(vec![3, 3]);
let csc = CscTensor::new(col_ptr, row_indices, values, shape).unwrap();
let dense = csc.to_dense().unwrap();
assert_eq!(dense.get(&[0, 1]).unwrap(), 3.0);
assert_eq!(dense.get(&[1, 0]).unwrap(), 1.0);
assert_eq!(dense.get(&[2, 2]).unwrap(), 2.0);
}
}