use crate::{CooTensor, CscTensor, CsrTensor, SparseFormat, SparseTensor, TorshResult};
use std::collections::BTreeMap;
use torsh_core::{device::DeviceType, DType, Shape, TorshError};
use torsh_tensor::Tensor;
#[derive(Debug, Clone)]
pub struct DsrTensor {
shape: Shape,
dtype: DType,
device: DeviceType,
rows: Vec<BTreeMap<usize, f32>>,
cached_nnz: usize,
}
impl DsrTensor {
pub fn new(shape: Shape, dtype: DType) -> TorshResult<Self> {
if shape.dims().len() != 2 {
return Err(TorshError::InvalidArgument(
"DSR format only supports 2D tensors".to_string(),
));
}
let rows = vec![BTreeMap::new(); shape.dims()[0]];
Ok(Self {
shape,
dtype,
device: DeviceType::Cpu,
rows,
cached_nnz: 0,
})
}
pub fn from_coo(coo: &CooTensor) -> TorshResult<Self> {
let mut dsr = Self::new(coo.shape().clone(), coo.dtype())?;
for (row, col, val) in coo.triplets() {
dsr.set(row, col, val)?;
}
Ok(dsr)
}
pub fn from_csr(csr: &CsrTensor) -> TorshResult<Self> {
let coo = csr.to_coo()?;
Self::from_coo(&coo)
}
pub fn from_dense(dense: &Tensor, _threshold: f32) -> TorshResult<Self> {
let shape = dense.shape();
if shape.dims().len() != 2 {
return Err(TorshError::InvalidArgument(
"DSR format only supports 2D tensors".to_string(),
));
}
let dsr = Self::new(shape.clone(), dense.dtype())?;
let rows = shape.dims()[0];
let cols = shape.dims()[1];
for _row in 0..rows {
for _col in 0..cols {
}
}
Ok(dsr)
}
pub fn set(&mut self, row: usize, col: usize, value: f32) -> TorshResult<()> {
if row >= self.shape.dims()[0] || col >= self.shape.dims()[1] {
return Err(TorshError::InvalidArgument(
"Index out of bounds".to_string(),
));
}
let was_present = self.rows[row].contains_key(&col);
let is_zero = value.abs() < 1e-12;
if is_zero {
if self.rows[row].remove(&col).is_some() {
self.cached_nnz -= 1;
}
} else {
if !was_present {
self.cached_nnz += 1;
}
self.rows[row].insert(col, value);
}
Ok(())
}
pub fn get(&self, row: usize, col: usize) -> TorshResult<f32> {
if row >= self.shape.dims()[0] || col >= self.shape.dims()[1] {
return Err(TorshError::InvalidArgument(
"Index out of bounds".to_string(),
));
}
Ok(self.rows[row].get(&col).copied().unwrap_or(0.0))
}
pub fn remove(&mut self, row: usize, col: usize) -> TorshResult<f32> {
if row >= self.shape.dims()[0] || col >= self.shape.dims()[1] {
return Err(TorshError::InvalidArgument(
"Index out of bounds".to_string(),
));
}
if let Some(value) = self.rows[row].remove(&col) {
self.cached_nnz -= 1;
Ok(value)
} else {
Ok(0.0)
}
}
pub fn insert_row_elements(
&mut self,
row: usize,
elements: &[(usize, f32)],
) -> TorshResult<()> {
if row >= self.shape.dims()[0] {
return Err(TorshError::InvalidArgument(
"Row index out of bounds".to_string(),
));
}
for &(col, value) in elements {
if col >= self.shape.dims()[1] {
return Err(TorshError::InvalidArgument(
"Column index out of bounds".to_string(),
));
}
if value.abs() > 1e-12 {
let was_present = self.rows[row].contains_key(&col);
self.rows[row].insert(col, value);
if !was_present {
self.cached_nnz += 1;
}
}
}
Ok(())
}
pub fn get_row_elements(&self, row: usize) -> TorshResult<Vec<(usize, f32)>> {
if row >= self.shape.dims()[0] {
return Err(TorshError::InvalidArgument(
"Row index out of bounds".to_string(),
));
}
Ok(self.rows[row]
.iter()
.map(|(&col, &val)| (col, val))
.collect())
}
pub fn clear_row(&mut self, row: usize) -> TorshResult<()> {
if row >= self.shape.dims()[0] {
return Err(TorshError::InvalidArgument(
"Row index out of bounds".to_string(),
));
}
let row_nnz = self.rows[row].len();
self.rows[row].clear();
self.cached_nnz -= row_nnz;
Ok(())
}
pub fn row_nnz(&self, row: usize) -> TorshResult<usize> {
if row >= self.shape.dims()[0] {
return Err(TorshError::InvalidArgument(
"Row index out of bounds".to_string(),
));
}
Ok(self.rows[row].len())
}
pub fn apply_inplace<F>(&mut self, mut f: F) -> TorshResult<()>
where
F: FnMut(f32) -> f32,
{
let mut elements_to_remove = Vec::new();
for (row_idx, row) in self.rows.iter_mut().enumerate() {
let mut new_values = Vec::new();
for (&col, &val) in row.iter() {
let new_val = f(val);
if new_val.abs() > 1e-12 {
new_values.push((col, new_val));
} else {
elements_to_remove.push((row_idx, col));
}
}
for (col, new_val) in new_values {
row.insert(col, new_val);
}
}
for (row_idx, col) in elements_to_remove {
self.rows[row_idx].remove(&col);
self.cached_nnz -= 1;
}
Ok(())
}
pub fn transpose(&self) -> TorshResult<Self> {
let transposed_shape = Shape::new(vec![self.shape.dims()[1], self.shape.dims()[0]]);
let mut transposed = Self::new(transposed_shape, self.dtype)?;
for (row_idx, row) in self.rows.iter().enumerate() {
for (&col_idx, &value) in row.iter() {
transposed.set(col_idx, row_idx, value)?;
}
}
Ok(transposed)
}
pub fn triplets(&self) -> Vec<(usize, usize, f32)> {
let mut triplets = Vec::with_capacity(self.cached_nnz);
for (row_idx, row) in self.rows.iter().enumerate() {
for (&col_idx, &value) in row.iter() {
triplets.push((row_idx, col_idx, value));
}
}
triplets
}
pub fn add_dsr(&mut self, other: &DsrTensor) -> TorshResult<()> {
if self.shape != *other.shape() {
return Err(TorshError::InvalidArgument(
"Shape mismatch for DSR addition".to_string(),
));
}
for (row_idx, other_row) in other.rows.iter().enumerate() {
for (&col_idx, &other_value) in other_row.iter() {
let current_value = self.rows[row_idx].get(&col_idx).copied().unwrap_or(0.0);
let new_value = current_value + other_value;
if new_value.abs() > 1e-12 {
let was_present = self.rows[row_idx].contains_key(&col_idx);
self.rows[row_idx].insert(col_idx, new_value);
if !was_present {
self.cached_nnz += 1;
}
} else if self.rows[row_idx].remove(&col_idx).is_some() {
self.cached_nnz -= 1;
}
}
}
Ok(())
}
pub fn scale(&mut self, scalar: f32) -> TorshResult<()> {
if scalar.abs() < 1e-12 {
for row in &mut self.rows {
row.clear();
}
self.cached_nnz = 0;
} else {
for row in &mut self.rows {
for value in row.values_mut() {
*value *= scalar;
}
}
}
Ok(())
}
}
impl SparseTensor for DsrTensor {
fn format(&self) -> SparseFormat {
SparseFormat::Csr }
fn shape(&self) -> &Shape {
&self.shape
}
fn dtype(&self) -> DType {
self.dtype
}
fn device(&self) -> DeviceType {
self.device
}
fn nnz(&self) -> usize {
self.cached_nnz
}
fn to_dense(&self) -> TorshResult<Tensor> {
use torsh_tensor::creation::zeros;
let dense = zeros::<f32>(self.shape.dims())?;
Ok(dense)
}
fn to_coo(&self) -> TorshResult<CooTensor> {
let triplets = self.triplets();
let (rows, cols, vals): (Vec<_>, Vec<_>, Vec<_>) = triplets.into_iter().fold(
(Vec::new(), Vec::new(), Vec::new()),
|(mut rs, mut cs, mut vs), (r, c, v)| {
rs.push(r);
cs.push(c);
vs.push(v);
(rs, cs, vs)
},
);
CooTensor::new(rows, cols, vals, self.shape.clone())
}
fn to_csr(&self) -> TorshResult<CsrTensor> {
let coo = self.to_coo()?;
CsrTensor::from_coo(&coo)
}
fn to_csc(&self) -> TorshResult<CscTensor> {
let coo = self.to_coo()?;
CscTensor::from_coo(&coo)
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_dsr_creation() {
let shape = Shape::new(vec![3, 4]);
let dsr = DsrTensor::new(shape, DType::F32).unwrap();
assert_eq!(dsr.nnz(), 0);
assert_eq!(dsr.shape().dims(), &[3, 4]);
}
#[test]
fn test_dsr_set_get() {
let shape = Shape::new(vec![3, 4]);
let mut dsr = DsrTensor::new(shape, DType::F32).unwrap();
dsr.set(1, 2, 5.0).unwrap();
assert_relative_eq!(dsr.get(1, 2).unwrap(), 5.0);
assert_eq!(dsr.nnz(), 1);
dsr.set(1, 2, 10.0).unwrap();
assert_relative_eq!(dsr.get(1, 2).unwrap(), 10.0);
assert_eq!(dsr.nnz(), 1);
dsr.set(1, 2, 0.0).unwrap();
assert_relative_eq!(dsr.get(1, 2).unwrap(), 0.0);
assert_eq!(dsr.nnz(), 0);
}
#[test]
fn test_dsr_dynamic_operations() {
let shape = Shape::new(vec![3, 3]);
let mut dsr = DsrTensor::new(shape, DType::F32).unwrap();
dsr.set(0, 0, 1.0).unwrap();
dsr.set(0, 2, 2.0).unwrap();
dsr.set(1, 1, 3.0).unwrap();
dsr.set(2, 0, 4.0).unwrap();
assert_eq!(dsr.nnz(), 4);
let row_0_elements = dsr.get_row_elements(0).unwrap();
assert_eq!(row_0_elements, vec![(0, 1.0), (2, 2.0)]);
assert_eq!(dsr.row_nnz(0).unwrap(), 2);
assert_eq!(dsr.row_nnz(1).unwrap(), 1);
dsr.clear_row(0).unwrap();
assert_eq!(dsr.nnz(), 2);
assert_eq!(dsr.row_nnz(0).unwrap(), 0);
}
#[test]
fn test_dsr_conversions() {
let shape = Shape::new(vec![3, 3]);
let mut dsr = DsrTensor::new(shape, DType::F32).unwrap();
dsr.set(0, 0, 1.0).unwrap();
dsr.set(1, 1, 2.0).unwrap();
dsr.set(2, 2, 3.0).unwrap();
let coo = dsr.to_coo().unwrap();
assert_eq!(coo.nnz(), 3);
let dsr2 = DsrTensor::from_coo(&coo).unwrap();
assert_eq!(dsr2.nnz(), 3);
assert_relative_eq!(dsr2.get(0, 0).unwrap(), 1.0);
assert_relative_eq!(dsr2.get(1, 1).unwrap(), 2.0);
assert_relative_eq!(dsr2.get(2, 2).unwrap(), 3.0);
}
#[test]
fn test_dsr_transpose() {
let shape = Shape::new(vec![2, 3]);
let mut dsr = DsrTensor::new(shape, DType::F32).unwrap();
dsr.set(0, 1, 5.0).unwrap();
dsr.set(1, 2, 10.0).unwrap();
let transposed = dsr.transpose().unwrap();
assert_eq!(transposed.shape().dims(), &[3, 2]);
assert_relative_eq!(transposed.get(1, 0).unwrap(), 5.0);
assert_relative_eq!(transposed.get(2, 1).unwrap(), 10.0);
}
#[test]
fn test_dsr_addition() {
let shape = Shape::new(vec![2, 2]);
let mut dsr1 = DsrTensor::new(shape.clone(), DType::F32).unwrap();
let mut dsr2 = DsrTensor::new(shape, DType::F32).unwrap();
dsr1.set(0, 0, 1.0).unwrap();
dsr1.set(1, 1, 2.0).unwrap();
dsr2.set(0, 0, 3.0).unwrap();
dsr2.set(0, 1, 4.0).unwrap();
dsr1.add_dsr(&dsr2).unwrap();
assert_relative_eq!(dsr1.get(0, 0).unwrap(), 4.0);
assert_relative_eq!(dsr1.get(0, 1).unwrap(), 4.0);
assert_relative_eq!(dsr1.get(1, 1).unwrap(), 2.0);
}
#[test]
fn test_dsr_scaling() {
let shape = Shape::new(vec![2, 2]);
let mut dsr = DsrTensor::new(shape, DType::F32).unwrap();
dsr.set(0, 0, 2.0).unwrap();
dsr.set(1, 1, 4.0).unwrap();
dsr.scale(0.5).unwrap();
assert_relative_eq!(dsr.get(0, 0).unwrap(), 1.0);
assert_relative_eq!(dsr.get(1, 1).unwrap(), 2.0);
assert_eq!(dsr.nnz(), 2);
dsr.scale(0.0).unwrap();
assert_eq!(dsr.nnz(), 0);
}
}