1use crate::cholesky::CholeskyFactorization;
8use crate::lu::LUFactorization;
9use crate::matrix::DenseMatrix;
10use crate::Scalar;
11use faer::sparse::SparseColMat;
12use faer::{ComplexField, Conjugate, Entity, SimpleEntity};
13use numra_core::LinalgError;
14
15pub struct SparseMatrix<S: Scalar + Entity> {
24 inner: SparseColMat<usize, S>,
25 nrows: usize,
26 ncols: usize,
27}
28
29impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> SparseMatrix<S> {
30 pub fn from_triplets(
34 nrows: usize,
35 ncols: usize,
36 triplets: &[(usize, usize, S)],
37 ) -> Result<Self, LinalgError> {
38 let faer_triplets: Vec<(usize, usize, S)> = triplets.to_vec();
39 let inner =
40 SparseColMat::try_new_from_triplets(nrows, ncols, &faer_triplets).map_err(|_| {
41 LinalgError::DimensionMismatch {
42 expected: (nrows, ncols),
43 actual: (0, 0),
44 }
45 })?;
46 Ok(Self {
47 inner,
48 nrows,
49 ncols,
50 })
51 }
52
53 pub fn nrows(&self) -> usize {
55 self.nrows
56 }
57
58 pub fn ncols(&self) -> usize {
60 self.ncols
61 }
62
63 pub fn nnz(&self) -> usize {
65 self.inner.compute_nnz()
66 }
67
68 pub fn get(&self, row: usize, col: usize) -> S {
70 let col_ptrs = self.inner.col_ptrs();
71 let row_indices = self.inner.row_indices();
72 let values = self.inner.values();
73
74 let start = col_ptrs[col];
75 let end = col_ptrs[col + 1];
76 for idx in start..end {
77 if row_indices[idx] == row {
78 return values[idx];
79 }
80 }
81 S::ZERO
82 }
83
84 pub fn to_dense(&self) -> DenseMatrix<S> {
86 DenseMatrix::from_faer(self.inner.to_dense())
87 }
88
89 pub fn col_ptrs(&self) -> Vec<usize> {
91 self.inner.col_ptrs().to_vec()
92 }
93
94 pub fn row_indices(&self) -> Vec<usize> {
96 self.inner.row_indices().to_vec()
97 }
98
99 pub fn values(&self) -> Vec<S> {
101 let vals = self.inner.values();
102 (0..vals.len()).map(|i| vals[i]).collect()
103 }
104
105 pub fn mul_vec(&self, x: &[S]) -> Result<Vec<S>, LinalgError> {
107 if x.len() != self.ncols {
108 return Err(LinalgError::DimensionMismatch {
109 expected: (self.ncols, 1),
110 actual: (x.len(), 1),
111 });
112 }
113
114 let col_ptrs = self.inner.col_ptrs();
115 let row_indices = self.inner.row_indices();
116 let values = self.inner.values();
117
118 let mut y = vec![S::ZERO; self.nrows];
119
120 for j in 0..self.ncols {
121 let start = col_ptrs[j];
122 let end = col_ptrs[j + 1];
123 for idx in start..end {
124 let i = row_indices[idx];
125 y[i] += values[idx] * x[j];
126 }
127 }
128
129 Ok(y)
130 }
131
132 pub fn transpose(&self) -> Result<SparseMatrix<S>, LinalgError> {
134 let col_ptrs = self.inner.col_ptrs();
136 let row_indices = self.inner.row_indices();
137 let values = self.inner.values();
138
139 let mut triplets = Vec::with_capacity(self.nnz());
140 for j in 0..self.ncols {
141 let start = col_ptrs[j];
142 let end = col_ptrs[j + 1];
143 for idx in start..end {
144 let i = row_indices[idx];
145 triplets.push((j, i, values[idx]));
146 }
147 }
148
149 SparseMatrix::from_triplets(self.ncols, self.nrows, &triplets)
150 }
151}
152
153pub struct SparseLU<S: Scalar + Entity> {
158 lu: LUFactorization<S>,
159 n: usize,
160}
161
162impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> SparseLU<S> {
163 pub fn new(matrix: &SparseMatrix<S>) -> Result<Self, LinalgError> {
165 if matrix.nrows() != matrix.ncols() {
166 return Err(LinalgError::NotSquare {
167 nrows: matrix.nrows(),
168 ncols: matrix.ncols(),
169 });
170 }
171 let n = matrix.nrows();
172 let dense = matrix.to_dense();
173 let lu = LUFactorization::new(&dense)?;
174 Ok(Self { lu, n })
175 }
176
177 pub fn solve(&self, b: &[S]) -> Result<Vec<S>, LinalgError> {
179 self.lu.solve(b)
180 }
181
182 pub fn dim(&self) -> usize {
184 self.n
185 }
186}
187
188pub struct SparseCholesky<S: Scalar + Entity> {
192 chol: CholeskyFactorization<S>,
193 n: usize,
194}
195
196impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> SparseCholesky<S> {
197 pub fn new(matrix: &SparseMatrix<S>) -> Result<Self, LinalgError> {
199 if matrix.nrows() != matrix.ncols() {
200 return Err(LinalgError::NotSquare {
201 nrows: matrix.nrows(),
202 ncols: matrix.ncols(),
203 });
204 }
205 let n = matrix.nrows();
206 let dense = matrix.to_dense();
207 let chol = CholeskyFactorization::new(&dense)?;
208 Ok(Self { chol, n })
209 }
210
211 pub fn solve(&self, b: &[S]) -> Result<Vec<S>, LinalgError> {
213 self.chol.solve(b)
214 }
215
216 pub fn dim(&self) -> usize {
218 self.n
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use crate::Matrix;
226
227 #[test]
228 fn test_identity_from_triplets() {
229 let triplets = vec![(0, 0, 1.0), (1, 1, 1.0), (2, 2, 1.0)];
230 let m = SparseMatrix::from_triplets(3, 3, &triplets).unwrap();
231
232 assert_eq!(m.nrows(), 3);
233 assert_eq!(m.ncols(), 3);
234 assert_eq!(m.nnz(), 3);
235
236 assert!((m.get(0, 0) - 1.0).abs() < 1e-15);
237 assert!((m.get(1, 1) - 1.0).abs() < 1e-15);
238 assert!((m.get(2, 2) - 1.0).abs() < 1e-15);
239 assert!(m.get(0, 1).abs() < 1e-15);
240 }
241
242 #[test]
243 fn test_tridiagonal() {
244 let triplets = vec![
246 (0, 0, -2.0),
247 (0, 1, 1.0),
248 (1, 0, 1.0),
249 (1, 1, -2.0),
250 (1, 2, 1.0),
251 (2, 1, 1.0),
252 (2, 2, -2.0),
253 ];
254 let m = SparseMatrix::from_triplets(3, 3, &triplets).unwrap();
255
256 assert_eq!(m.nnz(), 7);
257 assert!((m.get(0, 0) - (-2.0)).abs() < 1e-15);
258 assert!((m.get(0, 1) - 1.0).abs() < 1e-15);
259 assert!((m.get(2, 2) - (-2.0)).abs() < 1e-15);
260 assert!(m.get(0, 2).abs() < 1e-15);
261 }
262
263 #[test]
264 fn test_duplicate_entries_summed() {
265 let triplets = vec![(0, 0, 3.0), (0, 0, 4.0), (1, 1, 1.0)];
267 let m = SparseMatrix::from_triplets(2, 2, &triplets).unwrap();
268
269 assert!((m.get(0, 0) - 7.0).abs() < 1e-15);
270 assert!((m.get(1, 1) - 1.0).abs() < 1e-15);
271 }
272
273 #[test]
274 fn test_to_dense_roundtrip() {
275 let triplets = vec![(0, 0, 1.0), (0, 1, 2.0), (1, 0, 3.0), (1, 1, 4.0)];
276 let sparse = SparseMatrix::from_triplets(2, 2, &triplets).unwrap();
277 let dense = sparse.to_dense();
278
279 assert!((dense.get(0, 0) - 1.0).abs() < 1e-15);
280 assert!((dense.get(0, 1) - 2.0).abs() < 1e-15);
281 assert!((dense.get(1, 0) - 3.0).abs() < 1e-15);
282 assert!((dense.get(1, 1) - 4.0).abs() < 1e-15);
283 }
284
285 #[test]
286 fn test_spmv_tridiagonal() {
287 let triplets = vec![
289 (0, 0, 2.0),
290 (0, 1, -1.0),
291 (1, 0, -1.0),
292 (1, 1, 2.0),
293 (1, 2, -1.0),
294 (2, 1, -1.0),
295 (2, 2, 2.0),
296 ];
297 let m = SparseMatrix::from_triplets(3, 3, &triplets).unwrap();
298 let x = vec![1.0, 1.0, 1.0];
299 let y = m.mul_vec(&x).unwrap();
300
301 assert!((y[0] - 1.0).abs() < 1e-10);
302 assert!((y[1] - 0.0).abs() < 1e-10);
303 assert!((y[2] - 1.0).abs() < 1e-10);
304 }
305
306 #[test]
307 fn test_sparse_lu_solve() {
308 let triplets = vec![(0, 0, 1.0), (0, 1, 2.0), (1, 0, 3.0), (1, 1, 4.0)];
310 let sparse = SparseMatrix::from_triplets(2, 2, &triplets).unwrap();
311 let lu = SparseLU::new(&sparse).unwrap();
312
313 let b = vec![5.0, 11.0];
314 let x = lu.solve(&b).unwrap();
315
316 assert!((x[0] - 1.0).abs() < 1e-10);
317 assert!((x[1] - 2.0).abs() < 1e-10);
318 }
319
320 #[test]
321 fn test_sparse_lu_matches_dense() {
322 let triplets = vec![
324 (0, 0, 4.0),
325 (0, 1, -1.0),
326 (1, 0, -1.0),
327 (1, 1, 4.0),
328 (1, 2, -1.0),
329 (2, 1, -1.0),
330 (2, 2, 4.0),
331 (2, 3, -1.0),
332 (3, 2, -1.0),
333 (3, 3, 4.0),
334 ];
335 let sparse = SparseMatrix::from_triplets(4, 4, &triplets).unwrap();
336 let dense = sparse.to_dense();
337
338 let b = vec![1.0, 2.0, 3.0, 4.0];
339
340 let x_dense = dense.solve(&b).unwrap();
341 let lu = SparseLU::new(&sparse).unwrap();
342 let x_sparse = lu.solve(&b).unwrap();
343
344 for i in 0..4 {
345 assert!(
346 (x_dense[i] - x_sparse[i]).abs() < 1e-10,
347 "Mismatch at {}: {} vs {}",
348 i,
349 x_dense[i],
350 x_sparse[i]
351 );
352 }
353 }
354
355 #[test]
356 fn test_sparse_cholesky_solve() {
357 let triplets = vec![(0, 0, 4.0), (0, 1, 2.0), (1, 0, 2.0), (1, 1, 3.0)];
359 let sparse = SparseMatrix::from_triplets(2, 2, &triplets).unwrap();
360 let chol = SparseCholesky::new(&sparse).unwrap();
361
362 let b = vec![6.0, 5.0];
363 let x = chol.solve(&b).unwrap();
364
365 assert!((x[0] - 1.0).abs() < 1e-10);
366 assert!((x[1] - 1.0).abs() < 1e-10);
367 }
368
369 #[test]
370 fn test_transpose() {
371 let triplets = vec![(0, 0, 1.0), (0, 1, 2.0), (1, 0, 3.0), (1, 1, 4.0)];
372 let m = SparseMatrix::from_triplets(2, 2, &triplets).unwrap();
373 let mt = m.transpose().unwrap();
374
375 assert!((mt.get(0, 0) - 1.0).abs() < 1e-15);
376 assert!((mt.get(0, 1) - 3.0).abs() < 1e-15);
377 assert!((mt.get(1, 0) - 2.0).abs() < 1e-15);
378 assert!((mt.get(1, 1) - 4.0).abs() < 1e-15);
379 }
380
381 #[test]
382 fn test_spmv_dimension_mismatch() {
383 let triplets = vec![(0, 0, 1.0), (1, 1, 1.0)];
384 let m = SparseMatrix::from_triplets(2, 2, &triplets).unwrap();
385
386 let x = vec![1.0, 2.0, 3.0]; assert!(m.mul_vec(&x).is_err());
388 }
389
390 #[test]
391 fn test_sparse_f32() {
392 let triplets: Vec<(usize, usize, f32)> = vec![(0, 0, 2.0), (1, 1, 3.0)];
393 let m = SparseMatrix::from_triplets(2, 2, &triplets).unwrap();
394 let lu = SparseLU::new(&m).unwrap();
395
396 let b = vec![4.0f32, 9.0f32];
397 let x = lu.solve(&b).unwrap();
398
399 assert!((x[0] - 2.0).abs() < 1e-5);
400 assert!((x[1] - 3.0).abs() < 1e-5);
401 }
402}