scirs2_sparse/
dia.rs

1//! Diagonal (DIA) matrix format
2//!
3//! This module provides the DIA matrix format implementation, which is
4//! efficient for matrices with values concentrated on a small number of diagonals.
5
6use crate::error::{SparseError, SparseResult};
7use scirs2_core::numeric::{SparseElement, Zero};
8
9/// Diagonal (DIA) matrix
10///
11/// A sparse matrix format that stores diagonals, making it efficient for
12/// matrices with values concentrated on a small number of diagonals.
13pub struct DiaMatrix<T> {
14    /// Number of rows
15    rows: usize,
16    /// Number of columns
17    cols: usize,
18    /// Diagonals data (n_diags x max(rows, cols))
19    data: Vec<Vec<T>>,
20    /// Diagonal offsets from the main diagonal
21    offsets: Vec<isize>,
22}
23
24impl<T> DiaMatrix<T>
25where
26    T: Clone + Copy + Zero + std::cmp::PartialEq + SparseElement,
27{
28    /// Create a new DIA matrix from raw data
29    ///
30    /// # Arguments
31    ///
32    /// * `data` - Diagonals data (n_diags x max(rows, cols))
33    /// * `offsets` - Diagonal offsets from the main diagonal
34    /// * `shape` - Tuple containing the matrix dimensions (rows, cols)
35    ///
36    /// # Returns
37    ///
38    /// * A new DIA matrix
39    ///
40    /// # Examples
41    ///
42    /// ```
43    /// use scirs2_sparse::dia::DiaMatrix;
44    ///
45    /// // Create a 3x3 sparse matrix with main diagonal and upper diagonal
46    /// let data = vec![
47    ///     vec![1.0, 2.0, 3.0], // Main diagonal
48    ///     vec![4.0, 5.0, 0.0], // Upper diagonal (k=1)
49    /// ];
50    /// let offsets = vec![0, 1]; // Main diagonal and k=1
51    /// let shape = (3, 3);
52    ///
53    /// let matrix = DiaMatrix::new(data, offsets, shape).unwrap();
54    /// ```
55    pub fn new(
56        data: Vec<Vec<T>>,
57        offsets: Vec<isize>,
58        shape: (usize, usize),
59    ) -> SparseResult<Self> {
60        let (rows, cols) = shape;
61        let max_dim = rows.max(cols);
62
63        // Validate input data
64        if data.len() != offsets.len() {
65            return Err(SparseError::DimensionMismatch {
66                expected: data.len(),
67                found: offsets.len(),
68            });
69        }
70
71        for diag in data.iter() {
72            if diag.len() != max_dim {
73                return Err(SparseError::DimensionMismatch {
74                    expected: max_dim,
75                    found: diag.len(),
76                });
77            }
78        }
79
80        Ok(DiaMatrix {
81            rows,
82            cols,
83            data,
84            offsets,
85        })
86    }
87
88    /// Create a new empty DIA matrix
89    ///
90    /// # Arguments
91    ///
92    /// * `shape` - Tuple containing the matrix dimensions (rows, cols)
93    ///
94    /// # Returns
95    ///
96    /// * A new empty DIA matrix
97    pub fn empty(shape: (usize, usize)) -> Self {
98        let (rows, cols) = shape;
99
100        DiaMatrix {
101            rows,
102            cols,
103            data: Vec::new(),
104            offsets: Vec::new(),
105        }
106    }
107
108    /// Get the number of rows in the matrix
109    pub fn rows(&self) -> usize {
110        self.rows
111    }
112
113    /// Get the number of columns in the matrix
114    pub fn cols(&self) -> usize {
115        self.cols
116    }
117
118    /// Get the shape (dimensions) of the matrix
119    pub fn shape(&self) -> (usize, usize) {
120        (self.rows, self.cols)
121    }
122
123    /// Get the number of non-zero elements in the matrix
124    pub fn nnz(&self) -> usize {
125        let mut count = 0;
126
127        for (diag_idx, &offset) in self.offsets.iter().enumerate() {
128            let diag = &self.data[diag_idx];
129
130            // Calculate valid range for this diagonal
131            let mut start = 0;
132            let mut end = self.rows.min(self.cols);
133
134            if offset < 0 {
135                start = (-offset) as usize;
136            }
137
138            if offset > 0 {
139                end = (self.rows as isize - offset) as usize;
140            }
141
142            // Count non-zeros in the valid range
143            for val in diag.iter().skip(start).take(end - start) {
144                if *val != T::sparse_zero() {
145                    count += 1;
146                }
147            }
148        }
149
150        count
151    }
152
153    /// Convert to dense matrix (as Vec<Vec<T>>)
154    pub fn to_dense(&self) -> Vec<Vec<T>>
155    where
156        T: Zero + Copy + SparseElement,
157    {
158        let mut result = vec![vec![T::sparse_zero(); self.cols]; self.rows];
159
160        for (diag_idx, &offset) in self.offsets.iter().enumerate() {
161            let diag = &self.data[diag_idx];
162
163            if offset >= 0 {
164                // Upper diagonal
165                let offset = offset as usize;
166                for i in 0..self.rows.min(self.cols.saturating_sub(offset)) {
167                    result[i][i + offset] = diag[i];
168                }
169            } else {
170                // Lower diagonal
171                let offset = (-offset) as usize;
172                for i in 0..self.cols.min(self.rows.saturating_sub(offset)) {
173                    result[i + offset][i] = diag[i];
174                }
175            }
176        }
177
178        result
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    #[test]
187    fn test_dia_create() {
188        // Create a 3x3 sparse matrix with main diagonal and upper diagonal
189        let data = vec![
190            vec![1.0, 2.0, 3.0], // Main diagonal
191            vec![4.0, 5.0, 0.0], // Upper diagonal (k=1)
192        ];
193        let offsets = vec![0, 1]; // Main diagonal and k=1
194        let shape = (3, 3);
195
196        let matrix = DiaMatrix::new(data, offsets, shape).unwrap();
197
198        assert_eq!(matrix.shape(), (3, 3));
199        assert_eq!(matrix.nnz(), 5); // 3 on main diagonal, 2 on upper diagonal
200    }
201
202    #[test]
203    fn test_dia_to_dense() {
204        // Create a 3x3 sparse matrix with main diagonal and upper diagonal
205        let data = vec![
206            vec![1.0, 2.0, 3.0], // Main diagonal
207            vec![4.0, 5.0, 0.0], // Upper diagonal (k=1)
208        ];
209        let offsets = vec![0, 1]; // Main diagonal and k=1
210        let shape = (3, 3);
211
212        let matrix = DiaMatrix::new(data, offsets, shape).unwrap();
213        let dense = matrix.to_dense();
214
215        let expected = vec![
216            vec![1.0, 4.0, 0.0],
217            vec![0.0, 2.0, 5.0],
218            vec![0.0, 0.0, 3.0],
219        ];
220
221        assert_eq!(dense, expected);
222    }
223}