iterative_solvers/utils/
helper.rs

1#[cfg(feature = "faer")]
2use faer::{unzip, zip};
3
4use crate::{IterSolverError, IterSolverResult, Matrix, Vector};
5#[cfg(not(feature = "ndarray"))]
6use crate::{SparseCscMatrix, SparseCsrMatrix};
7#[cfg(feature = "ndarray")]
8use sprs::CsMat;
9
10/// Check if the matrix is a vector.
11pub(crate) fn is_vector(_mat: &Vector<f64>) -> bool {
12    #[cfg(feature = "nalgebra")]
13    {
14        true
15    }
16    #[cfg(feature = "faer")]
17    {
18        _mat.ncols() == 1
19    }
20    #[cfg(feature = "ndarray")]
21    {
22        true
23    }
24}
25
26/// Compute the dot product of two vectors.
27pub(crate) fn dot(lhs: &Vector<f64>, rhs: &Vector<f64>) -> IterSolverResult<f64> {
28    if !is_vector(lhs) {
29        return Err(IterSolverError::InvalidInput(
30            "The input parameter is not a vector".to_string(),
31        ));
32    }
33    if !is_vector(rhs) {
34        return Err(IterSolverError::InvalidInput(
35            "The input parameter is not a vector".to_string(),
36        ));
37    }
38
39    #[cfg(not(feature = "ndarray"))]
40    if lhs.nrows() != rhs.nrows() {
41        return Err(IterSolverError::InvalidInput(
42            "The input parameter is not a vector".to_string(),
43        ));
44    }
45
46    #[cfg(feature = "ndarray")]
47    if lhs.len() != rhs.len() {
48        return Err(IterSolverError::InvalidInput(
49            "The input parameter is not a vector".to_string(),
50        ));
51    }
52
53    #[cfg(feature = "faer")]
54    {
55        let mut result = 0.0;
56        zip!(lhs, rhs).for_each(|unzip!(lhs_val, rhs_val)| {
57            result += lhs_val * rhs_val;
58        });
59        Ok(result)
60    }
61    #[cfg(feature = "nalgebra")]
62    {
63        Ok(lhs.dot(rhs))
64    }
65
66    #[cfg(feature = "ndarray")]
67    {
68        Ok(lhs.dot(rhs))
69    }
70}
71
72/// self = alpha * x + beta * self
73pub(crate) fn axpy(v: &mut Vector<f64>, alpha: f64, x: &Vector<f64>, beta: f64) {
74    #[cfg(feature = "nalgebra")]
75    {
76        v.axpy(alpha, x, beta);
77    }
78    #[cfg(feature = "faer")]
79    {
80        *v *= beta;
81        *v += alpha * x;
82    }
83    #[cfg(feature = "ndarray")]
84    {
85        *v *= beta;
86        *v += &(alpha * x);
87    }
88}
89
90pub fn zeros(n: usize) -> Vector<f64> {
91    #[cfg(feature = "nalgebra")]
92    {
93        Vector::zeros(n)
94    }
95    #[cfg(feature = "faer")]
96    {
97        Vector::zeros(n, 1)
98    }
99    #[cfg(feature = "ndarray")]
100    {
101        Vector::zeros(n)
102    }
103}
104
105pub fn norm_l2(mat: &Vector<f64>) -> f64 {
106    #[cfg(feature = "nalgebra")]
107    {
108        mat.norm()
109    }
110    #[cfg(feature = "faer")]
111    {
112        mat.norm_l2()
113    }
114    #[cfg(feature = "ndarray")]
115    {
116        use ndarray_linalg::Norm;
117
118        mat.norm_l2()
119    }
120}
121
122pub(crate) fn from_diagonal(data: &[f64]) -> Matrix<f64> {
123    #[cfg(feature = "faer")]
124    {
125        let n = data.len();
126        let mut mat = Matrix::zeros(n, n);
127
128        data.iter().enumerate().for_each(|(idx, &val)| unsafe {
129            *mat.get_mut_unchecked(idx, idx) = val;
130        });
131        mat
132    }
133    #[cfg(feature = "nalgebra")]
134    {
135        Matrix::from_diagonal(&Vector::from_column_slice(data))
136    }
137    #[cfg(feature = "ndarray")]
138    {
139        use ndarray::arr1;
140
141        Matrix::from_diag(&arr1(data))
142    }
143}
144
145/// # Safety
146///
147/// This function is unsafe because it does not check if the row and column indices are valid.
148pub(crate) unsafe fn get_mut_unchecked<T>(mat: &mut Matrix<T>, row: usize, col: usize) -> &mut T {
149    #[cfg(feature = "nalgebra")]
150    {
151        unsafe { mat.get_unchecked_mut((row, col)) }
152    }
153    #[cfg(feature = "faer")]
154    {
155        unsafe { mat.get_mut_unchecked(row, col) }
156    }
157    #[cfg(feature = "ndarray")]
158    {
159        mat.get_mut((row, col)).unwrap()
160    }
161}
162
163#[cfg(not(feature = "ndarray"))]
164pub(crate) fn empty_spcsr() -> SparseCsrMatrix<f64> {
165    #[cfg(feature = "nalgebra")]
166    {
167        SparseCsrMatrix::zeros(0, 0)
168    }
169    #[cfg(feature = "faer")]
170    {
171        SparseCsrMatrix::try_new_from_triplets(0, 0, &[]).unwrap()
172    }
173}
174
175#[cfg(not(feature = "ndarray"))]
176pub(crate) fn empty_spcsc() -> SparseCscMatrix<f64> {
177    #[cfg(feature = "nalgebra")]
178    {
179        SparseCscMatrix::zeros(0, 0)
180    }
181    #[cfg(feature = "faer")]
182    {
183        SparseCscMatrix::try_new_from_triplets(0, 0, &[]).unwrap()
184    }
185}
186
187#[cfg(feature = "ndarray")]
188pub(crate) fn empty_spcsr() -> CsMat<f64> {
189    CsMat::empty(sprs::CompressedStorage::CSR, 0)
190}
191
192#[cfg(feature = "ndarray")]
193pub(crate) fn empty_spcsc() -> CsMat<f64> {
194    CsMat::empty(sprs::CompressedStorage::CSC, 0)
195}