Skip to main content

axonml_tensor/
sparse.rs

1//! Sparse Tensor Support
2//!
3//! Provides sparse tensor representations for memory-efficient storage and
4//! computation when tensors have many zero elements.
5//!
6//! # Formats
7//! - COO (Coordinate): Best for construction and random access
8//! - CSR (Compressed Sparse Row): Best for row-wise operations and matrix-vector products
9//!
10//! # Example
11//! ```rust,ignore
12//! use axonml_tensor::sparse::{SparseTensor, SparseFormat};
13//!
14//! // Create from COO format
15//! let indices = vec![(0, 1), (1, 0), (2, 2)];
16//! let values = vec![1.0, 2.0, 3.0];
17//! let sparse = SparseTensor::from_coo(&indices, &values, &[3, 3]);
18//!
19//! // Convert to dense
20//! let dense = sparse.to_dense();
21//! ```
22//!
23//! @version 0.1.0
24
25use crate::Tensor;
26
27// =============================================================================
28// Sparse Format
29// =============================================================================
30
31/// Sparse tensor storage format.
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum SparseFormat {
34    /// Coordinate format: (row, col, value) tuples
35    COO,
36    /// Compressed Sparse Row format
37    CSR,
38    /// Compressed Sparse Column format
39    CSC,
40}
41
42// =============================================================================
43// COO Sparse Tensor
44// =============================================================================
45
46/// Sparse tensor in COO (Coordinate) format.
47///
48/// Stores non-zero elements as a list of (index, value) pairs.
49/// Efficient for construction but less efficient for arithmetic.
50#[derive(Debug, Clone)]
51pub struct SparseCOO {
52    /// Row indices of non-zero elements
53    pub indices: Vec<Vec<usize>>,
54    /// Values of non-zero elements
55    pub values: Vec<f32>,
56    /// Shape of the tensor
57    pub shape: Vec<usize>,
58    /// Whether indices are sorted
59    pub is_coalesced: bool,
60}
61
62impl SparseCOO {
63    /// Creates a new sparse COO tensor.
64    ///
65    /// # Arguments
66    /// * `indices` - List of index tuples, one per dimension
67    /// * `values` - Non-zero values
68    /// * `shape` - Shape of the tensor
69    pub fn new(indices: Vec<Vec<usize>>, values: Vec<f32>, shape: Vec<usize>) -> Self {
70        assert_eq!(
71            indices.len(),
72            shape.len(),
73            "indices dimensions must match shape"
74        );
75        if !indices.is_empty() {
76            let nnz = indices[0].len();
77            for idx in &indices {
78                assert_eq!(idx.len(), nnz, "all index arrays must have same length");
79            }
80            assert_eq!(
81                values.len(),
82                nnz,
83                "values length must match number of indices"
84            );
85        }
86
87        Self {
88            indices,
89            values,
90            shape,
91            is_coalesced: false,
92        }
93    }
94
95    /// Creates from a list of 2D coordinate tuples.
96    pub fn from_coo_2d(coords: &[(usize, usize)], values: &[f32], shape: &[usize]) -> Self {
97        assert_eq!(shape.len(), 2, "shape must be 2D");
98        assert_eq!(
99            coords.len(),
100            values.len(),
101            "coords and values must have same length"
102        );
103
104        let rows: Vec<usize> = coords.iter().map(|(r, _)| *r).collect();
105        let cols: Vec<usize> = coords.iter().map(|(_, c)| *c).collect();
106
107        Self::new(vec![rows, cols], values.to_vec(), shape.to_vec())
108    }
109
110    /// Returns number of non-zero elements.
111    pub fn nnz(&self) -> usize {
112        self.values.len()
113    }
114
115    /// Returns the density (ratio of non-zeros to total elements).
116    pub fn density(&self) -> f32 {
117        let total: usize = self.shape.iter().product();
118        if total == 0 {
119            0.0
120        } else {
121            self.nnz() as f32 / total as f32
122        }
123    }
124
125    /// Returns the shape.
126    pub fn shape(&self) -> &[usize] {
127        &self.shape
128    }
129
130    /// Coalesces the sparse tensor (combines duplicate indices).
131    pub fn coalesce(&mut self) {
132        if self.is_coalesced || self.nnz() == 0 {
133            self.is_coalesced = true;
134            return;
135        }
136
137        // Create (indices, value) pairs and sort
138        let mut entries: Vec<(Vec<usize>, f32)> = (0..self.nnz())
139            .map(|i| {
140                let idx: Vec<usize> = self.indices.iter().map(|dim| dim[i]).collect();
141                (idx, self.values[i])
142            })
143            .collect();
144
145        entries.sort_by(|a, b| a.0.cmp(&b.0));
146
147        // Combine duplicates
148        let mut new_indices: Vec<Vec<usize>> = vec![Vec::new(); self.shape.len()];
149        let mut new_values = Vec::new();
150
151        let mut prev_idx: Option<Vec<usize>> = None;
152
153        for (idx, val) in entries {
154            if prev_idx.as_ref() == Some(&idx) {
155                // Duplicate: add to previous value
156                if let Some(last) = new_values.last_mut() {
157                    *last += val;
158                }
159            } else {
160                // New index
161                for (d, i) in idx.iter().enumerate() {
162                    new_indices[d].push(*i);
163                }
164                new_values.push(val);
165                prev_idx = Some(idx);
166            }
167        }
168
169        self.indices = new_indices;
170        self.values = new_values;
171        self.is_coalesced = true;
172    }
173
174    /// Converts to dense tensor.
175    pub fn to_dense(&self) -> Tensor<f32> {
176        let total: usize = self.shape.iter().product();
177        let mut data = vec![0.0f32; total];
178
179        for i in 0..self.nnz() {
180            let mut flat_idx = 0;
181            let mut stride = 1;
182            for d in (0..self.shape.len()).rev() {
183                flat_idx += self.indices[d][i] * stride;
184                stride *= self.shape[d];
185            }
186            data[flat_idx] += self.values[i];
187        }
188
189        Tensor::from_vec(data, &self.shape).unwrap()
190    }
191
192    /// Converts to CSR format (for 2D matrices).
193    pub fn to_csr(&self) -> SparseCSR {
194        assert_eq!(self.shape.len(), 2, "CSR only supports 2D tensors");
195
196        let mut coo = self.clone();
197        coo.coalesce();
198
199        let nrows = self.shape[0];
200        let nnz = coo.nnz();
201
202        let mut row_ptr = vec![0usize; nrows + 1];
203        let mut col_indices = Vec::with_capacity(nnz);
204        let mut values = Vec::with_capacity(nnz);
205
206        // Count entries per row
207        for &row in &coo.indices[0] {
208            row_ptr[row + 1] += 1;
209        }
210
211        // Cumulative sum
212        for i in 1..=nrows {
213            row_ptr[i] += row_ptr[i - 1];
214        }
215
216        // Sort by row, then column
217        let mut entries: Vec<(usize, usize, f32)> = (0..nnz)
218            .map(|i| (coo.indices[0][i], coo.indices[1][i], coo.values[i]))
219            .collect();
220        entries.sort_by_key(|(r, c, _)| (*r, *c));
221
222        for (_, col, val) in entries {
223            col_indices.push(col);
224            values.push(val);
225        }
226
227        SparseCSR {
228            row_ptr,
229            col_indices,
230            values,
231            shape: self.shape.clone(),
232        }
233    }
234}
235
236// =============================================================================
237// CSR Sparse Tensor
238// =============================================================================
239
240/// Sparse tensor in CSR (Compressed Sparse Row) format.
241///
242/// Efficient for row-wise operations and sparse matrix-vector products.
243#[derive(Debug, Clone)]
244pub struct SparseCSR {
245    /// Row pointers (length = nrows + 1)
246    pub row_ptr: Vec<usize>,
247    /// Column indices for each non-zero
248    pub col_indices: Vec<usize>,
249    /// Values for each non-zero
250    pub values: Vec<f32>,
251    /// Shape [nrows, ncols]
252    pub shape: Vec<usize>,
253}
254
255impl SparseCSR {
256    /// Creates a new CSR sparse matrix.
257    pub fn new(
258        row_ptr: Vec<usize>,
259        col_indices: Vec<usize>,
260        values: Vec<f32>,
261        shape: Vec<usize>,
262    ) -> Self {
263        assert_eq!(shape.len(), 2, "CSR only supports 2D tensors");
264        assert_eq!(
265            row_ptr.len(),
266            shape[0] + 1,
267            "row_ptr length must be nrows + 1"
268        );
269        assert_eq!(
270            col_indices.len(),
271            values.len(),
272            "col_indices and values must match"
273        );
274
275        Self {
276            row_ptr,
277            col_indices,
278            values,
279            shape,
280        }
281    }
282
283    /// Returns number of non-zero elements.
284    pub fn nnz(&self) -> usize {
285        self.values.len()
286    }
287
288    /// Returns number of rows.
289    pub fn nrows(&self) -> usize {
290        self.shape[0]
291    }
292
293    /// Returns number of columns.
294    pub fn ncols(&self) -> usize {
295        self.shape[1]
296    }
297
298    /// Returns the density.
299    pub fn density(&self) -> f32 {
300        let total = self.nrows() * self.ncols();
301        if total == 0 {
302            0.0
303        } else {
304            self.nnz() as f32 / total as f32
305        }
306    }
307
308    /// Gets entries for a specific row.
309    pub fn row(&self, row_idx: usize) -> impl Iterator<Item = (usize, f32)> + '_ {
310        let start = self.row_ptr[row_idx];
311        let end = self.row_ptr[row_idx + 1];
312        (start..end).map(move |i| (self.col_indices[i], self.values[i]))
313    }
314
315    /// Sparse matrix-vector multiplication: A @ x.
316    pub fn matvec(&self, x: &[f32]) -> Vec<f32> {
317        assert_eq!(x.len(), self.ncols(), "vector length must match ncols");
318
319        let mut result = vec![0.0f32; self.nrows()];
320
321        for row in 0..self.nrows() {
322            let start = self.row_ptr[row];
323            let end = self.row_ptr[row + 1];
324
325            for i in start..end {
326                let col = self.col_indices[i];
327                let val = self.values[i];
328                result[row] += val * x[col];
329            }
330        }
331
332        result
333    }
334
335    /// Sparse matrix-matrix multiplication: A @ B (where B is dense).
336    pub fn matmul_dense(&self, b: &Tensor<f32>) -> Tensor<f32> {
337        let b_shape = b.shape();
338        assert_eq!(b_shape[0], self.ncols(), "inner dimensions must match");
339
340        let m = self.nrows();
341        let n = b_shape[1];
342        let b_data = b.to_vec();
343
344        let mut result = vec![0.0f32; m * n];
345
346        for row in 0..m {
347            for (col, val) in self.row(row) {
348                for j in 0..n {
349                    result[row * n + j] += val * b_data[col * n + j];
350                }
351            }
352        }
353
354        Tensor::from_vec(result, &[m, n]).unwrap()
355    }
356
357    /// Converts to dense tensor.
358    pub fn to_dense(&self) -> Tensor<f32> {
359        let mut data = vec![0.0f32; self.nrows() * self.ncols()];
360
361        for row in 0..self.nrows() {
362            for (col, val) in self.row(row) {
363                data[row * self.ncols() + col] = val;
364            }
365        }
366
367        Tensor::from_vec(data, &self.shape).unwrap()
368    }
369
370    /// Converts to COO format.
371    pub fn to_coo(&self) -> SparseCOO {
372        let mut rows = Vec::with_capacity(self.nnz());
373        let mut cols = Vec::with_capacity(self.nnz());
374
375        for row in 0..self.nrows() {
376            let start = self.row_ptr[row];
377            let end = self.row_ptr[row + 1];
378            for i in start..end {
379                rows.push(row);
380                cols.push(self.col_indices[i]);
381            }
382        }
383
384        SparseCOO {
385            indices: vec![rows, cols],
386            values: self.values.clone(),
387            shape: self.shape.clone(),
388            is_coalesced: true,
389        }
390    }
391}
392
393// =============================================================================
394// SparseTensor (Unified Interface)
395// =============================================================================
396
397/// Unified sparse tensor interface supporting multiple formats.
398#[derive(Debug, Clone)]
399pub enum SparseTensor {
400    /// COO format
401    COO(SparseCOO),
402    /// CSR format
403    CSR(SparseCSR),
404}
405
406impl SparseTensor {
407    /// Creates a sparse tensor from COO data.
408    pub fn from_coo(indices: Vec<Vec<usize>>, values: Vec<f32>, shape: Vec<usize>) -> Self {
409        Self::COO(SparseCOO::new(indices, values, shape))
410    }
411
412    /// Creates a 2D sparse tensor from coordinate list.
413    pub fn from_coords(coords: &[(usize, usize)], values: &[f32], shape: &[usize]) -> Self {
414        Self::COO(SparseCOO::from_coo_2d(coords, values, shape))
415    }
416
417    /// Creates from a dense tensor, keeping only non-zero elements.
418    pub fn from_dense(tensor: &Tensor<f32>, threshold: f32) -> Self {
419        let data = tensor.to_vec();
420        let shape = tensor.shape().to_vec();
421
422        let mut indices: Vec<Vec<usize>> = vec![Vec::new(); shape.len()];
423        let mut values = Vec::new();
424
425        let strides: Vec<usize> = {
426            let mut s = vec![1; shape.len()];
427            for i in (0..shape.len() - 1).rev() {
428                s[i] = s[i + 1] * shape[i + 1];
429            }
430            s
431        };
432
433        for (flat_idx, &val) in data.iter().enumerate() {
434            if val.abs() > threshold {
435                let mut idx = flat_idx;
436                for (d, &stride) in strides.iter().enumerate() {
437                    indices[d].push(idx / stride);
438                    idx %= stride;
439                }
440                values.push(val);
441            }
442        }
443
444        Self::COO(SparseCOO::new(indices, values, shape))
445    }
446
447    /// Creates an identity matrix in sparse format.
448    pub fn eye(n: usize) -> Self {
449        let indices: Vec<usize> = (0..n).collect();
450        let values = vec![1.0f32; n];
451        Self::COO(SparseCOO::new(
452            vec![indices.clone(), indices],
453            values,
454            vec![n, n],
455        ))
456    }
457
458    /// Creates a sparse diagonal matrix.
459    pub fn diag(values: &[f32]) -> Self {
460        let n = values.len();
461        let indices: Vec<usize> = (0..n).collect();
462        Self::COO(SparseCOO::new(
463            vec![indices.clone(), indices],
464            values.to_vec(),
465            vec![n, n],
466        ))
467    }
468
469    /// Returns number of non-zero elements.
470    pub fn nnz(&self) -> usize {
471        match self {
472            Self::COO(coo) => coo.nnz(),
473            Self::CSR(csr) => csr.nnz(),
474        }
475    }
476
477    /// Returns the shape.
478    pub fn shape(&self) -> &[usize] {
479        match self {
480            Self::COO(coo) => &coo.shape,
481            Self::CSR(csr) => &csr.shape,
482        }
483    }
484
485    /// Returns the density.
486    pub fn density(&self) -> f32 {
487        match self {
488            Self::COO(coo) => coo.density(),
489            Self::CSR(csr) => csr.density(),
490        }
491    }
492
493    /// Converts to dense tensor.
494    pub fn to_dense(&self) -> Tensor<f32> {
495        match self {
496            Self::COO(coo) => coo.to_dense(),
497            Self::CSR(csr) => csr.to_dense(),
498        }
499    }
500
501    /// Converts to CSR format.
502    pub fn to_csr(&self) -> SparseCSR {
503        match self {
504            Self::COO(coo) => coo.to_csr(),
505            Self::CSR(csr) => csr.clone(),
506        }
507    }
508
509    /// Converts to COO format.
510    pub fn to_coo(&self) -> SparseCOO {
511        match self {
512            Self::COO(coo) => coo.clone(),
513            Self::CSR(csr) => csr.to_coo(),
514        }
515    }
516
517    /// Sparse matrix-vector multiplication.
518    pub fn matvec(&self, x: &[f32]) -> Vec<f32> {
519        match self {
520            Self::COO(coo) => coo.to_csr().matvec(x),
521            Self::CSR(csr) => csr.matvec(x),
522        }
523    }
524
525    /// Sparse-dense matrix multiplication.
526    pub fn matmul(&self, dense: &Tensor<f32>) -> Tensor<f32> {
527        match self {
528            Self::COO(coo) => coo.to_csr().matmul_dense(dense),
529            Self::CSR(csr) => csr.matmul_dense(dense),
530        }
531    }
532
533    /// Element-wise multiplication with a scalar.
534    pub fn mul_scalar(&self, scalar: f32) -> Self {
535        match self {
536            Self::COO(coo) => {
537                let values: Vec<f32> = coo.values.iter().map(|v| v * scalar).collect();
538                Self::COO(SparseCOO::new(
539                    coo.indices.clone(),
540                    values,
541                    coo.shape.clone(),
542                ))
543            }
544            Self::CSR(csr) => {
545                let values: Vec<f32> = csr.values.iter().map(|v| v * scalar).collect();
546                Self::CSR(SparseCSR::new(
547                    csr.row_ptr.clone(),
548                    csr.col_indices.clone(),
549                    values,
550                    csr.shape.clone(),
551                ))
552            }
553        }
554    }
555}
556
557// =============================================================================
558// Tests
559// =============================================================================
560
561#[cfg(test)]
562mod tests {
563    use super::*;
564
565    #[test]
566    fn test_sparse_coo_creation() {
567        let indices = vec![vec![0, 1, 2], vec![1, 0, 2]];
568        let values = vec![1.0, 2.0, 3.0];
569        let sparse = SparseCOO::new(indices, values, vec![3, 3]);
570
571        assert_eq!(sparse.nnz(), 3);
572        assert_eq!(sparse.shape(), &[3, 3]);
573    }
574
575    #[test]
576    fn test_sparse_coo_to_dense() {
577        let coords = vec![(0, 1), (1, 0), (2, 2)];
578        let values = vec![1.0, 2.0, 3.0];
579        let sparse = SparseCOO::from_coo_2d(&coords, &values, &[3, 3]);
580
581        let dense = sparse.to_dense();
582        let data = dense.to_vec();
583
584        assert_eq!(data[0 * 3 + 1], 1.0); // (0, 1)
585        assert_eq!(data[1 * 3 + 0], 2.0); // (1, 0)
586        assert_eq!(data[2 * 3 + 2], 3.0); // (2, 2)
587    }
588
589    #[test]
590    fn test_sparse_coo_coalesce() {
591        let indices = vec![vec![0, 0, 1], vec![0, 0, 1]];
592        let values = vec![1.0, 2.0, 3.0];
593        let mut sparse = SparseCOO::new(indices, values, vec![2, 2]);
594
595        sparse.coalesce();
596
597        assert_eq!(sparse.nnz(), 2); // Duplicates combined
598        let dense = sparse.to_dense();
599        assert_eq!(dense.to_vec()[0], 3.0); // 1.0 + 2.0
600    }
601
602    #[test]
603    fn test_sparse_csr_creation() {
604        let row_ptr = vec![0, 1, 2, 3];
605        let col_indices = vec![1, 0, 2];
606        let values = vec![1.0, 2.0, 3.0];
607        let csr = SparseCSR::new(row_ptr, col_indices, values, vec![3, 3]);
608
609        assert_eq!(csr.nnz(), 3);
610        assert_eq!(csr.nrows(), 3);
611        assert_eq!(csr.ncols(), 3);
612    }
613
614    #[test]
615    fn test_sparse_csr_matvec() {
616        // Matrix: [[1, 0], [0, 2]]
617        let row_ptr = vec![0, 1, 2];
618        let col_indices = vec![0, 1];
619        let values = vec![1.0, 2.0];
620        let csr = SparseCSR::new(row_ptr, col_indices, values, vec![2, 2]);
621
622        let x = vec![1.0, 2.0];
623        let result = csr.matvec(&x);
624
625        assert_eq!(result, vec![1.0, 4.0]);
626    }
627
628    #[test]
629    fn test_sparse_coo_to_csr() {
630        let coords = vec![(0, 1), (1, 0), (2, 2)];
631        let values = vec![1.0, 2.0, 3.0];
632        let coo = SparseCOO::from_coo_2d(&coords, &values, &[3, 3]);
633
634        let csr = coo.to_csr();
635
636        assert_eq!(csr.nnz(), 3);
637        assert_eq!(csr.nrows(), 3);
638    }
639
640    #[test]
641    fn test_sparse_tensor_from_dense() {
642        let dense = Tensor::from_vec(vec![0.0, 1.0, 0.0, 2.0], &[2, 2]).unwrap();
643        let sparse = SparseTensor::from_dense(&dense, 0.0);
644
645        assert_eq!(sparse.nnz(), 2);
646    }
647
648    #[test]
649    fn test_sparse_tensor_eye() {
650        let eye = SparseTensor::eye(3);
651        let dense = eye.to_dense();
652        let data = dense.to_vec();
653
654        assert_eq!(data[0], 1.0);
655        assert_eq!(data[4], 1.0);
656        assert_eq!(data[8], 1.0);
657        assert_eq!(data[1], 0.0);
658    }
659
660    #[test]
661    fn test_sparse_tensor_diag() {
662        let diag = SparseTensor::diag(&[1.0, 2.0, 3.0]);
663        let dense = diag.to_dense();
664        let data = dense.to_vec();
665
666        assert_eq!(data[0], 1.0);
667        assert_eq!(data[4], 2.0);
668        assert_eq!(data[8], 3.0);
669    }
670
671    #[test]
672    fn test_sparse_density() {
673        let coords = vec![(0, 0), (1, 1)];
674        let values = vec![1.0, 2.0];
675        let sparse = SparseTensor::from_coords(&coords, &values, &[4, 4]);
676
677        assert!((sparse.density() - 0.125).abs() < 1e-6); // 2/16
678    }
679
680    #[test]
681    fn test_sparse_mul_scalar() {
682        let coords = vec![(0, 0)];
683        let values = vec![2.0];
684        let sparse = SparseTensor::from_coords(&coords, &values, &[2, 2]);
685
686        let scaled = sparse.mul_scalar(3.0);
687        let dense = scaled.to_dense();
688
689        assert_eq!(dense.to_vec()[0], 6.0);
690    }
691
692    #[test]
693    fn test_sparse_matmul() {
694        // Sparse: [[1, 0], [0, 2]]
695        let coords = vec![(0, 0), (1, 1)];
696        let values = vec![1.0, 2.0];
697        let sparse = SparseTensor::from_coords(&coords, &values, &[2, 2]);
698
699        // Dense: [[1, 2], [3, 4]]
700        let dense = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
701
702        let result = sparse.matmul(&dense);
703        let data = result.to_vec();
704
705        // [[1, 0], [0, 2]] @ [[1, 2], [3, 4]] = [[1, 2], [6, 8]]
706        assert_eq!(data, vec![1.0, 2.0, 6.0, 8.0]);
707    }
708}