Skip to main content

axonml_tensor/
sparse.rs

1//! Sparse Tensor Support
2//!
3//! # File
4//! `crates/axonml-tensor/src/sparse.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use crate::Tensor;
18
19// =============================================================================
20// Sparse Format
21// =============================================================================
22
23/// Sparse tensor storage format.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum SparseFormat {
26    /// Coordinate format: (row, col, value) tuples
27    COO,
28    /// Compressed Sparse Row format
29    CSR,
30    /// Compressed Sparse Column format
31    CSC,
32}
33
34// =============================================================================
35// COO Sparse Tensor
36// =============================================================================
37
38/// Sparse tensor in COO (Coordinate) format (f32 only).
39///
40/// Stores non-zero elements as a list of (index, value) pairs.
41/// Efficient for construction but less efficient for arithmetic.
42/// Currently hardcoded to f32 values — sufficient for all ML workloads.
43#[derive(Debug, Clone)]
44pub struct SparseCOO {
45    /// Row indices of non-zero elements
46    pub indices: Vec<Vec<usize>>,
47    /// Values of non-zero elements
48    pub values: Vec<f32>,
49    /// Shape of the tensor
50    pub shape: Vec<usize>,
51    /// Whether indices are sorted
52    pub is_coalesced: bool,
53}
54
55impl SparseCOO {
56    /// Creates a new sparse COO tensor.
57    ///
58    /// # Arguments
59    /// * `indices` - List of index tuples, one per dimension
60    /// * `values` - Non-zero values
61    /// * `shape` - Shape of the tensor
62    pub fn new(indices: Vec<Vec<usize>>, values: Vec<f32>, shape: Vec<usize>) -> Self {
63        assert_eq!(
64            indices.len(),
65            shape.len(),
66            "indices dimensions must match shape"
67        );
68        if !indices.is_empty() {
69            let nnz = indices[0].len();
70            for idx in &indices {
71                assert_eq!(idx.len(), nnz, "all index arrays must have same length");
72            }
73            assert_eq!(
74                values.len(),
75                nnz,
76                "values length must match number of indices"
77            );
78        }
79
80        Self {
81            indices,
82            values,
83            shape,
84            is_coalesced: false,
85        }
86    }
87
88    /// Creates from a list of 2D coordinate tuples.
89    pub fn from_coo_2d(coords: &[(usize, usize)], values: &[f32], shape: &[usize]) -> Self {
90        assert_eq!(shape.len(), 2, "shape must be 2D");
91        assert_eq!(
92            coords.len(),
93            values.len(),
94            "coords and values must have same length"
95        );
96
97        let rows: Vec<usize> = coords.iter().map(|(r, _)| *r).collect();
98        let cols: Vec<usize> = coords.iter().map(|(_, c)| *c).collect();
99
100        Self::new(vec![rows, cols], values.to_vec(), shape.to_vec())
101    }
102
103    /// Returns number of non-zero elements.
104    pub fn nnz(&self) -> usize {
105        self.values.len()
106    }
107
108    /// Returns the density (ratio of non-zeros to total elements).
109    pub fn density(&self) -> f32 {
110        let total: usize = self.shape.iter().product();
111        if total == 0 {
112            0.0
113        } else {
114            self.nnz() as f32 / total as f32
115        }
116    }
117
118    /// Returns the shape.
119    pub fn shape(&self) -> &[usize] {
120        &self.shape
121    }
122
123    /// Coalesces the sparse tensor (combines duplicate indices).
124    pub fn coalesce(&mut self) {
125        if self.is_coalesced || self.nnz() == 0 {
126            self.is_coalesced = true;
127            return;
128        }
129
130        // Create (indices, value) pairs and sort
131        let mut entries: Vec<(Vec<usize>, f32)> = (0..self.nnz())
132            .map(|i| {
133                let idx: Vec<usize> = self.indices.iter().map(|dim| dim[i]).collect();
134                (idx, self.values[i])
135            })
136            .collect();
137
138        entries.sort_by(|a, b| a.0.cmp(&b.0));
139
140        // Combine duplicates
141        let mut new_indices: Vec<Vec<usize>> = vec![Vec::new(); self.shape.len()];
142        let mut new_values = Vec::new();
143
144        let mut prev_idx: Option<Vec<usize>> = None;
145
146        for (idx, val) in entries {
147            if prev_idx.as_ref() == Some(&idx) {
148                // Duplicate: add to previous value
149                if let Some(last) = new_values.last_mut() {
150                    *last += val;
151                }
152            } else {
153                // New index
154                for (d, i) in idx.iter().enumerate() {
155                    new_indices[d].push(*i);
156                }
157                new_values.push(val);
158                prev_idx = Some(idx);
159            }
160        }
161
162        self.indices = new_indices;
163        self.values = new_values;
164        self.is_coalesced = true;
165    }
166
167    /// Converts to dense tensor.
168    pub fn to_dense(&self) -> Tensor<f32> {
169        let total: usize = self.shape.iter().product();
170        let mut data = vec![0.0f32; total];
171
172        for i in 0..self.nnz() {
173            let mut flat_idx = 0;
174            let mut stride = 1;
175            for d in (0..self.shape.len()).rev() {
176                flat_idx += self.indices[d][i] * stride;
177                stride *= self.shape[d];
178            }
179            data[flat_idx] += self.values[i];
180        }
181
182        Tensor::from_vec(data, &self.shape).expect("tensor creation failed")
183    }
184
185    /// Converts to CSR format (for 2D matrices).
186    pub fn to_csr(&self) -> SparseCSR {
187        assert_eq!(self.shape.len(), 2, "CSR only supports 2D tensors");
188
189        let mut coo = self.clone();
190        coo.coalesce();
191
192        let nrows = self.shape[0];
193        let nnz = coo.nnz();
194
195        let mut row_ptr = vec![0usize; nrows + 1];
196        let mut col_indices = Vec::with_capacity(nnz);
197        let mut values = Vec::with_capacity(nnz);
198
199        // Count entries per row
200        for &row in &coo.indices[0] {
201            row_ptr[row + 1] += 1;
202        }
203
204        // Cumulative sum
205        for i in 1..=nrows {
206            row_ptr[i] += row_ptr[i - 1];
207        }
208
209        // Sort by row, then column
210        let mut entries: Vec<(usize, usize, f32)> = (0..nnz)
211            .map(|i| (coo.indices[0][i], coo.indices[1][i], coo.values[i]))
212            .collect();
213        entries.sort_by_key(|(r, c, _)| (*r, *c));
214
215        for (_, col, val) in entries {
216            col_indices.push(col);
217            values.push(val);
218        }
219
220        SparseCSR {
221            row_ptr,
222            col_indices,
223            values,
224            shape: self.shape.clone(),
225        }
226    }
227}
228
229// =============================================================================
230// CSR Sparse Tensor
231// =============================================================================
232
233/// Sparse tensor in CSR (Compressed Sparse Row) format.
234///
235/// Efficient for row-wise operations and sparse matrix-vector products.
236#[derive(Debug, Clone)]
237pub struct SparseCSR {
238    /// Row pointers (length = nrows + 1)
239    pub row_ptr: Vec<usize>,
240    /// Column indices for each non-zero
241    pub col_indices: Vec<usize>,
242    /// Values for each non-zero
243    pub values: Vec<f32>,
244    /// Shape [nrows, ncols]
245    pub shape: Vec<usize>,
246}
247
248impl SparseCSR {
249    /// Creates a new CSR sparse matrix.
250    pub fn new(
251        row_ptr: Vec<usize>,
252        col_indices: Vec<usize>,
253        values: Vec<f32>,
254        shape: Vec<usize>,
255    ) -> Self {
256        assert_eq!(shape.len(), 2, "CSR only supports 2D tensors");
257        assert_eq!(
258            row_ptr.len(),
259            shape[0] + 1,
260            "row_ptr length must be nrows + 1"
261        );
262        assert_eq!(
263            col_indices.len(),
264            values.len(),
265            "col_indices and values must match"
266        );
267
268        Self {
269            row_ptr,
270            col_indices,
271            values,
272            shape,
273        }
274    }
275
276    /// Returns number of non-zero elements.
277    pub fn nnz(&self) -> usize {
278        self.values.len()
279    }
280
281    /// Returns number of rows.
282    pub fn nrows(&self) -> usize {
283        self.shape[0]
284    }
285
286    /// Returns number of columns.
287    pub fn ncols(&self) -> usize {
288        self.shape[1]
289    }
290
291    /// Returns the density.
292    pub fn density(&self) -> f32 {
293        let total = self.nrows() * self.ncols();
294        if total == 0 {
295            0.0
296        } else {
297            self.nnz() as f32 / total as f32
298        }
299    }
300
301    /// Gets entries for a specific row.
302    pub fn row(&self, row_idx: usize) -> impl Iterator<Item = (usize, f32)> + '_ {
303        let start = self.row_ptr[row_idx];
304        let end = self.row_ptr[row_idx + 1];
305        (start..end).map(move |i| (self.col_indices[i], self.values[i]))
306    }
307
308    /// Sparse matrix-vector multiplication: A @ x.
309    pub fn matvec(&self, x: &[f32]) -> Vec<f32> {
310        assert_eq!(x.len(), self.ncols(), "vector length must match ncols");
311
312        let mut result = vec![0.0f32; self.nrows()];
313
314        for row in 0..self.nrows() {
315            let start = self.row_ptr[row];
316            let end = self.row_ptr[row + 1];
317
318            for i in start..end {
319                let col = self.col_indices[i];
320                let val = self.values[i];
321                result[row] += val * x[col];
322            }
323        }
324
325        result
326    }
327
328    /// Sparse matrix-matrix multiplication: A @ B (where B is dense).
329    pub fn matmul_dense(&self, b: &Tensor<f32>) -> Tensor<f32> {
330        let b_shape = b.shape();
331        assert_eq!(b_shape[0], self.ncols(), "inner dimensions must match");
332
333        let m = self.nrows();
334        let n = b_shape[1];
335        let b_data = b.to_vec();
336
337        let mut result = vec![0.0f32; m * n];
338
339        for row in 0..m {
340            for (col, val) in self.row(row) {
341                for j in 0..n {
342                    result[row * n + j] += val * b_data[col * n + j];
343                }
344            }
345        }
346
347        Tensor::from_vec(result, &[m, n]).expect("tensor creation failed")
348    }
349
350    /// Converts to dense tensor.
351    pub fn to_dense(&self) -> Tensor<f32> {
352        let mut data = vec![0.0f32; self.nrows() * self.ncols()];
353
354        for row in 0..self.nrows() {
355            for (col, val) in self.row(row) {
356                data[row * self.ncols() + col] = val;
357            }
358        }
359
360        Tensor::from_vec(data, &self.shape).expect("tensor creation failed")
361    }
362
363    /// Converts to COO format.
364    pub fn to_coo(&self) -> SparseCOO {
365        let mut rows = Vec::with_capacity(self.nnz());
366        let mut cols = Vec::with_capacity(self.nnz());
367
368        for row in 0..self.nrows() {
369            let start = self.row_ptr[row];
370            let end = self.row_ptr[row + 1];
371            for i in start..end {
372                rows.push(row);
373                cols.push(self.col_indices[i]);
374            }
375        }
376
377        SparseCOO {
378            indices: vec![rows, cols],
379            values: self.values.clone(),
380            shape: self.shape.clone(),
381            is_coalesced: true,
382        }
383    }
384}
385
386// =============================================================================
387// SparseTensor (Unified Interface)
388// =============================================================================
389
390/// Unified sparse tensor interface supporting multiple formats.
391#[derive(Debug, Clone)]
392pub enum SparseTensor {
393    /// COO format
394    COO(SparseCOO),
395    /// CSR format
396    CSR(SparseCSR),
397}
398
399impl SparseTensor {
400    /// Creates a sparse tensor from COO data.
401    pub fn from_coo(indices: Vec<Vec<usize>>, values: Vec<f32>, shape: Vec<usize>) -> Self {
402        Self::COO(SparseCOO::new(indices, values, shape))
403    }
404
405    /// Creates a 2D sparse tensor from coordinate list.
406    pub fn from_coords(coords: &[(usize, usize)], values: &[f32], shape: &[usize]) -> Self {
407        Self::COO(SparseCOO::from_coo_2d(coords, values, shape))
408    }
409
410    /// Creates from a dense tensor, keeping only non-zero elements.
411    pub fn from_dense(tensor: &Tensor<f32>, threshold: f32) -> Self {
412        let data = tensor.to_vec();
413        let shape = tensor.shape().to_vec();
414
415        let mut indices: Vec<Vec<usize>> = vec![Vec::new(); shape.len()];
416        let mut values = Vec::new();
417
418        let strides: Vec<usize> = {
419            let mut s = vec![1; shape.len()];
420            for i in (0..shape.len() - 1).rev() {
421                s[i] = s[i + 1] * shape[i + 1];
422            }
423            s
424        };
425
426        for (flat_idx, &val) in data.iter().enumerate() {
427            if val.abs() > threshold {
428                let mut idx = flat_idx;
429                for (d, &stride) in strides.iter().enumerate() {
430                    indices[d].push(idx / stride);
431                    idx %= stride;
432                }
433                values.push(val);
434            }
435        }
436
437        Self::COO(SparseCOO::new(indices, values, shape))
438    }
439
440    /// Creates an identity matrix in sparse format.
441    pub fn eye(n: usize) -> Self {
442        let indices: Vec<usize> = (0..n).collect();
443        let values = vec![1.0f32; n];
444        Self::COO(SparseCOO::new(
445            vec![indices.clone(), indices],
446            values,
447            vec![n, n],
448        ))
449    }
450
451    /// Creates a sparse diagonal matrix.
452    pub fn diag(values: &[f32]) -> Self {
453        let n = values.len();
454        let indices: Vec<usize> = (0..n).collect();
455        Self::COO(SparseCOO::new(
456            vec![indices.clone(), indices],
457            values.to_vec(),
458            vec![n, n],
459        ))
460    }
461
462    /// Returns number of non-zero elements.
463    pub fn nnz(&self) -> usize {
464        match self {
465            Self::COO(coo) => coo.nnz(),
466            Self::CSR(csr) => csr.nnz(),
467        }
468    }
469
470    /// Returns the shape.
471    pub fn shape(&self) -> &[usize] {
472        match self {
473            Self::COO(coo) => &coo.shape,
474            Self::CSR(csr) => &csr.shape,
475        }
476    }
477
478    /// Returns the density.
479    pub fn density(&self) -> f32 {
480        match self {
481            Self::COO(coo) => coo.density(),
482            Self::CSR(csr) => csr.density(),
483        }
484    }
485
486    /// Converts to dense tensor.
487    pub fn to_dense(&self) -> Tensor<f32> {
488        match self {
489            Self::COO(coo) => coo.to_dense(),
490            Self::CSR(csr) => csr.to_dense(),
491        }
492    }
493
494    /// Converts to CSR format.
495    pub fn to_csr(&self) -> SparseCSR {
496        match self {
497            Self::COO(coo) => coo.to_csr(),
498            Self::CSR(csr) => csr.clone(),
499        }
500    }
501
502    /// Converts to COO format.
503    pub fn to_coo(&self) -> SparseCOO {
504        match self {
505            Self::COO(coo) => coo.clone(),
506            Self::CSR(csr) => csr.to_coo(),
507        }
508    }
509
510    /// Sparse matrix-vector multiplication.
511    pub fn matvec(&self, x: &[f32]) -> Vec<f32> {
512        match self {
513            Self::COO(coo) => coo.to_csr().matvec(x),
514            Self::CSR(csr) => csr.matvec(x),
515        }
516    }
517
518    /// Sparse-dense matrix multiplication.
519    pub fn matmul(&self, dense: &Tensor<f32>) -> Tensor<f32> {
520        match self {
521            Self::COO(coo) => coo.to_csr().matmul_dense(dense),
522            Self::CSR(csr) => csr.matmul_dense(dense),
523        }
524    }
525
526    /// Element-wise multiplication with a scalar.
527    pub fn mul_scalar(&self, scalar: f32) -> Self {
528        match self {
529            Self::COO(coo) => {
530                let values: Vec<f32> = coo.values.iter().map(|v| v * scalar).collect();
531                Self::COO(SparseCOO::new(
532                    coo.indices.clone(),
533                    values,
534                    coo.shape.clone(),
535                ))
536            }
537            Self::CSR(csr) => {
538                let values: Vec<f32> = csr.values.iter().map(|v| v * scalar).collect();
539                Self::CSR(SparseCSR::new(
540                    csr.row_ptr.clone(),
541                    csr.col_indices.clone(),
542                    values,
543                    csr.shape.clone(),
544                ))
545            }
546        }
547    }
548}
549
550// =============================================================================
551// Tests
552// =============================================================================
553
554#[cfg(test)]
555mod tests {
556    use super::*;
557
558    #[test]
559    fn test_sparse_coo_creation() {
560        let indices = vec![vec![0, 1, 2], vec![1, 0, 2]];
561        let values = vec![1.0, 2.0, 3.0];
562        let sparse = SparseCOO::new(indices, values, vec![3, 3]);
563
564        assert_eq!(sparse.nnz(), 3);
565        assert_eq!(sparse.shape(), &[3, 3]);
566    }
567
568    #[test]
569    fn test_sparse_coo_to_dense() {
570        let coords = vec![(0, 1), (1, 0), (2, 2)];
571        let values = vec![1.0, 2.0, 3.0];
572        let sparse = SparseCOO::from_coo_2d(&coords, &values, &[3, 3]);
573
574        let dense = sparse.to_dense();
575        let data = dense.to_vec();
576
577        assert_eq!(data[0 * 3 + 1], 1.0); // (0, 1)
578        assert_eq!(data[1 * 3 + 0], 2.0); // (1, 0)
579        assert_eq!(data[2 * 3 + 2], 3.0); // (2, 2)
580    }
581
582    #[test]
583    fn test_sparse_coo_coalesce() {
584        let indices = vec![vec![0, 0, 1], vec![0, 0, 1]];
585        let values = vec![1.0, 2.0, 3.0];
586        let mut sparse = SparseCOO::new(indices, values, vec![2, 2]);
587
588        sparse.coalesce();
589
590        assert_eq!(sparse.nnz(), 2); // Duplicates combined
591        let dense = sparse.to_dense();
592        assert_eq!(dense.to_vec()[0], 3.0); // 1.0 + 2.0
593    }
594
595    #[test]
596    fn test_sparse_csr_creation() {
597        let row_ptr = vec![0, 1, 2, 3];
598        let col_indices = vec![1, 0, 2];
599        let values = vec![1.0, 2.0, 3.0];
600        let csr = SparseCSR::new(row_ptr, col_indices, values, vec![3, 3]);
601
602        assert_eq!(csr.nnz(), 3);
603        assert_eq!(csr.nrows(), 3);
604        assert_eq!(csr.ncols(), 3);
605    }
606
607    #[test]
608    fn test_sparse_csr_matvec() {
609        // Matrix: [[1, 0], [0, 2]]
610        let row_ptr = vec![0, 1, 2];
611        let col_indices = vec![0, 1];
612        let values = vec![1.0, 2.0];
613        let csr = SparseCSR::new(row_ptr, col_indices, values, vec![2, 2]);
614
615        let x = vec![1.0, 2.0];
616        let result = csr.matvec(&x);
617
618        assert_eq!(result, vec![1.0, 4.0]);
619    }
620
621    #[test]
622    fn test_sparse_coo_to_csr() {
623        let coords = vec![(0, 1), (1, 0), (2, 2)];
624        let values = vec![1.0, 2.0, 3.0];
625        let coo = SparseCOO::from_coo_2d(&coords, &values, &[3, 3]);
626
627        let csr = coo.to_csr();
628
629        assert_eq!(csr.nnz(), 3);
630        assert_eq!(csr.nrows(), 3);
631    }
632
633    #[test]
634    fn test_sparse_tensor_from_dense() {
635        let dense =
636            Tensor::from_vec(vec![0.0, 1.0, 0.0, 2.0], &[2, 2]).expect("tensor creation failed");
637        let sparse = SparseTensor::from_dense(&dense, 0.0);
638
639        assert_eq!(sparse.nnz(), 2);
640    }
641
642    #[test]
643    fn test_sparse_tensor_eye() {
644        let eye = SparseTensor::eye(3);
645        let dense = eye.to_dense();
646        let data = dense.to_vec();
647
648        assert_eq!(data[0], 1.0);
649        assert_eq!(data[4], 1.0);
650        assert_eq!(data[8], 1.0);
651        assert_eq!(data[1], 0.0);
652    }
653
654    #[test]
655    fn test_sparse_tensor_diag() {
656        let diag = SparseTensor::diag(&[1.0, 2.0, 3.0]);
657        let dense = diag.to_dense();
658        let data = dense.to_vec();
659
660        assert_eq!(data[0], 1.0);
661        assert_eq!(data[4], 2.0);
662        assert_eq!(data[8], 3.0);
663    }
664
665    #[test]
666    fn test_sparse_density() {
667        let coords = vec![(0, 0), (1, 1)];
668        let values = vec![1.0, 2.0];
669        let sparse = SparseTensor::from_coords(&coords, &values, &[4, 4]);
670
671        assert!((sparse.density() - 0.125).abs() < 1e-6); // 2/16
672    }
673
674    #[test]
675    fn test_sparse_mul_scalar() {
676        let coords = vec![(0, 0)];
677        let values = vec![2.0];
678        let sparse = SparseTensor::from_coords(&coords, &values, &[2, 2]);
679
680        let scaled = sparse.mul_scalar(3.0);
681        let dense = scaled.to_dense();
682
683        assert_eq!(dense.to_vec()[0], 6.0);
684    }
685
686    #[test]
687    fn test_sparse_matmul() {
688        // Sparse: [[1, 0], [0, 2]]
689        let coords = vec![(0, 0), (1, 1)];
690        let values = vec![1.0, 2.0];
691        let sparse = SparseTensor::from_coords(&coords, &values, &[2, 2]);
692
693        // Dense: [[1, 2], [3, 4]]
694        let dense =
695            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("tensor creation failed");
696
697        let result = sparse.matmul(&dense);
698        let data = result.to_vec();
699
700        // [[1, 0], [0, 2]] @ [[1, 2], [3, 4]] = [[1, 2], [6, 8]]
701        assert_eq!(data, vec![1.0, 2.0, 6.0, 8.0]);
702    }
703}