Skip to main content

axonml_tensor/
sparse.rs

1//! Sparse Tensor Support
2//!
3//! Provides sparse tensor representations for memory-efficient storage and
4//! computation when tensors have many zero elements.
5//!
6//! # Formats
7//! - COO (Coordinate): Best for construction and random access
8//! - CSR (Compressed Sparse Row): Best for row-wise operations and matrix-vector products
9//!
10//! # Example
11//! ```rust,ignore
12//! use axonml_tensor::sparse::{SparseTensor, SparseFormat};
13//!
14//! // Create from COO format
15//! let indices = vec![(0, 1), (1, 0), (2, 2)];
16//! let values = vec![1.0, 2.0, 3.0];
17//! let sparse = SparseTensor::from_coo(&indices, &values, &[3, 3]);
18//!
19//! // Convert to dense
20//! let dense = sparse.to_dense();
21//! ```
22//!
23//! @version 0.1.0
24
25use crate::Tensor;
26
27// =============================================================================
28// Sparse Format
29// =============================================================================
30
31/// Sparse tensor storage format.
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum SparseFormat {
34    /// Coordinate format: (row, col, value) tuples
35    COO,
36    /// Compressed Sparse Row format
37    CSR,
38    /// Compressed Sparse Column format
39    CSC,
40}
41
42// =============================================================================
43// COO Sparse Tensor
44// =============================================================================
45
46/// Sparse tensor in COO (Coordinate) format.
47///
48/// Stores non-zero elements as a list of (index, value) pairs.
49/// Efficient for construction but less efficient for arithmetic.
50#[derive(Debug, Clone)]
51pub struct SparseCOO {
52    /// Row indices of non-zero elements
53    pub indices: Vec<Vec<usize>>,
54    /// Values of non-zero elements
55    pub values: Vec<f32>,
56    /// Shape of the tensor
57    pub shape: Vec<usize>,
58    /// Whether indices are sorted
59    pub is_coalesced: bool,
60}
61
62impl SparseCOO {
63    /// Creates a new sparse COO tensor.
64    ///
65    /// # Arguments
66    /// * `indices` - List of index tuples, one per dimension
67    /// * `values` - Non-zero values
68    /// * `shape` - Shape of the tensor
69    pub fn new(indices: Vec<Vec<usize>>, values: Vec<f32>, shape: Vec<usize>) -> Self {
70        assert_eq!(indices.len(), shape.len(), "indices dimensions must match shape");
71        if !indices.is_empty() {
72            let nnz = indices[0].len();
73            for idx in &indices {
74                assert_eq!(idx.len(), nnz, "all index arrays must have same length");
75            }
76            assert_eq!(values.len(), nnz, "values length must match number of indices");
77        }
78
79        Self {
80            indices,
81            values,
82            shape,
83            is_coalesced: false,
84        }
85    }
86
87    /// Creates from a list of 2D coordinate tuples.
88    pub fn from_coo_2d(coords: &[(usize, usize)], values: &[f32], shape: &[usize]) -> Self {
89        assert_eq!(shape.len(), 2, "shape must be 2D");
90        assert_eq!(coords.len(), values.len(), "coords and values must have same length");
91
92        let rows: Vec<usize> = coords.iter().map(|(r, _)| *r).collect();
93        let cols: Vec<usize> = coords.iter().map(|(_, c)| *c).collect();
94
95        Self::new(vec![rows, cols], values.to_vec(), shape.to_vec())
96    }
97
98    /// Returns number of non-zero elements.
99    pub fn nnz(&self) -> usize {
100        self.values.len()
101    }
102
103    /// Returns the density (ratio of non-zeros to total elements).
104    pub fn density(&self) -> f32 {
105        let total: usize = self.shape.iter().product();
106        if total == 0 {
107            0.0
108        } else {
109            self.nnz() as f32 / total as f32
110        }
111    }
112
113    /// Returns the shape.
114    pub fn shape(&self) -> &[usize] {
115        &self.shape
116    }
117
118    /// Coalesces the sparse tensor (combines duplicate indices).
119    pub fn coalesce(&mut self) {
120        if self.is_coalesced || self.nnz() == 0 {
121            self.is_coalesced = true;
122            return;
123        }
124
125        // Create (indices, value) pairs and sort
126        let mut entries: Vec<(Vec<usize>, f32)> = (0..self.nnz())
127            .map(|i| {
128                let idx: Vec<usize> = self.indices.iter().map(|dim| dim[i]).collect();
129                (idx, self.values[i])
130            })
131            .collect();
132
133        entries.sort_by(|a, b| a.0.cmp(&b.0));
134
135        // Combine duplicates
136        let mut new_indices: Vec<Vec<usize>> = vec![Vec::new(); self.shape.len()];
137        let mut new_values = Vec::new();
138
139        let mut prev_idx: Option<Vec<usize>> = None;
140
141        for (idx, val) in entries {
142            if prev_idx.as_ref() == Some(&idx) {
143                // Duplicate: add to previous value
144                if let Some(last) = new_values.last_mut() {
145                    *last += val;
146                }
147            } else {
148                // New index
149                for (d, i) in idx.iter().enumerate() {
150                    new_indices[d].push(*i);
151                }
152                new_values.push(val);
153                prev_idx = Some(idx);
154            }
155        }
156
157        self.indices = new_indices;
158        self.values = new_values;
159        self.is_coalesced = true;
160    }
161
162    /// Converts to dense tensor.
163    pub fn to_dense(&self) -> Tensor<f32> {
164        let total: usize = self.shape.iter().product();
165        let mut data = vec![0.0f32; total];
166
167        for i in 0..self.nnz() {
168            let mut flat_idx = 0;
169            let mut stride = 1;
170            for d in (0..self.shape.len()).rev() {
171                flat_idx += self.indices[d][i] * stride;
172                stride *= self.shape[d];
173            }
174            data[flat_idx] += self.values[i];
175        }
176
177        Tensor::from_vec(data, &self.shape).unwrap()
178    }
179
180    /// Converts to CSR format (for 2D matrices).
181    pub fn to_csr(&self) -> SparseCSR {
182        assert_eq!(self.shape.len(), 2, "CSR only supports 2D tensors");
183
184        let mut coo = self.clone();
185        coo.coalesce();
186
187        let nrows = self.shape[0];
188        let nnz = coo.nnz();
189
190        let mut row_ptr = vec![0usize; nrows + 1];
191        let mut col_indices = Vec::with_capacity(nnz);
192        let mut values = Vec::with_capacity(nnz);
193
194        // Count entries per row
195        for &row in &coo.indices[0] {
196            row_ptr[row + 1] += 1;
197        }
198
199        // Cumulative sum
200        for i in 1..=nrows {
201            row_ptr[i] += row_ptr[i - 1];
202        }
203
204        // Sort by row, then column
205        let mut entries: Vec<(usize, usize, f32)> = (0..nnz)
206            .map(|i| (coo.indices[0][i], coo.indices[1][i], coo.values[i]))
207            .collect();
208        entries.sort_by_key(|(r, c, _)| (*r, *c));
209
210        for (_, col, val) in entries {
211            col_indices.push(col);
212            values.push(val);
213        }
214
215        SparseCSR {
216            row_ptr,
217            col_indices,
218            values,
219            shape: self.shape.clone(),
220        }
221    }
222}
223
224// =============================================================================
225// CSR Sparse Tensor
226// =============================================================================
227
228/// Sparse tensor in CSR (Compressed Sparse Row) format.
229///
230/// Efficient for row-wise operations and sparse matrix-vector products.
231#[derive(Debug, Clone)]
232pub struct SparseCSR {
233    /// Row pointers (length = nrows + 1)
234    pub row_ptr: Vec<usize>,
235    /// Column indices for each non-zero
236    pub col_indices: Vec<usize>,
237    /// Values for each non-zero
238    pub values: Vec<f32>,
239    /// Shape [nrows, ncols]
240    pub shape: Vec<usize>,
241}
242
243impl SparseCSR {
244    /// Creates a new CSR sparse matrix.
245    pub fn new(
246        row_ptr: Vec<usize>,
247        col_indices: Vec<usize>,
248        values: Vec<f32>,
249        shape: Vec<usize>,
250    ) -> Self {
251        assert_eq!(shape.len(), 2, "CSR only supports 2D tensors");
252        assert_eq!(row_ptr.len(), shape[0] + 1, "row_ptr length must be nrows + 1");
253        assert_eq!(col_indices.len(), values.len(), "col_indices and values must match");
254
255        Self {
256            row_ptr,
257            col_indices,
258            values,
259            shape,
260        }
261    }
262
263    /// Returns number of non-zero elements.
264    pub fn nnz(&self) -> usize {
265        self.values.len()
266    }
267
268    /// Returns number of rows.
269    pub fn nrows(&self) -> usize {
270        self.shape[0]
271    }
272
273    /// Returns number of columns.
274    pub fn ncols(&self) -> usize {
275        self.shape[1]
276    }
277
278    /// Returns the density.
279    pub fn density(&self) -> f32 {
280        let total = self.nrows() * self.ncols();
281        if total == 0 {
282            0.0
283        } else {
284            self.nnz() as f32 / total as f32
285        }
286    }
287
288    /// Gets entries for a specific row.
289    pub fn row(&self, row_idx: usize) -> impl Iterator<Item = (usize, f32)> + '_ {
290        let start = self.row_ptr[row_idx];
291        let end = self.row_ptr[row_idx + 1];
292        (start..end).map(move |i| (self.col_indices[i], self.values[i]))
293    }
294
295    /// Sparse matrix-vector multiplication: A @ x.
296    pub fn matvec(&self, x: &[f32]) -> Vec<f32> {
297        assert_eq!(x.len(), self.ncols(), "vector length must match ncols");
298
299        let mut result = vec![0.0f32; self.nrows()];
300
301        for row in 0..self.nrows() {
302            let start = self.row_ptr[row];
303            let end = self.row_ptr[row + 1];
304
305            for i in start..end {
306                let col = self.col_indices[i];
307                let val = self.values[i];
308                result[row] += val * x[col];
309            }
310        }
311
312        result
313    }
314
315    /// Sparse matrix-matrix multiplication: A @ B (where B is dense).
316    pub fn matmul_dense(&self, b: &Tensor<f32>) -> Tensor<f32> {
317        let b_shape = b.shape();
318        assert_eq!(b_shape[0], self.ncols(), "inner dimensions must match");
319
320        let m = self.nrows();
321        let n = b_shape[1];
322        let b_data = b.to_vec();
323
324        let mut result = vec![0.0f32; m * n];
325
326        for row in 0..m {
327            for (col, val) in self.row(row) {
328                for j in 0..n {
329                    result[row * n + j] += val * b_data[col * n + j];
330                }
331            }
332        }
333
334        Tensor::from_vec(result, &[m, n]).unwrap()
335    }
336
337    /// Converts to dense tensor.
338    pub fn to_dense(&self) -> Tensor<f32> {
339        let mut data = vec![0.0f32; self.nrows() * self.ncols()];
340
341        for row in 0..self.nrows() {
342            for (col, val) in self.row(row) {
343                data[row * self.ncols() + col] = val;
344            }
345        }
346
347        Tensor::from_vec(data, &self.shape).unwrap()
348    }
349
350    /// Converts to COO format.
351    pub fn to_coo(&self) -> SparseCOO {
352        let mut rows = Vec::with_capacity(self.nnz());
353        let mut cols = Vec::with_capacity(self.nnz());
354
355        for row in 0..self.nrows() {
356            let start = self.row_ptr[row];
357            let end = self.row_ptr[row + 1];
358            for i in start..end {
359                rows.push(row);
360                cols.push(self.col_indices[i]);
361            }
362        }
363
364        SparseCOO {
365            indices: vec![rows, cols],
366            values: self.values.clone(),
367            shape: self.shape.clone(),
368            is_coalesced: true,
369        }
370    }
371}
372
373// =============================================================================
374// SparseTensor (Unified Interface)
375// =============================================================================
376
377/// Unified sparse tensor interface supporting multiple formats.
378#[derive(Debug, Clone)]
379pub enum SparseTensor {
380    /// COO format
381    COO(SparseCOO),
382    /// CSR format
383    CSR(SparseCSR),
384}
385
386impl SparseTensor {
387    /// Creates a sparse tensor from COO data.
388    pub fn from_coo(indices: Vec<Vec<usize>>, values: Vec<f32>, shape: Vec<usize>) -> Self {
389        Self::COO(SparseCOO::new(indices, values, shape))
390    }
391
392    /// Creates a 2D sparse tensor from coordinate list.
393    pub fn from_coords(coords: &[(usize, usize)], values: &[f32], shape: &[usize]) -> Self {
394        Self::COO(SparseCOO::from_coo_2d(coords, values, shape))
395    }
396
397    /// Creates from a dense tensor, keeping only non-zero elements.
398    pub fn from_dense(tensor: &Tensor<f32>, threshold: f32) -> Self {
399        let data = tensor.to_vec();
400        let shape = tensor.shape().to_vec();
401
402        let mut indices: Vec<Vec<usize>> = vec![Vec::new(); shape.len()];
403        let mut values = Vec::new();
404
405        let strides: Vec<usize> = {
406            let mut s = vec![1; shape.len()];
407            for i in (0..shape.len() - 1).rev() {
408                s[i] = s[i + 1] * shape[i + 1];
409            }
410            s
411        };
412
413        for (flat_idx, &val) in data.iter().enumerate() {
414            if val.abs() > threshold {
415                let mut idx = flat_idx;
416                for (d, &stride) in strides.iter().enumerate() {
417                    indices[d].push(idx / stride);
418                    idx %= stride;
419                }
420                values.push(val);
421            }
422        }
423
424        Self::COO(SparseCOO::new(indices, values, shape))
425    }
426
427    /// Creates an identity matrix in sparse format.
428    pub fn eye(n: usize) -> Self {
429        let indices: Vec<usize> = (0..n).collect();
430        let values = vec![1.0f32; n];
431        Self::COO(SparseCOO::new(
432            vec![indices.clone(), indices],
433            values,
434            vec![n, n],
435        ))
436    }
437
438    /// Creates a sparse diagonal matrix.
439    pub fn diag(values: &[f32]) -> Self {
440        let n = values.len();
441        let indices: Vec<usize> = (0..n).collect();
442        Self::COO(SparseCOO::new(
443            vec![indices.clone(), indices],
444            values.to_vec(),
445            vec![n, n],
446        ))
447    }
448
449    /// Returns number of non-zero elements.
450    pub fn nnz(&self) -> usize {
451        match self {
452            Self::COO(coo) => coo.nnz(),
453            Self::CSR(csr) => csr.nnz(),
454        }
455    }
456
457    /// Returns the shape.
458    pub fn shape(&self) -> &[usize] {
459        match self {
460            Self::COO(coo) => &coo.shape,
461            Self::CSR(csr) => &csr.shape,
462        }
463    }
464
465    /// Returns the density.
466    pub fn density(&self) -> f32 {
467        match self {
468            Self::COO(coo) => coo.density(),
469            Self::CSR(csr) => csr.density(),
470        }
471    }
472
473    /// Converts to dense tensor.
474    pub fn to_dense(&self) -> Tensor<f32> {
475        match self {
476            Self::COO(coo) => coo.to_dense(),
477            Self::CSR(csr) => csr.to_dense(),
478        }
479    }
480
481    /// Converts to CSR format.
482    pub fn to_csr(&self) -> SparseCSR {
483        match self {
484            Self::COO(coo) => coo.to_csr(),
485            Self::CSR(csr) => csr.clone(),
486        }
487    }
488
489    /// Converts to COO format.
490    pub fn to_coo(&self) -> SparseCOO {
491        match self {
492            Self::COO(coo) => coo.clone(),
493            Self::CSR(csr) => csr.to_coo(),
494        }
495    }
496
497    /// Sparse matrix-vector multiplication.
498    pub fn matvec(&self, x: &[f32]) -> Vec<f32> {
499        match self {
500            Self::COO(coo) => coo.to_csr().matvec(x),
501            Self::CSR(csr) => csr.matvec(x),
502        }
503    }
504
505    /// Sparse-dense matrix multiplication.
506    pub fn matmul(&self, dense: &Tensor<f32>) -> Tensor<f32> {
507        match self {
508            Self::COO(coo) => coo.to_csr().matmul_dense(dense),
509            Self::CSR(csr) => csr.matmul_dense(dense),
510        }
511    }
512
513    /// Element-wise multiplication with a scalar.
514    pub fn mul_scalar(&self, scalar: f32) -> Self {
515        match self {
516            Self::COO(coo) => {
517                let values: Vec<f32> = coo.values.iter().map(|v| v * scalar).collect();
518                Self::COO(SparseCOO::new(coo.indices.clone(), values, coo.shape.clone()))
519            }
520            Self::CSR(csr) => {
521                let values: Vec<f32> = csr.values.iter().map(|v| v * scalar).collect();
522                Self::CSR(SparseCSR::new(
523                    csr.row_ptr.clone(),
524                    csr.col_indices.clone(),
525                    values,
526                    csr.shape.clone(),
527                ))
528            }
529        }
530    }
531}
532
533// =============================================================================
534// Tests
535// =============================================================================
536
537#[cfg(test)]
538mod tests {
539    use super::*;
540
541    #[test]
542    fn test_sparse_coo_creation() {
543        let indices = vec![vec![0, 1, 2], vec![1, 0, 2]];
544        let values = vec![1.0, 2.0, 3.0];
545        let sparse = SparseCOO::new(indices, values, vec![3, 3]);
546
547        assert_eq!(sparse.nnz(), 3);
548        assert_eq!(sparse.shape(), &[3, 3]);
549    }
550
551    #[test]
552    fn test_sparse_coo_to_dense() {
553        let coords = vec![(0, 1), (1, 0), (2, 2)];
554        let values = vec![1.0, 2.0, 3.0];
555        let sparse = SparseCOO::from_coo_2d(&coords, &values, &[3, 3]);
556
557        let dense = sparse.to_dense();
558        let data = dense.to_vec();
559
560        assert_eq!(data[0 * 3 + 1], 1.0); // (0, 1)
561        assert_eq!(data[1 * 3 + 0], 2.0); // (1, 0)
562        assert_eq!(data[2 * 3 + 2], 3.0); // (2, 2)
563    }
564
565    #[test]
566    fn test_sparse_coo_coalesce() {
567        let indices = vec![vec![0, 0, 1], vec![0, 0, 1]];
568        let values = vec![1.0, 2.0, 3.0];
569        let mut sparse = SparseCOO::new(indices, values, vec![2, 2]);
570
571        sparse.coalesce();
572
573        assert_eq!(sparse.nnz(), 2); // Duplicates combined
574        let dense = sparse.to_dense();
575        assert_eq!(dense.to_vec()[0], 3.0); // 1.0 + 2.0
576    }
577
578    #[test]
579    fn test_sparse_csr_creation() {
580        let row_ptr = vec![0, 1, 2, 3];
581        let col_indices = vec![1, 0, 2];
582        let values = vec![1.0, 2.0, 3.0];
583        let csr = SparseCSR::new(row_ptr, col_indices, values, vec![3, 3]);
584
585        assert_eq!(csr.nnz(), 3);
586        assert_eq!(csr.nrows(), 3);
587        assert_eq!(csr.ncols(), 3);
588    }
589
590    #[test]
591    fn test_sparse_csr_matvec() {
592        // Matrix: [[1, 0], [0, 2]]
593        let row_ptr = vec![0, 1, 2];
594        let col_indices = vec![0, 1];
595        let values = vec![1.0, 2.0];
596        let csr = SparseCSR::new(row_ptr, col_indices, values, vec![2, 2]);
597
598        let x = vec![1.0, 2.0];
599        let result = csr.matvec(&x);
600
601        assert_eq!(result, vec![1.0, 4.0]);
602    }
603
604    #[test]
605    fn test_sparse_coo_to_csr() {
606        let coords = vec![(0, 1), (1, 0), (2, 2)];
607        let values = vec![1.0, 2.0, 3.0];
608        let coo = SparseCOO::from_coo_2d(&coords, &values, &[3, 3]);
609
610        let csr = coo.to_csr();
611
612        assert_eq!(csr.nnz(), 3);
613        assert_eq!(csr.nrows(), 3);
614    }
615
616    #[test]
617    fn test_sparse_tensor_from_dense() {
618        let dense = Tensor::from_vec(vec![0.0, 1.0, 0.0, 2.0], &[2, 2]).unwrap();
619        let sparse = SparseTensor::from_dense(&dense, 0.0);
620
621        assert_eq!(sparse.nnz(), 2);
622    }
623
624    #[test]
625    fn test_sparse_tensor_eye() {
626        let eye = SparseTensor::eye(3);
627        let dense = eye.to_dense();
628        let data = dense.to_vec();
629
630        assert_eq!(data[0], 1.0);
631        assert_eq!(data[4], 1.0);
632        assert_eq!(data[8], 1.0);
633        assert_eq!(data[1], 0.0);
634    }
635
636    #[test]
637    fn test_sparse_tensor_diag() {
638        let diag = SparseTensor::diag(&[1.0, 2.0, 3.0]);
639        let dense = diag.to_dense();
640        let data = dense.to_vec();
641
642        assert_eq!(data[0], 1.0);
643        assert_eq!(data[4], 2.0);
644        assert_eq!(data[8], 3.0);
645    }
646
647    #[test]
648    fn test_sparse_density() {
649        let coords = vec![(0, 0), (1, 1)];
650        let values = vec![1.0, 2.0];
651        let sparse = SparseTensor::from_coords(&coords, &values, &[4, 4]);
652
653        assert!((sparse.density() - 0.125).abs() < 1e-6); // 2/16
654    }
655
656    #[test]
657    fn test_sparse_mul_scalar() {
658        let coords = vec![(0, 0)];
659        let values = vec![2.0];
660        let sparse = SparseTensor::from_coords(&coords, &values, &[2, 2]);
661
662        let scaled = sparse.mul_scalar(3.0);
663        let dense = scaled.to_dense();
664
665        assert_eq!(dense.to_vec()[0], 6.0);
666    }
667
668    #[test]
669    fn test_sparse_matmul() {
670        // Sparse: [[1, 0], [0, 2]]
671        let coords = vec![(0, 0), (1, 1)];
672        let values = vec![1.0, 2.0];
673        let sparse = SparseTensor::from_coords(&coords, &values, &[2, 2]);
674
675        // Dense: [[1, 2], [3, 4]]
676        let dense = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
677
678        let result = sparse.matmul(&dense);
679        let data = result.to_vec();
680
681        // [[1, 0], [0, 2]] @ [[1, 2], [3, 4]] = [[1, 2], [6, 8]]
682        assert_eq!(data, vec![1.0, 2.0, 6.0, 8.0]);
683    }
684}