iterative_solvers/utils/
helper.rs1#[cfg(feature = "faer")]
2use faer::{unzip, zip};
3
4use crate::{IterSolverError, IterSolverResult, Matrix, Vector};
5#[cfg(not(feature = "ndarray"))]
6use crate::{SparseCscMatrix, SparseCsrMatrix};
7#[cfg(feature = "ndarray")]
8use sprs::CsMat;
9
10pub(crate) fn is_vector(_mat: &Vector<f64>) -> bool {
12 #[cfg(feature = "nalgebra")]
13 {
14 true
15 }
16 #[cfg(feature = "faer")]
17 {
18 _mat.ncols() == 1
19 }
20 #[cfg(feature = "ndarray")]
21 {
22 true
23 }
24}
25
26pub(crate) fn dot(lhs: &Vector<f64>, rhs: &Vector<f64>) -> IterSolverResult<f64> {
28 if !is_vector(lhs) {
29 return Err(IterSolverError::InvalidInput(
30 "The input parameter is not a vector".to_string(),
31 ));
32 }
33 if !is_vector(rhs) {
34 return Err(IterSolverError::InvalidInput(
35 "The input parameter is not a vector".to_string(),
36 ));
37 }
38
39 #[cfg(not(feature = "ndarray"))]
40 if lhs.nrows() != rhs.nrows() {
41 return Err(IterSolverError::InvalidInput(
42 "The input parameter is not a vector".to_string(),
43 ));
44 }
45
46 #[cfg(feature = "ndarray")]
47 if lhs.len() != rhs.len() {
48 return Err(IterSolverError::InvalidInput(
49 "The input parameter is not a vector".to_string(),
50 ));
51 }
52
53 #[cfg(feature = "faer")]
54 {
55 let mut result = 0.0;
56 zip!(lhs, rhs).for_each(|unzip!(lhs_val, rhs_val)| {
57 result += lhs_val * rhs_val;
58 });
59 Ok(result)
60 }
61 #[cfg(feature = "nalgebra")]
62 {
63 Ok(lhs.dot(rhs))
64 }
65
66 #[cfg(feature = "ndarray")]
67 {
68 Ok(lhs.dot(rhs))
69 }
70}
71
72pub(crate) fn axpy(v: &mut Vector<f64>, alpha: f64, x: &Vector<f64>, beta: f64) {
74 #[cfg(feature = "nalgebra")]
75 {
76 v.axpy(alpha, x, beta);
77 }
78 #[cfg(feature = "faer")]
79 {
80 *v *= beta;
81 *v += alpha * x;
82 }
83 #[cfg(feature = "ndarray")]
84 {
85 *v *= beta;
86 *v += &(alpha * x);
87 }
88}
89
90pub fn zeros(n: usize) -> Vector<f64> {
91 #[cfg(feature = "nalgebra")]
92 {
93 Vector::zeros(n)
94 }
95 #[cfg(feature = "faer")]
96 {
97 Vector::zeros(n, 1)
98 }
99 #[cfg(feature = "ndarray")]
100 {
101 Vector::zeros(n)
102 }
103}
104
105pub fn norm_l2(mat: &Vector<f64>) -> f64 {
106 #[cfg(feature = "nalgebra")]
107 {
108 mat.norm()
109 }
110 #[cfg(feature = "faer")]
111 {
112 mat.norm_l2()
113 }
114 #[cfg(feature = "ndarray")]
115 {
116 use ndarray_linalg::Norm;
117
118 mat.norm_l2()
119 }
120}
121
122pub(crate) fn from_diagonal(data: &[f64]) -> Matrix<f64> {
123 #[cfg(feature = "faer")]
124 {
125 let n = data.len();
126 let mut mat = Matrix::zeros(n, n);
127
128 data.iter().enumerate().for_each(|(idx, &val)| unsafe {
129 *mat.get_mut_unchecked(idx, idx) = val;
130 });
131 mat
132 }
133 #[cfg(feature = "nalgebra")]
134 {
135 Matrix::from_diagonal(&Vector::from_column_slice(data))
136 }
137 #[cfg(feature = "ndarray")]
138 {
139 use ndarray::arr1;
140
141 Matrix::from_diag(&arr1(data))
142 }
143}
144
145pub(crate) unsafe fn get_mut_unchecked<T>(mat: &mut Matrix<T>, row: usize, col: usize) -> &mut T {
149 #[cfg(feature = "nalgebra")]
150 {
151 unsafe { mat.get_unchecked_mut((row, col)) }
152 }
153 #[cfg(feature = "faer")]
154 {
155 unsafe { mat.get_mut_unchecked(row, col) }
156 }
157 #[cfg(feature = "ndarray")]
158 {
159 mat.get_mut((row, col)).unwrap()
160 }
161}
162
163#[cfg(not(feature = "ndarray"))]
164pub(crate) fn empty_spcsr() -> SparseCsrMatrix<f64> {
165 #[cfg(feature = "nalgebra")]
166 {
167 SparseCsrMatrix::zeros(0, 0)
168 }
169 #[cfg(feature = "faer")]
170 {
171 SparseCsrMatrix::try_new_from_triplets(0, 0, &[]).unwrap()
172 }
173}
174
175#[cfg(not(feature = "ndarray"))]
176pub(crate) fn empty_spcsc() -> SparseCscMatrix<f64> {
177 #[cfg(feature = "nalgebra")]
178 {
179 SparseCscMatrix::zeros(0, 0)
180 }
181 #[cfg(feature = "faer")]
182 {
183 SparseCscMatrix::try_new_from_triplets(0, 0, &[]).unwrap()
184 }
185}
186
187#[cfg(feature = "ndarray")]
188pub(crate) fn empty_spcsr() -> CsMat<f64> {
189 CsMat::empty(sprs::CompressedStorage::CSR, 0)
190}
191
192#[cfg(feature = "ndarray")]
193pub(crate) fn empty_spcsc() -> CsMat<f64> {
194 CsMat::empty(sprs::CompressedStorage::CSC, 0)
195}