Skip to main content

axonml_tensor/
sparse.rs

1//! Sparse Tensor Support
2//!
3//! # File
4//! `crates/axonml-tensor/src/sparse.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use crate::Tensor;
18
19// =============================================================================
20// Sparse Format
21// =============================================================================
22
23/// Sparse tensor storage format.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum SparseFormat {
26    /// Coordinate format: (row, col, value) tuples
27    COO,
28    /// Compressed Sparse Row format
29    CSR,
30    /// Compressed Sparse Column format
31    CSC,
32}
33
34// =============================================================================
35// COO Sparse Tensor
36// =============================================================================
37
38/// Sparse tensor in COO (Coordinate) format.
39///
40/// Stores non-zero elements as a list of (index, value) pairs.
41/// Efficient for construction but less efficient for arithmetic.
42#[derive(Debug, Clone)]
43pub struct SparseCOO {
44    /// Row indices of non-zero elements
45    pub indices: Vec<Vec<usize>>,
46    /// Values of non-zero elements
47    pub values: Vec<f32>,
48    /// Shape of the tensor
49    pub shape: Vec<usize>,
50    /// Whether indices are sorted
51    pub is_coalesced: bool,
52}
53
54impl SparseCOO {
55    /// Creates a new sparse COO tensor.
56    ///
57    /// # Arguments
58    /// * `indices` - List of index tuples, one per dimension
59    /// * `values` - Non-zero values
60    /// * `shape` - Shape of the tensor
61    pub fn new(indices: Vec<Vec<usize>>, values: Vec<f32>, shape: Vec<usize>) -> Self {
62        assert_eq!(
63            indices.len(),
64            shape.len(),
65            "indices dimensions must match shape"
66        );
67        if !indices.is_empty() {
68            let nnz = indices[0].len();
69            for idx in &indices {
70                assert_eq!(idx.len(), nnz, "all index arrays must have same length");
71            }
72            assert_eq!(
73                values.len(),
74                nnz,
75                "values length must match number of indices"
76            );
77        }
78
79        Self {
80            indices,
81            values,
82            shape,
83            is_coalesced: false,
84        }
85    }
86
87    /// Creates from a list of 2D coordinate tuples.
88    pub fn from_coo_2d(coords: &[(usize, usize)], values: &[f32], shape: &[usize]) -> Self {
89        assert_eq!(shape.len(), 2, "shape must be 2D");
90        assert_eq!(
91            coords.len(),
92            values.len(),
93            "coords and values must have same length"
94        );
95
96        let rows: Vec<usize> = coords.iter().map(|(r, _)| *r).collect();
97        let cols: Vec<usize> = coords.iter().map(|(_, c)| *c).collect();
98
99        Self::new(vec![rows, cols], values.to_vec(), shape.to_vec())
100    }
101
102    /// Returns number of non-zero elements.
103    pub fn nnz(&self) -> usize {
104        self.values.len()
105    }
106
107    /// Returns the density (ratio of non-zeros to total elements).
108    pub fn density(&self) -> f32 {
109        let total: usize = self.shape.iter().product();
110        if total == 0 {
111            0.0
112        } else {
113            self.nnz() as f32 / total as f32
114        }
115    }
116
117    /// Returns the shape.
118    pub fn shape(&self) -> &[usize] {
119        &self.shape
120    }
121
122    /// Coalesces the sparse tensor (combines duplicate indices).
123    pub fn coalesce(&mut self) {
124        if self.is_coalesced || self.nnz() == 0 {
125            self.is_coalesced = true;
126            return;
127        }
128
129        // Create (indices, value) pairs and sort
130        let mut entries: Vec<(Vec<usize>, f32)> = (0..self.nnz())
131            .map(|i| {
132                let idx: Vec<usize> = self.indices.iter().map(|dim| dim[i]).collect();
133                (idx, self.values[i])
134            })
135            .collect();
136
137        entries.sort_by(|a, b| a.0.cmp(&b.0));
138
139        // Combine duplicates
140        let mut new_indices: Vec<Vec<usize>> = vec![Vec::new(); self.shape.len()];
141        let mut new_values = Vec::new();
142
143        let mut prev_idx: Option<Vec<usize>> = None;
144
145        for (idx, val) in entries {
146            if prev_idx.as_ref() == Some(&idx) {
147                // Duplicate: add to previous value
148                if let Some(last) = new_values.last_mut() {
149                    *last += val;
150                }
151            } else {
152                // New index
153                for (d, i) in idx.iter().enumerate() {
154                    new_indices[d].push(*i);
155                }
156                new_values.push(val);
157                prev_idx = Some(idx);
158            }
159        }
160
161        self.indices = new_indices;
162        self.values = new_values;
163        self.is_coalesced = true;
164    }
165
166    /// Converts to dense tensor.
167    pub fn to_dense(&self) -> Tensor<f32> {
168        let total: usize = self.shape.iter().product();
169        let mut data = vec![0.0f32; total];
170
171        for i in 0..self.nnz() {
172            let mut flat_idx = 0;
173            let mut stride = 1;
174            for d in (0..self.shape.len()).rev() {
175                flat_idx += self.indices[d][i] * stride;
176                stride *= self.shape[d];
177            }
178            data[flat_idx] += self.values[i];
179        }
180
181        Tensor::from_vec(data, &self.shape).unwrap()
182    }
183
184    /// Converts to CSR format (for 2D matrices).
185    pub fn to_csr(&self) -> SparseCSR {
186        assert_eq!(self.shape.len(), 2, "CSR only supports 2D tensors");
187
188        let mut coo = self.clone();
189        coo.coalesce();
190
191        let nrows = self.shape[0];
192        let nnz = coo.nnz();
193
194        let mut row_ptr = vec![0usize; nrows + 1];
195        let mut col_indices = Vec::with_capacity(nnz);
196        let mut values = Vec::with_capacity(nnz);
197
198        // Count entries per row
199        for &row in &coo.indices[0] {
200            row_ptr[row + 1] += 1;
201        }
202
203        // Cumulative sum
204        for i in 1..=nrows {
205            row_ptr[i] += row_ptr[i - 1];
206        }
207
208        // Sort by row, then column
209        let mut entries: Vec<(usize, usize, f32)> = (0..nnz)
210            .map(|i| (coo.indices[0][i], coo.indices[1][i], coo.values[i]))
211            .collect();
212        entries.sort_by_key(|(r, c, _)| (*r, *c));
213
214        for (_, col, val) in entries {
215            col_indices.push(col);
216            values.push(val);
217        }
218
219        SparseCSR {
220            row_ptr,
221            col_indices,
222            values,
223            shape: self.shape.clone(),
224        }
225    }
226}
227
228// =============================================================================
229// CSR Sparse Tensor
230// =============================================================================
231
232/// Sparse tensor in CSR (Compressed Sparse Row) format.
233///
234/// Efficient for row-wise operations and sparse matrix-vector products.
235#[derive(Debug, Clone)]
236pub struct SparseCSR {
237    /// Row pointers (length = nrows + 1)
238    pub row_ptr: Vec<usize>,
239    /// Column indices for each non-zero
240    pub col_indices: Vec<usize>,
241    /// Values for each non-zero
242    pub values: Vec<f32>,
243    /// Shape [nrows, ncols]
244    pub shape: Vec<usize>,
245}
246
247impl SparseCSR {
248    /// Creates a new CSR sparse matrix.
249    pub fn new(
250        row_ptr: Vec<usize>,
251        col_indices: Vec<usize>,
252        values: Vec<f32>,
253        shape: Vec<usize>,
254    ) -> Self {
255        assert_eq!(shape.len(), 2, "CSR only supports 2D tensors");
256        assert_eq!(
257            row_ptr.len(),
258            shape[0] + 1,
259            "row_ptr length must be nrows + 1"
260        );
261        assert_eq!(
262            col_indices.len(),
263            values.len(),
264            "col_indices and values must match"
265        );
266
267        Self {
268            row_ptr,
269            col_indices,
270            values,
271            shape,
272        }
273    }
274
275    /// Returns number of non-zero elements.
276    pub fn nnz(&self) -> usize {
277        self.values.len()
278    }
279
280    /// Returns number of rows.
281    pub fn nrows(&self) -> usize {
282        self.shape[0]
283    }
284
285    /// Returns number of columns.
286    pub fn ncols(&self) -> usize {
287        self.shape[1]
288    }
289
290    /// Returns the density.
291    pub fn density(&self) -> f32 {
292        let total = self.nrows() * self.ncols();
293        if total == 0 {
294            0.0
295        } else {
296            self.nnz() as f32 / total as f32
297        }
298    }
299
300    /// Gets entries for a specific row.
301    pub fn row(&self, row_idx: usize) -> impl Iterator<Item = (usize, f32)> + '_ {
302        let start = self.row_ptr[row_idx];
303        let end = self.row_ptr[row_idx + 1];
304        (start..end).map(move |i| (self.col_indices[i], self.values[i]))
305    }
306
307    /// Sparse matrix-vector multiplication: A @ x.
308    pub fn matvec(&self, x: &[f32]) -> Vec<f32> {
309        assert_eq!(x.len(), self.ncols(), "vector length must match ncols");
310
311        let mut result = vec![0.0f32; self.nrows()];
312
313        for row in 0..self.nrows() {
314            let start = self.row_ptr[row];
315            let end = self.row_ptr[row + 1];
316
317            for i in start..end {
318                let col = self.col_indices[i];
319                let val = self.values[i];
320                result[row] += val * x[col];
321            }
322        }
323
324        result
325    }
326
327    /// Sparse matrix-matrix multiplication: A @ B (where B is dense).
328    pub fn matmul_dense(&self, b: &Tensor<f32>) -> Tensor<f32> {
329        let b_shape = b.shape();
330        assert_eq!(b_shape[0], self.ncols(), "inner dimensions must match");
331
332        let m = self.nrows();
333        let n = b_shape[1];
334        let b_data = b.to_vec();
335
336        let mut result = vec![0.0f32; m * n];
337
338        for row in 0..m {
339            for (col, val) in self.row(row) {
340                for j in 0..n {
341                    result[row * n + j] += val * b_data[col * n + j];
342                }
343            }
344        }
345
346        Tensor::from_vec(result, &[m, n]).unwrap()
347    }
348
349    /// Converts to dense tensor.
350    pub fn to_dense(&self) -> Tensor<f32> {
351        let mut data = vec![0.0f32; self.nrows() * self.ncols()];
352
353        for row in 0..self.nrows() {
354            for (col, val) in self.row(row) {
355                data[row * self.ncols() + col] = val;
356            }
357        }
358
359        Tensor::from_vec(data, &self.shape).unwrap()
360    }
361
362    /// Converts to COO format.
363    pub fn to_coo(&self) -> SparseCOO {
364        let mut rows = Vec::with_capacity(self.nnz());
365        let mut cols = Vec::with_capacity(self.nnz());
366
367        for row in 0..self.nrows() {
368            let start = self.row_ptr[row];
369            let end = self.row_ptr[row + 1];
370            for i in start..end {
371                rows.push(row);
372                cols.push(self.col_indices[i]);
373            }
374        }
375
376        SparseCOO {
377            indices: vec![rows, cols],
378            values: self.values.clone(),
379            shape: self.shape.clone(),
380            is_coalesced: true,
381        }
382    }
383}
384
385// =============================================================================
386// SparseTensor (Unified Interface)
387// =============================================================================
388
389/// Unified sparse tensor interface supporting multiple formats.
390#[derive(Debug, Clone)]
391pub enum SparseTensor {
392    /// COO format
393    COO(SparseCOO),
394    /// CSR format
395    CSR(SparseCSR),
396}
397
398impl SparseTensor {
399    /// Creates a sparse tensor from COO data.
400    pub fn from_coo(indices: Vec<Vec<usize>>, values: Vec<f32>, shape: Vec<usize>) -> Self {
401        Self::COO(SparseCOO::new(indices, values, shape))
402    }
403
404    /// Creates a 2D sparse tensor from coordinate list.
405    pub fn from_coords(coords: &[(usize, usize)], values: &[f32], shape: &[usize]) -> Self {
406        Self::COO(SparseCOO::from_coo_2d(coords, values, shape))
407    }
408
409    /// Creates from a dense tensor, keeping only non-zero elements.
410    pub fn from_dense(tensor: &Tensor<f32>, threshold: f32) -> Self {
411        let data = tensor.to_vec();
412        let shape = tensor.shape().to_vec();
413
414        let mut indices: Vec<Vec<usize>> = vec![Vec::new(); shape.len()];
415        let mut values = Vec::new();
416
417        let strides: Vec<usize> = {
418            let mut s = vec![1; shape.len()];
419            for i in (0..shape.len() - 1).rev() {
420                s[i] = s[i + 1] * shape[i + 1];
421            }
422            s
423        };
424
425        for (flat_idx, &val) in data.iter().enumerate() {
426            if val.abs() > threshold {
427                let mut idx = flat_idx;
428                for (d, &stride) in strides.iter().enumerate() {
429                    indices[d].push(idx / stride);
430                    idx %= stride;
431                }
432                values.push(val);
433            }
434        }
435
436        Self::COO(SparseCOO::new(indices, values, shape))
437    }
438
439    /// Creates an identity matrix in sparse format.
440    pub fn eye(n: usize) -> Self {
441        let indices: Vec<usize> = (0..n).collect();
442        let values = vec![1.0f32; n];
443        Self::COO(SparseCOO::new(
444            vec![indices.clone(), indices],
445            values,
446            vec![n, n],
447        ))
448    }
449
450    /// Creates a sparse diagonal matrix.
451    pub fn diag(values: &[f32]) -> Self {
452        let n = values.len();
453        let indices: Vec<usize> = (0..n).collect();
454        Self::COO(SparseCOO::new(
455            vec![indices.clone(), indices],
456            values.to_vec(),
457            vec![n, n],
458        ))
459    }
460
461    /// Returns number of non-zero elements.
462    pub fn nnz(&self) -> usize {
463        match self {
464            Self::COO(coo) => coo.nnz(),
465            Self::CSR(csr) => csr.nnz(),
466        }
467    }
468
469    /// Returns the shape.
470    pub fn shape(&self) -> &[usize] {
471        match self {
472            Self::COO(coo) => &coo.shape,
473            Self::CSR(csr) => &csr.shape,
474        }
475    }
476
477    /// Returns the density.
478    pub fn density(&self) -> f32 {
479        match self {
480            Self::COO(coo) => coo.density(),
481            Self::CSR(csr) => csr.density(),
482        }
483    }
484
485    /// Converts to dense tensor.
486    pub fn to_dense(&self) -> Tensor<f32> {
487        match self {
488            Self::COO(coo) => coo.to_dense(),
489            Self::CSR(csr) => csr.to_dense(),
490        }
491    }
492
493    /// Converts to CSR format.
494    pub fn to_csr(&self) -> SparseCSR {
495        match self {
496            Self::COO(coo) => coo.to_csr(),
497            Self::CSR(csr) => csr.clone(),
498        }
499    }
500
501    /// Converts to COO format.
502    pub fn to_coo(&self) -> SparseCOO {
503        match self {
504            Self::COO(coo) => coo.clone(),
505            Self::CSR(csr) => csr.to_coo(),
506        }
507    }
508
509    /// Sparse matrix-vector multiplication.
510    pub fn matvec(&self, x: &[f32]) -> Vec<f32> {
511        match self {
512            Self::COO(coo) => coo.to_csr().matvec(x),
513            Self::CSR(csr) => csr.matvec(x),
514        }
515    }
516
517    /// Sparse-dense matrix multiplication.
518    pub fn matmul(&self, dense: &Tensor<f32>) -> Tensor<f32> {
519        match self {
520            Self::COO(coo) => coo.to_csr().matmul_dense(dense),
521            Self::CSR(csr) => csr.matmul_dense(dense),
522        }
523    }
524
525    /// Element-wise multiplication with a scalar.
526    pub fn mul_scalar(&self, scalar: f32) -> Self {
527        match self {
528            Self::COO(coo) => {
529                let values: Vec<f32> = coo.values.iter().map(|v| v * scalar).collect();
530                Self::COO(SparseCOO::new(
531                    coo.indices.clone(),
532                    values,
533                    coo.shape.clone(),
534                ))
535            }
536            Self::CSR(csr) => {
537                let values: Vec<f32> = csr.values.iter().map(|v| v * scalar).collect();
538                Self::CSR(SparseCSR::new(
539                    csr.row_ptr.clone(),
540                    csr.col_indices.clone(),
541                    values,
542                    csr.shape.clone(),
543                ))
544            }
545        }
546    }
547}
548
549// =============================================================================
550// Tests
551// =============================================================================
552
553#[cfg(test)]
554mod tests {
555    use super::*;
556
557    #[test]
558    fn test_sparse_coo_creation() {
559        let indices = vec![vec![0, 1, 2], vec![1, 0, 2]];
560        let values = vec![1.0, 2.0, 3.0];
561        let sparse = SparseCOO::new(indices, values, vec![3, 3]);
562
563        assert_eq!(sparse.nnz(), 3);
564        assert_eq!(sparse.shape(), &[3, 3]);
565    }
566
567    #[test]
568    fn test_sparse_coo_to_dense() {
569        let coords = vec![(0, 1), (1, 0), (2, 2)];
570        let values = vec![1.0, 2.0, 3.0];
571        let sparse = SparseCOO::from_coo_2d(&coords, &values, &[3, 3]);
572
573        let dense = sparse.to_dense();
574        let data = dense.to_vec();
575
576        assert_eq!(data[0 * 3 + 1], 1.0); // (0, 1)
577        assert_eq!(data[1 * 3 + 0], 2.0); // (1, 0)
578        assert_eq!(data[2 * 3 + 2], 3.0); // (2, 2)
579    }
580
581    #[test]
582    fn test_sparse_coo_coalesce() {
583        let indices = vec![vec![0, 0, 1], vec![0, 0, 1]];
584        let values = vec![1.0, 2.0, 3.0];
585        let mut sparse = SparseCOO::new(indices, values, vec![2, 2]);
586
587        sparse.coalesce();
588
589        assert_eq!(sparse.nnz(), 2); // Duplicates combined
590        let dense = sparse.to_dense();
591        assert_eq!(dense.to_vec()[0], 3.0); // 1.0 + 2.0
592    }
593
594    #[test]
595    fn test_sparse_csr_creation() {
596        let row_ptr = vec![0, 1, 2, 3];
597        let col_indices = vec![1, 0, 2];
598        let values = vec![1.0, 2.0, 3.0];
599        let csr = SparseCSR::new(row_ptr, col_indices, values, vec![3, 3]);
600
601        assert_eq!(csr.nnz(), 3);
602        assert_eq!(csr.nrows(), 3);
603        assert_eq!(csr.ncols(), 3);
604    }
605
606    #[test]
607    fn test_sparse_csr_matvec() {
608        // Matrix: [[1, 0], [0, 2]]
609        let row_ptr = vec![0, 1, 2];
610        let col_indices = vec![0, 1];
611        let values = vec![1.0, 2.0];
612        let csr = SparseCSR::new(row_ptr, col_indices, values, vec![2, 2]);
613
614        let x = vec![1.0, 2.0];
615        let result = csr.matvec(&x);
616
617        assert_eq!(result, vec![1.0, 4.0]);
618    }
619
620    #[test]
621    fn test_sparse_coo_to_csr() {
622        let coords = vec![(0, 1), (1, 0), (2, 2)];
623        let values = vec![1.0, 2.0, 3.0];
624        let coo = SparseCOO::from_coo_2d(&coords, &values, &[3, 3]);
625
626        let csr = coo.to_csr();
627
628        assert_eq!(csr.nnz(), 3);
629        assert_eq!(csr.nrows(), 3);
630    }
631
632    #[test]
633    fn test_sparse_tensor_from_dense() {
634        let dense = Tensor::from_vec(vec![0.0, 1.0, 0.0, 2.0], &[2, 2]).unwrap();
635        let sparse = SparseTensor::from_dense(&dense, 0.0);
636
637        assert_eq!(sparse.nnz(), 2);
638    }
639
640    #[test]
641    fn test_sparse_tensor_eye() {
642        let eye = SparseTensor::eye(3);
643        let dense = eye.to_dense();
644        let data = dense.to_vec();
645
646        assert_eq!(data[0], 1.0);
647        assert_eq!(data[4], 1.0);
648        assert_eq!(data[8], 1.0);
649        assert_eq!(data[1], 0.0);
650    }
651
652    #[test]
653    fn test_sparse_tensor_diag() {
654        let diag = SparseTensor::diag(&[1.0, 2.0, 3.0]);
655        let dense = diag.to_dense();
656        let data = dense.to_vec();
657
658        assert_eq!(data[0], 1.0);
659        assert_eq!(data[4], 2.0);
660        assert_eq!(data[8], 3.0);
661    }
662
663    #[test]
664    fn test_sparse_density() {
665        let coords = vec![(0, 0), (1, 1)];
666        let values = vec![1.0, 2.0];
667        let sparse = SparseTensor::from_coords(&coords, &values, &[4, 4]);
668
669        assert!((sparse.density() - 0.125).abs() < 1e-6); // 2/16
670    }
671
672    #[test]
673    fn test_sparse_mul_scalar() {
674        let coords = vec![(0, 0)];
675        let values = vec![2.0];
676        let sparse = SparseTensor::from_coords(&coords, &values, &[2, 2]);
677
678        let scaled = sparse.mul_scalar(3.0);
679        let dense = scaled.to_dense();
680
681        assert_eq!(dense.to_vec()[0], 6.0);
682    }
683
684    #[test]
685    fn test_sparse_matmul() {
686        // Sparse: [[1, 0], [0, 2]]
687        let coords = vec![(0, 0), (1, 1)];
688        let values = vec![1.0, 2.0];
689        let sparse = SparseTensor::from_coords(&coords, &values, &[2, 2]);
690
691        // Dense: [[1, 2], [3, 4]]
692        let dense = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
693
694        let result = sparse.matmul(&dense);
695        let data = result.to_vec();
696
697        // [[1, 0], [0, 2]] @ [[1, 2], [3, 4]] = [[1, 2], [6, 8]]
698        assert_eq!(data, vec![1.0, 2.0, 6.0, 8.0]);
699    }
700}