1use std::fmt::Debug;
2use std::ops::{Add, Mul, Sub};
3
4use super::extract_block::CscBlock;
5use super::sparsity::MatrixSparsityRef;
6use super::utils::*;
7use super::{Matrix, MatrixCommon, MatrixSparsity};
8use crate::error::{DiffsolError, MatrixError};
9use crate::{DefaultSolver, FaerSparseLU, IndexType, Scalar, Scale};
10use crate::{FaerContext, FaerVec, FaerVecIndex, Vector, VectorIndex};
11
12use faer::reborrow::{Reborrow, ReborrowMut};
13use faer::sparse::ops::{ternary_op_assign_into, union_symbolic};
14use faer::sparse::{Pair, SparseColMat, SymbolicSparseColMat, SymbolicSparseColMatRef, Triplet};
15
16#[derive(Clone, Debug)]
17pub struct FaerSparseMat<T: Scalar> {
18 pub(crate) data: SparseColMat<IndexType, T>,
19 pub(crate) context: FaerContext,
20}
21
22impl<T: Scalar> DefaultSolver for FaerSparseMat<T> {
23 type LS = FaerSparseLU<T>;
24}
25
26impl_matrix_common!(FaerSparseMat<T>, FaerVec<T>, FaerContext, SparseColMat<IndexType, T>);
27
28macro_rules! impl_mul_scalar {
29 ($mat_type:ty, $out:ty) => {
30 impl<'a, T: Scalar> Mul<Scale<T>> for $mat_type {
31 type Output = $out;
32
33 fn mul(self, rhs: Scale<T>) -> Self::Output {
34 let scale: faer::Scale<T> = rhs.into();
35 Self::Output {
36 data: &self.data * scale,
37 context: self.context,
38 }
39 }
40 }
41 };
42}
43
44impl_mul_scalar!(FaerSparseMat<T>, FaerSparseMat<T>);
45impl_mul_scalar!(&FaerSparseMat<T>, FaerSparseMat<T>);
46
47impl_add!(FaerSparseMat<T>, &FaerSparseMat<T>, FaerSparseMat<T>);
48
49impl_sub!(FaerSparseMat<T>, &FaerSparseMat<T>, FaerSparseMat<T>);
50
51impl<T: Scalar> MatrixSparsity<FaerSparseMat<T>> for SymbolicSparseColMat<IndexType> {
52 fn union(
53 self,
54 other: SymbolicSparseColMatRef<IndexType>,
55 ) -> Result<SymbolicSparseColMat<IndexType>, DiffsolError> {
56 union_symbolic(self.rb(), other).map_err(|e| DiffsolError::Other(e.to_string()))
57 }
58
59 fn as_ref(&self) -> SymbolicSparseColMatRef<'_, IndexType> {
60 self.rb()
61 }
62
63 fn nrows(&self) -> IndexType {
64 self.nrows()
65 }
66
67 fn ncols(&self) -> IndexType {
68 self.ncols()
69 }
70
71 fn is_sparse() -> bool {
72 true
73 }
74
75 fn indices(&self) -> Vec<(IndexType, IndexType)> {
76 let mut indices = Vec::with_capacity(self.compute_nnz());
77 for col_i in 0..self.ncols() {
78 for row_j in self.col_range(col_i) {
79 indices.push((row_j, col_i));
80 }
81 }
82 indices
83 }
84
85 fn new_diagonal(n: IndexType) -> Self {
86 let indices = (0..n).map(|i| Pair::new(i, i)).collect::<Vec<_>>();
87 SymbolicSparseColMat::try_new_from_indices(n, n, indices.as_slice())
88 .unwrap()
89 .0
90 }
91
92 fn try_from_indices(
93 nrows: IndexType,
94 ncols: IndexType,
95 indices: Vec<(IndexType, IndexType)>,
96 ) -> Result<Self, DiffsolError> {
97 let indices = indices
98 .iter()
99 .map(|(i, j)| Pair::new(*i, *j))
100 .collect::<Vec<_>>();
101 match Self::try_new_from_indices(nrows, ncols, indices.as_slice()) {
102 Ok((sparsity, _)) => Ok(sparsity),
103 Err(e) => Err(DiffsolError::Other(e.to_string())),
104 }
105 }
106
107 fn get_index(
108 &self,
109 indices: &[(IndexType, IndexType)],
110 ctx: FaerContext,
111 ) -> <<FaerSparseMat<T> as MatrixCommon>::V as Vector>::Index {
112 let col_ptrs = self.col_ptr();
113 let row_indices = self.row_idx();
114 let mut ret = Vec::with_capacity(indices.len());
115 for &(i, j) in indices.iter() {
116 let col_ptr = col_ptrs[j];
117 let next_col_ptr = col_ptrs[j + 1];
118 for (ii, &ri) in row_indices
119 .iter()
120 .enumerate()
121 .take(next_col_ptr)
122 .skip(col_ptr)
123 {
124 if ri == i {
125 ret.push(ii);
126 break;
127 }
128 }
129 }
130 FaerVecIndex {
131 data: ret,
132 context: ctx,
133 }
134 }
135}
136
137impl<'a, T: Scalar> MatrixSparsityRef<'a, FaerSparseMat<T>>
138 for SymbolicSparseColMatRef<'a, IndexType>
139{
140 fn to_owned(&self) -> SymbolicSparseColMat<IndexType> {
141 self.to_owned().unwrap()
142 }
143 fn nrows(&self) -> IndexType {
144 self.nrows()
145 }
146
147 fn ncols(&self) -> IndexType {
148 self.ncols()
149 }
150
151 fn is_sparse() -> bool {
152 true
153 }
154
155 fn split(
156 &self,
157 indices: &<<FaerSparseMat<T> as MatrixCommon>::V as Vector>::Index,
158 ) -> [(
159 SymbolicSparseColMat<IndexType>,
160 <<FaerSparseMat<T> as MatrixCommon>::V as Vector>::Index,
161 ); 4] {
162 let (_ni, _nj, col_ptrs, _col_nnz, row_idx) = self.parts();
163 let ctx = indices.context();
164 let (ul_blk, ur_blk, ll_blk, lr_blk) = CscBlock::split(row_idx, col_ptrs, indices);
165 let ul_sym = SymbolicSparseColMat::new_checked(
166 ul_blk.nrows,
167 ul_blk.ncols,
168 ul_blk.col_pointers,
169 None,
170 ul_blk.row_indices,
171 );
172 let ur_sym = SymbolicSparseColMat::new_checked(
173 ur_blk.nrows,
174 ur_blk.ncols,
175 ur_blk.col_pointers,
176 None,
177 ur_blk.row_indices,
178 );
179 let ll_sym = SymbolicSparseColMat::new_checked(
180 ll_blk.nrows,
181 ll_blk.ncols,
182 ll_blk.col_pointers,
183 None,
184 ll_blk.row_indices,
185 );
186 let lr_sym = SymbolicSparseColMat::new_checked(
187 lr_blk.nrows,
188 lr_blk.ncols,
189 lr_blk.col_pointers,
190 None,
191 lr_blk.row_indices,
192 );
193 [
194 (
195 ul_sym,
196 FaerVecIndex {
197 data: ul_blk.src_indices,
198 context: *ctx,
199 },
200 ),
201 (
202 ur_sym,
203 FaerVecIndex {
204 data: ur_blk.src_indices,
205 context: *ctx,
206 },
207 ),
208 (
209 ll_sym,
210 FaerVecIndex {
211 data: ll_blk.src_indices,
212 context: *ctx,
213 },
214 ),
215 (
216 lr_sym,
217 FaerVecIndex {
218 data: lr_blk.src_indices,
219 context: *ctx,
220 },
221 ),
222 ]
223 }
224
225 fn indices(&self) -> Vec<(IndexType, IndexType)> {
226 let mut indices = Vec::with_capacity(self.compute_nnz());
227 for col_i in 0..self.ncols() {
228 for row_j in self.col_range(col_i) {
229 indices.push((row_j, col_i));
230 }
231 }
232 indices
233 }
234}
235
236impl<T: Scalar> Matrix for FaerSparseMat<T> {
237 type Sparsity = SymbolicSparseColMat<IndexType>;
238 type SparsityRef<'a> = SymbolicSparseColMatRef<'a, IndexType>;
239
240 fn sparsity(&self) -> Option<Self::SparsityRef<'_>> {
241 Some(self.data.symbolic())
242 }
243 fn context(&self) -> &FaerContext {
244 &self.context
245 }
246
247 fn gather(&mut self, other: &Self, indices: &<Self::V as Vector>::Index) {
248 let dst_data = self.data.val_mut();
249 let src_data = other.data.val();
250 for (dst_i, idx) in dst_data.iter_mut().zip(indices.data.iter()) {
251 *dst_i = src_data[*idx];
252 }
253 }
254
255 fn set_data_with_indices(
256 &mut self,
257 dst_indices: &<Self::V as Vector>::Index,
258 src_indices: &<Self::V as Vector>::Index,
259 data: &Self::V,
260 ) {
261 let values = self.data.val_mut();
262 for (dst_i, src_i) in dst_indices.data.iter().zip(src_indices.data.iter()) {
263 values[*dst_i] = data[*src_i];
264 }
265 }
266
267 fn add_column_to_vector(&self, j: IndexType, v: &mut Self::V) {
268 for i in self.data.col_range(j) {
269 let row = self.data.row_idx()[i];
270 v[row] += self.data.val()[i];
271 }
272 }
273
274 fn triplet_iter(&self) -> impl Iterator<Item = (IndexType, IndexType, Self::T)> {
275 (0..self.ncols()).flat_map(move |j| {
276 self.data.col_range(j).map(move |i| {
277 let row = self.data.row_idx()[i];
278 (row, j, self.data.val()[i])
279 })
280 })
281 }
282
283 fn try_from_triplets(
284 nrows: IndexType,
285 ncols: IndexType,
286 triplets: Vec<(IndexType, IndexType, T)>,
287 ctx: Self::C,
288 ) -> Result<Self, DiffsolError> {
289 let triplets = triplets
290 .iter()
291 .map(|(i, j, v)| Triplet::new(*i, *j, *v))
292 .collect::<Vec<_>>();
293 match faer::sparse::SparseColMat::try_new_from_triplets(nrows, ncols, triplets.as_slice()) {
294 Ok(mat) => Ok(Self {
295 data: mat,
296 context: ctx,
297 }),
298 Err(e) => Err(DiffsolError::from(
299 MatrixError::FailedToCreateMatrixFromTriplets(e),
300 )),
301 }
302 }
303 fn gemv(&self, alpha: Self::T, x: &Self::V, beta: Self::T, y: &mut Self::V) {
304 let tmp = Self::V {
305 data: &self.data * &x.data,
306 context: self.context,
307 };
308 y.axpy(alpha, &tmp, beta);
309 }
310 fn zeros(nrows: IndexType, ncols: IndexType, ctx: Self::C) -> Self {
311 Self {
312 data: SparseColMat::try_new_from_triplets(nrows, ncols, &[]).unwrap(),
313 context: ctx,
314 }
315 }
316 fn copy_from(&mut self, other: &Self) {
317 self.data = faer::sparse::SparseColMat::new(
318 other.data.symbolic().to_owned().unwrap(),
319 other.data.val().to_vec(),
320 )
321 }
322 fn from_diagonal(v: &FaerVec<T>) -> Self {
323 let dim = v.len();
324 let triplets = (0..dim)
325 .map(|i| Triplet::new(i, i, v[i]))
326 .collect::<Vec<_>>();
327 Self {
328 data: SparseColMat::try_new_from_triplets(dim, dim, &triplets).unwrap(),
329 context: *v.context(),
330 }
331 }
332
333 fn partition_indices_by_zero_diagonal(
334 &self,
335 ) -> (<Self::V as Vector>::Index, <Self::V as Vector>::Index) {
336 let mut indices_zero_diag = vec![];
337 let mut indices_non_zero_diag = vec![];
338 'outer: for j in 0..self.ncols() {
339 for (i, v) in self.data.row_idx_of_col(j).zip(self.data.val_of_col(j)) {
340 if i == j && *v != T::zero() {
341 indices_non_zero_diag.push(j);
342 continue 'outer;
343 } else if i > j {
344 break;
345 }
346 }
347 indices_zero_diag.push(j);
348 }
349 (
350 <Self::V as Vector>::Index::from_vec(indices_zero_diag, self.context),
351 <Self::V as Vector>::Index::from_vec(indices_non_zero_diag, self.context),
352 )
353 }
354
355 fn set_column(&mut self, j: IndexType, v: &Self::V) {
356 assert_eq!(v.len(), self.nrows());
357 for i in self.data.col_range(j) {
358 let row_i = self.data.row_idx()[i];
359 self.data.val_mut()[i] = v[row_i];
360 }
361 }
362
363 fn scale_add_and_assign(&mut self, x: &Self, beta: Self::T, y: &Self) {
364 ternary_op_assign_into(self.data.rb_mut(), x.data.rb(), y.data.rb(), |s, x, y| {
365 *s = *x.unwrap_or(&T::zero()) + beta * *y.unwrap_or(&T::zero())
366 });
367 }
368
369 fn new_from_sparsity(
370 nrows: IndexType,
371 ncols: IndexType,
372 sparsity: Option<Self::Sparsity>,
373 ctx: Self::C,
374 ) -> Self {
375 let sparsity = sparsity.expect("Sparsity pattern required for sparse matrix");
376 assert_eq!(sparsity.nrows(), nrows);
377 assert_eq!(sparsity.ncols(), ncols);
378 let nnz = sparsity.row_idx().len();
379 Self {
380 data: SparseColMat::new(sparsity, vec![T::zero(); nnz]),
381 context: ctx,
382 }
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use crate::{FaerSparseMat, Matrix};
389 #[test]
390 fn test_triplet_iter() {
391 let triplets = vec![(0, 0, 1.0), (1, 0, 2.0), (2, 2, 3.0), (3, 2, 4.0)];
392 let mat =
393 FaerSparseMat::<f64>::try_from_triplets(4, 3, triplets.clone(), Default::default())
394 .unwrap();
395 let mut iter = mat.triplet_iter();
396 for triplet in triplets {
397 let (i, j, val) = iter.next().unwrap();
398 assert_eq!(i, triplet.0);
399 assert_eq!(j, triplet.1);
400 assert_eq!(val, triplet.2);
401 }
402 }
403
404 #[test]
405 fn test_partition_indices_by_zero_diagonal() {
406 super::super::tests::test_partition_indices_by_zero_diagonal::<FaerSparseMat<f64>>();
407 }
408}