use super::error::PruningError;
use super::mask::SparsityMask;
use crate::autograd::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SparseFormat {
CSR,
COO,
Block { height: usize, width: usize },
}
#[derive(Debug, Clone)]
pub struct CSRTensor {
pub values: Vec<f32>,
pub col_indices: Vec<usize>,
pub row_ptrs: Vec<usize>,
pub nrows: usize,
pub ncols: usize,
}
impl CSRTensor {
pub fn new(
values: Vec<f32>,
col_indices: Vec<usize>,
row_ptrs: Vec<usize>,
nrows: usize,
ncols: usize,
) -> Result<Self, PruningError> {
if values.len() != col_indices.len() {
return Err(PruningError::InvalidMask {
reason: format!(
"values length ({}) != col_indices length ({})",
values.len(),
col_indices.len()
),
});
}
if row_ptrs.len() != nrows + 1 {
return Err(PruningError::InvalidMask {
reason: format!(
"row_ptrs length ({}) != nrows + 1 ({})",
row_ptrs.len(),
nrows + 1
),
});
}
for i in 1..row_ptrs.len() {
if row_ptrs[i] < row_ptrs[i - 1] {
return Err(PruningError::InvalidMask {
reason: format!(
"row_ptrs not monotonic at index {}: {} < {}",
i,
row_ptrs[i],
row_ptrs[i - 1]
),
});
}
}
for &col in &col_indices {
if col >= ncols {
return Err(PruningError::InvalidMask {
reason: format!("col_index {col} >= ncols {ncols}"),
});
}
}
Ok(Self {
values,
col_indices,
row_ptrs,
nrows,
ncols,
})
}
pub fn from_dense(tensor: &Tensor) -> Result<Self, PruningError> {
let shape = tensor.shape();
if shape.len() != 2 {
return Err(PruningError::ShapeMismatch {
expected: vec![0, 0],
got: shape.to_vec(),
});
}
let nrows = shape[0];
let ncols = shape[1];
let data = tensor.data();
let mut values = Vec::new();
let mut col_indices = Vec::new();
let mut row_ptrs = vec![0];
for row in 0..nrows {
for col in 0..ncols {
let val = data[row * ncols + col];
if val != 0.0 {
values.push(val);
col_indices.push(col);
}
}
row_ptrs.push(values.len());
}
Ok(Self {
values,
col_indices,
row_ptrs,
nrows,
ncols,
})
}
#[must_use]
pub fn to_dense(&self) -> Tensor {
let mut data = vec![0.0f32; self.nrows * self.ncols];
for row in 0..self.nrows {
let start = self.row_ptrs[row];
let end = self.row_ptrs[row + 1];
for idx in start..end {
let col = self.col_indices[idx];
let val = self.values[idx];
data[row * self.ncols + col] = val;
}
}
Tensor::from_vec(data, &[self.nrows, self.ncols])
}
#[must_use]
pub fn nnz(&self) -> usize {
self.values.len()
}
#[must_use]
pub fn sparsity(&self) -> f32 {
let total = self.nrows * self.ncols;
if total == 0 {
return 0.0;
}
1.0 - (self.nnz() as f32 / total as f32)
}
#[must_use]
pub fn shape(&self) -> [usize; 2] {
[self.nrows, self.ncols]
}
#[must_use]
pub fn get(&self, row: usize, col: usize) -> f32 {
if row >= self.nrows || col >= self.ncols {
return 0.0;
}
let start = self.row_ptrs[row];
let end = self.row_ptrs[row + 1];
for idx in start..end {
if self.col_indices[idx] == col {
return self.values[idx];
}
}
0.0
}
pub fn matvec(&self, x: &[f32]) -> Result<Vec<f32>, PruningError> {
if x.len() != self.ncols {
return Err(PruningError::ShapeMismatch {
expected: vec![self.ncols],
got: vec![x.len()],
});
}
let mut y = vec![0.0f32; self.nrows];
for row in 0..self.nrows {
let start = self.row_ptrs[row];
let end = self.row_ptrs[row + 1];
let mut sum = 0.0f32;
for idx in start..end {
let col = self.col_indices[idx];
sum += self.values[idx] * x[col];
}
y[row] = sum;
}
Ok(y)
}
#[must_use]
pub fn memory_bytes(&self) -> usize {
let values_bytes = self.values.len() * size_of::<f32>();
let col_indices_bytes = self.col_indices.len() * size_of::<usize>();
let row_ptrs_bytes = self.row_ptrs.len() * size_of::<usize>();
values_bytes + col_indices_bytes + row_ptrs_bytes
}
#[must_use]
pub fn dense_memory_bytes(&self) -> usize {
self.nrows * self.ncols * size_of::<f32>()
}
#[must_use]
pub fn memory_savings_ratio(&self) -> f32 {
let sparse = self.memory_bytes();
let dense = self.dense_memory_bytes();
if sparse == 0 {
return 1.0;
}
dense as f32 / sparse as f32
}
}
#[derive(Debug, Clone)]
pub struct COOTensor {
pub values: Vec<f32>,
pub row_indices: Vec<usize>,
pub col_indices: Vec<usize>,
pub nrows: usize,
pub ncols: usize,
}
impl COOTensor {
#[must_use]
pub fn new(nrows: usize, ncols: usize) -> Self {
Self {
values: Vec::new(),
row_indices: Vec::new(),
col_indices: Vec::new(),
nrows,
ncols,
}
}
pub fn from_dense(tensor: &Tensor) -> Result<Self, PruningError> {
let shape = tensor.shape();
if shape.len() != 2 {
return Err(PruningError::ShapeMismatch {
expected: vec![0, 0],
got: shape.to_vec(),
});
}
let nrows = shape[0];
let ncols = shape[1];
let data = tensor.data();
let mut coo = Self::new(nrows, ncols);
for row in 0..nrows {
for col in 0..ncols {
let val = data[row * ncols + col];
if val != 0.0 {
coo.push(row, col, val);
}
}
}
Ok(coo)
}
pub fn push(&mut self, row: usize, col: usize, value: f32) {
self.values.push(value);
self.row_indices.push(row);
self.col_indices.push(col);
}
#[must_use]
pub fn to_dense(&self) -> Tensor {
let mut data = vec![0.0f32; self.nrows * self.ncols];
for i in 0..self.values.len() {
let row = self.row_indices[i];
let col = self.col_indices[i];
if row < self.nrows && col < self.ncols {
data[row * self.ncols + col] = self.values[i];
}
}
Tensor::from_vec(data, &[self.nrows, self.ncols])
}
pub fn to_csr(&self) -> Result<CSRTensor, PruningError> {
let mut entries: Vec<(usize, usize, f32)> = self
.values
.iter()
.zip(self.row_indices.iter())
.zip(self.col_indices.iter())
.map(|((&v, &r), &c)| (r, c, v))
.collect();
entries.sort_by(|a, b| {
if a.0 == b.0 {
a.1.cmp(&b.1)
} else {
a.0.cmp(&b.0)
}
});
let mut values = Vec::with_capacity(entries.len());
let mut col_indices = Vec::with_capacity(entries.len());
let mut row_ptrs = vec![0usize; self.nrows + 1];
for (row, col, val) in entries {
values.push(val);
col_indices.push(col);
row_ptrs[row + 1] += 1;
}
for i in 1..row_ptrs.len() {
row_ptrs[i] += row_ptrs[i - 1];
}
CSRTensor::new(values, col_indices, row_ptrs, self.nrows, self.ncols)
}
#[must_use]
pub fn nnz(&self) -> usize {
self.values.len()
}
#[must_use]
pub fn sparsity(&self) -> f32 {
let total = self.nrows * self.ncols;
if total == 0 {
return 0.0;
}
1.0 - (self.nnz() as f32 / total as f32)
}
#[must_use]
pub fn shape(&self) -> [usize; 2] {
[self.nrows, self.ncols]
}
}
include!("apply_mask.rs");