1#[cfg(feature = "nalgebra")]
2use nalgebra::{DMatrix, DVector};
3
4#[cfg(feature = "nalgebra")]
5use nalgebra_sparse::{
6    CscMatrix, CsrMatrix,
7    ops::{
8        Op,
9        serial::{spmm_csc_dense, spmm_csr_dense},
10    },
11};
12
13#[cfg(feature = "faer")]
14use faer::{
15    Mat,
16    linalg::matmul::matmul,
17    matrix_free::LinOp,
18    sparse::{SparseColMat, SparseRowMat, linalg::matmul::sparse_dense_matmul},
19};
20#[cfg(feature = "ndarray")]
21use ndarray::{Array1, Array2};
22
23#[cfg(feature = "ndarray")]
24use sprs::CsMat;
25
26#[cfg(feature = "nalgebra")]
27pub(crate) type Vector<T> = DVector<T>;
28
29#[cfg(feature = "nalgebra")]
30pub(crate) type Matrix<T> = DMatrix<T>;
31
32#[cfg(feature = "nalgebra")]
33pub(crate) type SparseCsrMatrix<T> = CsrMatrix<T>;
34
35#[cfg(feature = "nalgebra")]
36pub(crate) type SparseCscMatrix<T> = CscMatrix<T>;
37
38#[cfg(feature = "faer")]
39pub(crate) type Vector<T> = Mat<T>;
40
41#[cfg(feature = "faer")]
42pub(crate) type Matrix<T> = Mat<T>;
43
44#[cfg(feature = "faer")]
45pub(crate) type SparseCsrMatrix<T> = SparseRowMat<usize, T>;
46
47#[cfg(feature = "faer")]
48pub(crate) type SparseCscMatrix<T> = SparseColMat<usize, T>;
49
50#[cfg(feature = "ndarray")]
51pub(crate) type Vector<T> = Array1<T>;
52
53#[cfg(feature = "ndarray")]
54pub(crate) type Matrix<T> = Array2<T>;
55
56pub trait MatrixOp {
58    fn nrows(&self) -> usize;
60
61    fn ncols(&self) -> usize;
63
64    fn is_square(&self) -> bool;
66
67    fn gemv(&self, alpha: f64, x: &Vector<f64>, beta: f64, y: &mut Vector<f64>);
69
70    fn is_empty(&self) -> bool;
72
73    #[cfg(feature = "faer")]
74    fn len(&self) -> usize {
75        self.nrows() * self.ncols()
76    }
77}
78
79impl MatrixOp for Matrix<f64> {
80    fn nrows(&self) -> usize {
81        self.nrows()
82    }
83
84    fn ncols(&self) -> usize {
85        self.ncols()
86    }
87    fn is_square(&self) -> bool {
88        #[cfg(feature = "nalgebra")]
89        {
90            self.is_square()
91        }
92        #[cfg(feature = "faer")]
93        {
94            self.nrows() == self.ncols()
95        }
96        #[cfg(feature = "ndarray")]
97        {
98            self.nrows() == self.ncols()
99        }
100    }
101
102    fn gemv(&self, alpha: f64, x: &Vector<f64>, beta: f64, y: &mut Vector<f64>) {
103        #[cfg(feature = "nalgebra")]
104        {
105            y.gemv(alpha, self, x, beta)
106        }
107        #[cfg(feature = "faer")]
108        {
109            *y *= beta;
110            matmul(y, faer::Accum::Add, self, x, alpha, faer::Par::Seq);
111        }
112        #[cfg(feature = "ndarray")]
113        {
114            use ndarray::linalg::general_mat_vec_mul;
115
116            general_mat_vec_mul(alpha, self, x, beta, y);
117        }
118    }
119
120    fn is_empty(&self) -> bool {
121        #[cfg(feature = "nalgebra")]
122        {
123            self.is_empty()
124        }
125        #[cfg(feature = "faer")]
126        {
127            self.nrows() == 0 || self.ncols() == 0
128        }
129        #[cfg(feature = "ndarray")]
130        {
131            self.nrows() == 0 || self.ncols() == 0
132        }
133    }
134}
135
136#[cfg(not(feature = "ndarray"))]
137impl MatrixOp for SparseCsrMatrix<f64> {
138    fn nrows(&self) -> usize {
139        #[cfg(feature = "nalgebra")]
140        {
141            self.nrows()
142        }
143        #[cfg(feature = "faer")]
144        {
145            <Self as LinOp<f64>>::nrows(self)
146        }
147    }
148
149    fn ncols(&self) -> usize {
150        #[cfg(feature = "nalgebra")]
151        {
152            self.ncols()
153        }
154        #[cfg(feature = "faer")]
155        {
156            <Self as LinOp<f64>>::ncols(self)
157        }
158    }
159
160    fn is_square(&self) -> bool {
161        #[cfg(feature = "nalgebra")]
162        {
163            self.nrows() == self.ncols()
164        }
165        #[cfg(feature = "faer")]
166        {
167            <Self as LinOp<f64>>::nrows(self) == <Self as LinOp<f64>>::ncols(self)
168        }
169    }
170
171    fn gemv(&self, alpha: f64, x: &Vector<f64>, beta: f64, y: &mut Vector<f64>) {
172        #[cfg(feature = "nalgebra")]
173        {
174            spmm_csr_dense(beta, y, alpha, Op::NoOp(self), Op::NoOp(x))
175        }
176        #[cfg(feature = "faer")]
177        {
178            *y *= beta;
179            let col_mat = self.to_col_major().unwrap();
180            sparse_dense_matmul(
181                y.as_mut(),
182                faer::Accum::Add,
183                col_mat.as_ref(),
184                x.as_ref(),
185                alpha,
186                faer::Par::Seq,
187            );
188        }
189    }
190
191    fn is_empty(&self) -> bool {
192        #[cfg(feature = "nalgebra")]
193        {
194            self.nrows() == 0 || self.ncols() == 0
195        }
196        #[cfg(feature = "faer")]
197        {
198            <Self as LinOp<f64>>::nrows(self) == 0 || <Self as LinOp<f64>>::ncols(self) == 0
199        }
200    }
201}
202
203#[cfg(not(feature = "ndarray"))]
204impl MatrixOp for SparseCscMatrix<f64> {
205    fn nrows(&self) -> usize {
206        #[cfg(feature = "nalgebra")]
207        {
208            self.nrows()
209        }
210        #[cfg(feature = "faer")]
211        {
212            <Self as LinOp<f64>>::nrows(self)
213        }
214    }
215
216    fn ncols(&self) -> usize {
217        #[cfg(feature = "nalgebra")]
218        {
219            self.ncols()
220        }
221        #[cfg(feature = "faer")]
222        {
223            <Self as LinOp<f64>>::ncols(self)
224        }
225    }
226
227    fn is_square(&self) -> bool {
228        #[cfg(feature = "nalgebra")]
229        {
230            self.nrows() == self.ncols()
231        }
232        #[cfg(feature = "faer")]
233        {
234            <Self as LinOp<f64>>::nrows(self) == <Self as LinOp<f64>>::ncols(self)
235        }
236    }
237
238    fn gemv(&self, alpha: f64, x: &Vector<f64>, beta: f64, y: &mut Vector<f64>) {
239        #[cfg(feature = "nalgebra")]
240        {
241            spmm_csc_dense(beta, y, alpha, Op::NoOp(self), Op::NoOp(x))
242        }
243        #[cfg(feature = "faer")]
244        {
245            *y *= beta;
246            sparse_dense_matmul(
247                y.as_mut(),
248                faer::Accum::Add,
249                self.as_ref(),
250                x.as_ref(),
251                alpha,
252                faer::Par::Seq,
253            );
254        }
255    }
256
257    fn is_empty(&self) -> bool {
258        #[cfg(feature = "nalgebra")]
259        {
260            self.nrows() == 0 || self.ncols() == 0
261        }
262        #[cfg(feature = "faer")]
263        {
264            <Self as LinOp<f64>>::nrows(self) == 0 || <Self as LinOp<f64>>::ncols(self) == 0
265        }
266    }
267}
268
269#[cfg(feature = "ndarray")]
270impl MatrixOp for CsMat<f64> {
271    fn gemv(&self, alpha: f64, x: &Vector<f64>, beta: f64, y: &mut Vector<f64>) {
272        *y *= beta;
273        *y += &(alpha * (self * x));
274    }
275
276    fn nrows(&self) -> usize {
277        self.rows()
278    }
279
280    fn ncols(&self) -> usize {
281        self.cols()
282    }
283
284    fn is_empty(&self) -> bool {
285        self.rows() == 0 || self.cols() == 0
286    }
287
288    fn is_square(&self) -> bool {
289        self.rows() == self.cols()
290    }
291}