scirs2_sparse/
coo.rs

1//! Coordinate (COO) matrix format
2//!
3//! This module provides the COO matrix format implementation, which is
4//! efficient for incremental matrix construction.
5
6use crate::error::{SparseError, SparseResult};
7use scirs2_core::numeric::{SparseElement, Zero};
8use std::cmp::PartialEq;
9
10/// Coordinate (COO) matrix
11///
12/// A sparse matrix format that stores triplets (row, column, value),
13/// making it efficient for construction and modification.
14pub struct CooMatrix<T> {
15    /// Number of rows
16    rows: usize,
17    /// Number of columns
18    cols: usize,
19    /// Row indices
20    row_indices: Vec<usize>,
21    /// Column indices
22    col_indices: Vec<usize>,
23    /// Data values
24    data: Vec<T>,
25}
26
27impl<T> CooMatrix<T>
28where
29    T: Clone + Copy + Zero + PartialEq + SparseElement,
30{
31    /// Get the triplets (row indices, column indices, data)
32    pub fn get_triplets(&self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
33        (
34            self.row_indices.clone(),
35            self.col_indices.clone(),
36            self.data.clone(),
37        )
38    }
39    /// Create a new COO matrix from raw data
40    ///
41    /// # Arguments
42    ///
43    /// * `data` - Vector of non-zero values
44    /// * `row_indices` - Vector of row indices for each non-zero value
45    /// * `col_indices` - Vector of column indices for each non-zero value
46    /// * `shape` - Tuple containing the matrix dimensions (rows, cols)
47    ///
48    /// # Returns
49    ///
50    /// * A new COO matrix
51    ///
52    /// # Examples
53    ///
54    /// ```
55    /// use scirs2_sparse::coo::CooMatrix;
56    ///
57    /// // Create a 3x3 sparse matrix with 5 non-zero elements
58    /// let rows = vec![0, 0, 1, 2, 2];
59    /// let cols = vec![0, 2, 2, 0, 1];
60    /// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
61    /// let shape = (3, 3);
62    ///
63    /// let matrix = CooMatrix::new(data, rows, cols, shape).unwrap();
64    /// ```
65    pub fn new(
66        data: Vec<T>,
67        row_indices: Vec<usize>,
68        col_indices: Vec<usize>,
69        shape: (usize, usize),
70    ) -> SparseResult<Self> {
71        // Validate input data
72        if data.len() != row_indices.len() || data.len() != col_indices.len() {
73            return Err(SparseError::DimensionMismatch {
74                expected: data.len(),
75                found: std::cmp::min(row_indices.len(), col_indices.len()),
76            });
77        }
78
79        let (rows, cols) = shape;
80
81        // Check _indices are within bounds
82        if row_indices.iter().any(|&i| i >= rows) {
83            return Err(SparseError::ValueError(
84                "Row index out of bounds".to_string(),
85            ));
86        }
87
88        if col_indices.iter().any(|&i| i >= cols) {
89            return Err(SparseError::ValueError(
90                "Column index out of bounds".to_string(),
91            ));
92        }
93
94        Ok(CooMatrix {
95            rows,
96            cols,
97            row_indices,
98            col_indices,
99            data,
100        })
101    }
102
103    /// Create a new empty COO matrix
104    ///
105    /// # Arguments
106    ///
107    /// * `shape` - Tuple containing the matrix dimensions (rows, cols)
108    ///
109    /// # Returns
110    ///
111    /// * A new empty COO matrix
112    pub fn empty(shape: (usize, usize)) -> Self {
113        let (rows, cols) = shape;
114
115        CooMatrix {
116            rows,
117            cols,
118            row_indices: Vec::new(),
119            col_indices: Vec::new(),
120            data: Vec::new(),
121        }
122    }
123
124    /// Add a value to the matrix at the specified position
125    ///
126    /// # Arguments
127    ///
128    /// * `row` - Row index
129    /// * `col` - Column index
130    /// * `value` - Value to add
131    ///
132    /// # Returns
133    ///
134    /// * Ok(()) if successful, Error otherwise
135    pub fn add_element(&mut self, row: usize, col: usize, value: T) -> SparseResult<()> {
136        if row >= self.rows || col >= self.cols {
137            return Err(SparseError::ValueError(
138                "Row or column index out of bounds".to_string(),
139            ));
140        }
141
142        self.row_indices.push(row);
143        self.col_indices.push(col);
144        self.data.push(value);
145
146        Ok(())
147    }
148
149    /// Get the number of rows in the matrix
150    pub fn rows(&self) -> usize {
151        self.rows
152    }
153
154    /// Get the number of columns in the matrix
155    pub fn cols(&self) -> usize {
156        self.cols
157    }
158
159    /// Get the shape (dimensions) of the matrix
160    pub fn shape(&self) -> (usize, usize) {
161        (self.rows, self.cols)
162    }
163
164    /// Get the number of non-zero elements in the matrix
165    pub fn nnz(&self) -> usize {
166        self.data.len()
167    }
168
169    /// Get the row indices array
170    pub fn row_indices(&self) -> &[usize] {
171        &self.row_indices
172    }
173
174    /// Get the column indices array
175    pub fn col_indices(&self) -> &[usize] {
176        &self.col_indices
177    }
178
179    /// Get the data array
180    pub fn data(&self) -> &[T] {
181        &self.data
182    }
183
184    /// Convert to dense matrix (as Vec<Vec<T>>)
185    pub fn to_dense(&self) -> Vec<Vec<T>>
186    where
187        T: Zero + Copy + SparseElement,
188    {
189        let mut result = vec![vec![T::sparse_zero(); self.cols]; self.rows];
190
191        for i in 0..self.data.len() {
192            let row = self.row_indices[i];
193            let col = self.col_indices[i];
194            result[row][col] = self.data[i];
195        }
196
197        result
198    }
199
200    /// Convert to CSR format
201    pub fn to_csr(&self) -> crate::csr::CsrMatrix<T> {
202        crate::csr::CsrMatrix::new(
203            self.data.clone(),
204            self.row_indices.clone(),
205            self.col_indices.clone(),
206            (self.rows, self.cols),
207        )
208        .unwrap()
209    }
210
211    /// Convert to CSC format
212    pub fn to_csc(&self) -> crate::csc::CscMatrix<T> {
213        crate::csc::CscMatrix::new(
214            self.data.clone(),
215            self.row_indices.clone(),
216            self.col_indices.clone(),
217            (self.rows, self.cols),
218        )
219        .unwrap()
220    }
221
222    /// Transpose the matrix
223    pub fn transpose(&self) -> Self {
224        let mut transposed_data = Vec::with_capacity(self.data.len());
225        let mut transposed_row_indices = Vec::with_capacity(self.row_indices.len());
226        let mut transposed_col_indices = Vec::with_capacity(self.col_indices.len());
227
228        for i in 0..self.data.len() {
229            transposed_data.push(self.data[i]);
230            transposed_row_indices.push(self.col_indices[i]);
231            transposed_col_indices.push(self.row_indices[i]);
232        }
233
234        CooMatrix {
235            rows: self.cols,
236            cols: self.rows,
237            row_indices: transposed_row_indices,
238            col_indices: transposed_col_indices,
239            data: transposed_data,
240        }
241    }
242
243    /// Sort the matrix elements by row, then column
244    pub fn sort_by_row_col(&mut self) {
245        let mut indices: Vec<usize> = (0..self.data.len()).collect();
246        indices.sort_by_key(|&i| (self.row_indices[i], self.col_indices[i]));
247
248        let row_indices = self.row_indices.clone();
249        let col_indices = self.col_indices.clone();
250        let data = self.data.clone();
251
252        for (i, &idx) in indices.iter().enumerate() {
253            self.row_indices[i] = row_indices[idx];
254            self.col_indices[i] = col_indices[idx];
255            self.data[i] = data[idx];
256        }
257    }
258
259    /// Sort the matrix elements by column, then row
260    pub fn sort_by_col_row(&mut self) {
261        let mut indices: Vec<usize> = (0..self.data.len()).collect();
262        indices.sort_by_key(|&i| (self.col_indices[i], self.row_indices[i]));
263
264        let row_indices = self.row_indices.clone();
265        let col_indices = self.col_indices.clone();
266        let data = self.data.clone();
267
268        for (i, &idx) in indices.iter().enumerate() {
269            self.row_indices[i] = row_indices[idx];
270            self.col_indices[i] = col_indices[idx];
271            self.data[i] = data[idx];
272        }
273    }
274
275    /// Get the value at the specified position
276    pub fn get(&self, row: usize, col: usize) -> T
277    where
278        T: Zero + SparseElement,
279    {
280        for i in 0..self.data.len() {
281            if self.row_indices[i] == row && self.col_indices[i] == col {
282                return self.data[i];
283            }
284        }
285        T::sparse_zero()
286    }
287
288    /// Sum duplicate entries (elements with the same row and column indices)
289    pub fn sum_duplicates(&mut self)
290    where
291        T: std::ops::Add<Output = T>,
292    {
293        if self.data.is_empty() {
294            return;
295        }
296
297        // Sort by row and column
298        self.sort_by_row_col();
299
300        let mut unique_row_indices = Vec::new();
301        let mut unique_col_indices = Vec::new();
302        let mut unique_data = Vec::new();
303
304        let mut current_row = self.row_indices[0];
305        let mut current_col = self.col_indices[0];
306        let mut current_val = self.data[0];
307
308        for i in 1..self.data.len() {
309            if self.row_indices[i] == current_row && self.col_indices[i] == current_col {
310                // Same position, add values
311                current_val = current_val + self.data[i];
312            } else {
313                // New position, store the previous one
314                unique_row_indices.push(current_row);
315                unique_col_indices.push(current_col);
316                unique_data.push(current_val);
317
318                // Update current position
319                current_row = self.row_indices[i];
320                current_col = self.col_indices[i];
321                current_val = self.data[i];
322            }
323        }
324
325        // Add the last element
326        unique_row_indices.push(current_row);
327        unique_col_indices.push(current_col);
328        unique_data.push(current_val);
329
330        // Update the matrix
331        self.row_indices = unique_row_indices;
332        self.col_indices = unique_col_indices;
333        self.data = unique_data;
334    }
335}
336
337impl CooMatrix<f64> {
338    /// Matrix-vector multiplication
339    ///
340    /// # Arguments
341    ///
342    /// * `vec` - Vector to multiply with
343    ///
344    /// # Returns
345    ///
346    /// * Result of matrix-vector multiplication
347    pub fn dot(&self, vec: &[f64]) -> SparseResult<Vec<f64>> {
348        if vec.len() != self.cols {
349            return Err(SparseError::DimensionMismatch {
350                expected: self.cols,
351                found: vec.len(),
352            });
353        }
354
355        let mut result = vec![0.0; self.rows];
356
357        for i in 0..self.data.len() {
358            let row = self.row_indices[i];
359            let col = self.col_indices[i];
360            result[row] += self.data[i] * vec[col];
361        }
362
363        Ok(result)
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use approx::assert_relative_eq;
371
372    #[test]
373    fn test_coo_create() {
374        // Create a 3x3 sparse matrix with 5 non-zero elements
375        let rows = vec![0, 0, 1, 2, 2];
376        let cols = vec![0, 2, 2, 0, 1];
377        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
378        let shape = (3, 3);
379
380        let matrix = CooMatrix::new(data, rows, cols, shape).unwrap();
381
382        assert_eq!(matrix.shape(), (3, 3));
383        assert_eq!(matrix.nnz(), 5);
384    }
385
386    #[test]
387    fn test_coo_add_element() {
388        // Create an empty matrix
389        let mut matrix = CooMatrix::<f64>::empty((3, 3));
390
391        // Add elements
392        matrix.add_element(0, 0, 1.0).unwrap();
393        matrix.add_element(0, 2, 2.0).unwrap();
394        matrix.add_element(1, 2, 3.0).unwrap();
395        matrix.add_element(2, 0, 4.0).unwrap();
396        matrix.add_element(2, 1, 5.0).unwrap();
397
398        assert_eq!(matrix.nnz(), 5);
399
400        // Adding element out of bounds should fail
401        assert!(matrix.add_element(3, 0, 6.0).is_err());
402        assert!(matrix.add_element(0, 3, 6.0).is_err());
403    }
404
405    #[test]
406    fn test_coo_to_dense() {
407        // Create a 3x3 sparse matrix with 5 non-zero elements
408        let rows = vec![0, 0, 1, 2, 2];
409        let cols = vec![0, 2, 2, 0, 1];
410        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
411        let shape = (3, 3);
412
413        let matrix = CooMatrix::new(data, rows, cols, shape).unwrap();
414        let dense = matrix.to_dense();
415
416        let expected = vec![
417            vec![1.0, 0.0, 2.0],
418            vec![0.0, 0.0, 3.0],
419            vec![4.0, 5.0, 0.0],
420        ];
421
422        assert_eq!(dense, expected);
423    }
424
425    #[test]
426    fn test_coo_dot() {
427        // Create a 3x3 sparse matrix with 5 non-zero elements
428        let rows = vec![0, 0, 1, 2, 2];
429        let cols = vec![0, 2, 2, 0, 1];
430        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
431        let shape = (3, 3);
432
433        let matrix = CooMatrix::new(data, rows, cols, shape).unwrap();
434
435        // Matrix:
436        // [1 0 2]
437        // [0 0 3]
438        // [4 5 0]
439
440        let vec = vec![1.0, 2.0, 3.0];
441        let result = matrix.dot(&vec).unwrap();
442
443        // Expected:
444        // 1*1 + 0*2 + 2*3 = 7
445        // 0*1 + 0*2 + 3*3 = 9
446        // 4*1 + 5*2 + 0*3 = 14
447        let expected = [7.0, 9.0, 14.0];
448
449        assert_eq!(result.len(), expected.len());
450        for (a, b) in result.iter().zip(expected.iter()) {
451            assert_relative_eq!(a, b, epsilon = 1e-10);
452        }
453    }
454
455    #[test]
456    fn test_coo_transpose() {
457        // Create a 3x3 sparse matrix with 5 non-zero elements
458        let rows = vec![0, 0, 1, 2, 2];
459        let cols = vec![0, 2, 2, 0, 1];
460        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
461        let shape = (3, 3);
462
463        let matrix = CooMatrix::new(data, rows, cols, shape).unwrap();
464        let transposed = matrix.transpose();
465
466        assert_eq!(transposed.shape(), (3, 3));
467        assert_eq!(transposed.nnz(), 5);
468
469        let dense = transposed.to_dense();
470        let expected = vec![
471            vec![1.0, 0.0, 4.0],
472            vec![0.0, 0.0, 5.0],
473            vec![2.0, 3.0, 0.0],
474        ];
475
476        assert_eq!(dense, expected);
477    }
478
479    #[test]
480    fn test_coo_sort_and_sum_duplicates() {
481        // Create a matrix with duplicate entries
482        let rows = vec![0, 0, 0, 1, 1, 2];
483        let cols = vec![0, 0, 1, 0, 0, 1];
484        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
485        let shape = (3, 2);
486
487        let mut matrix = CooMatrix::new(data, rows, cols, shape).unwrap();
488        matrix.sum_duplicates();
489
490        assert_eq!(matrix.nnz(), 4); // Should have 4 unique entries after summing
491
492        let dense = matrix.to_dense();
493        let expected = vec![vec![3.0, 3.0], vec![9.0, 0.0], vec![0.0, 6.0]];
494
495        assert_eq!(dense, expected);
496    }
497
498    #[test]
499    fn test_coo_to_csr_to_csc() {
500        // Create a 3x3 sparse matrix with 5 non-zero elements
501        let rows = vec![0, 0, 1, 2, 2];
502        let cols = vec![0, 2, 2, 0, 1];
503        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
504        let shape = (3, 3);
505
506        let coo_matrix = CooMatrix::new(data, rows, cols, shape).unwrap();
507
508        // Convert to CSR and CSC
509        let csr_matrix = coo_matrix.to_csr();
510        let csc_matrix = coo_matrix.to_csc();
511
512        // Convert back to dense and compare
513        let dense_from_coo = coo_matrix.to_dense();
514        let dense_from_csr = csr_matrix.to_dense();
515        let dense_from_csc = csc_matrix.to_dense();
516
517        assert_eq!(dense_from_coo, dense_from_csr);
518        assert_eq!(dense_from_coo, dense_from_csc);
519    }
520}