1use std::fmt::Debug;
2use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign};
3
4use crate::error::DiffsolError;
5use crate::scalar::Scale;
6use crate::vector::VectorHost;
7use crate::{Context, IndexType, Scalar, Vector, VectorIndex};
8
9use extract_block::combine;
10use num_traits::{One, Zero};
11use sparsity::{Dense, MatrixSparsity, MatrixSparsityRef};
12
13#[cfg(feature = "cuda")]
14pub mod cuda;
15
16#[cfg(feature = "nalgebra")]
17pub mod dense_nalgebra_serial;
18
19#[cfg(feature = "faer")]
20pub mod dense_faer_serial;
21
22#[cfg(feature = "faer")]
23pub mod sparse_faer;
24
25pub mod default_solver;
26pub mod extract_block;
27pub mod sparsity;
28
29#[macro_use]
30mod utils;
31
32pub trait MatrixCommon: Sized + Debug {
33 type V: Vector<T = Self::T, C = Self::C, Index: VectorIndex<C = Self::C>>;
34 type T: Scalar;
35 type C: Context;
36 type Inner;
37
38 fn nrows(&self) -> IndexType;
39 fn ncols(&self) -> IndexType;
40 fn inner(&self) -> &Self::Inner;
41}
42
43impl<M> MatrixCommon for &M
44where
45 M: MatrixCommon,
46{
47 type T = M::T;
48 type V = M::V;
49 type C = M::C;
50 type Inner = M::Inner;
51
52 fn nrows(&self) -> IndexType {
53 M::nrows(*self)
54 }
55 fn ncols(&self) -> IndexType {
56 M::ncols(*self)
57 }
58 fn inner(&self) -> &Self::Inner {
59 M::inner(*self)
60 }
61}
62
63impl<M> MatrixCommon for &mut M
64where
65 M: MatrixCommon,
66{
67 type T = M::T;
68 type V = M::V;
69 type C = M::C;
70 type Inner = M::Inner;
71
72 fn ncols(&self) -> IndexType {
73 M::ncols(*self)
74 }
75 fn nrows(&self) -> IndexType {
76 M::nrows(*self)
77 }
78 fn inner(&self) -> &Self::Inner {
79 M::inner(*self)
80 }
81}
82
83pub trait MatrixOpsByValue<Rhs = Self, Output = Self>:
84 MatrixCommon + Add<Rhs, Output = Output> + Sub<Rhs, Output = Output>
85{
86}
87
88impl<M, Rhs, Output> MatrixOpsByValue<Rhs, Output> for M where
89 M: MatrixCommon + Add<Rhs, Output = Output> + Sub<Rhs, Output = Output>
90{
91}
92
93pub trait MatrixMutOpsByValue<Rhs = Self>: MatrixCommon + AddAssign<Rhs> + SubAssign<Rhs> {}
94
95impl<M, Rhs> MatrixMutOpsByValue<Rhs> for M where M: MatrixCommon + AddAssign<Rhs> + SubAssign<Rhs> {}
96
97pub trait MatrixRef<M: MatrixCommon>: Mul<Scale<M::T>, Output = M> {}
99impl<RefT, M: MatrixCommon> MatrixRef<M> for RefT where RefT: Mul<Scale<M::T>, Output = M> {}
100
101pub trait MatrixViewMut<'a>:
103 for<'b> MatrixMutOpsByValue<&'b Self>
104 + for<'b> MatrixMutOpsByValue<&'b Self::View>
105 + MulAssign<Scale<Self::T>>
106{
107 type Owned;
108 type View;
109 fn into_owned(self) -> Self::Owned;
110 fn gemm_oo(&mut self, alpha: Self::T, a: &Self::Owned, b: &Self::Owned, beta: Self::T);
111 fn gemm_vo(&mut self, alpha: Self::T, a: &Self::View, b: &Self::Owned, beta: Self::T);
112}
113
114pub trait MatrixView<'a>:
116 for<'b> MatrixOpsByValue<&'b Self::Owned, Self::Owned> + Mul<Scale<Self::T>, Output = Self::Owned>
117{
118 type Owned;
119
120 fn into_owned(self) -> Self::Owned;
121
122 fn gemv_v(
124 &self,
125 alpha: Self::T,
126 x: &<Self::V as Vector>::View<'_>,
127 beta: Self::T,
128 y: &mut Self::V,
129 );
130
131 fn gemv_o(&self, alpha: Self::T, x: &Self::V, beta: Self::T, y: &mut Self::V);
132}
133
134pub trait Matrix: MatrixCommon + Mul<Scale<Self::T>, Output = Self> + Clone + 'static {
136 type Sparsity: MatrixSparsity<Self>;
137 type SparsityRef<'a>: MatrixSparsityRef<'a, Self>
138 where
139 Self: 'a;
140
141 fn sparsity(&self) -> Option<Self::SparsityRef<'_>>;
143
144 fn context(&self) -> &Self::C;
145
146 fn is_sparse() -> bool {
147 Self::zeros(1, 1, Default::default()).sparsity().is_some()
148 }
149
150 fn partition_indices_by_zero_diagonal(
151 &self,
152 ) -> (<Self::V as Vector>::Index, <Self::V as Vector>::Index);
153
154 fn gemv(&self, alpha: Self::T, x: &Self::V, beta: Self::T, y: &mut Self::V);
156
157 fn copy_from(&mut self, other: &Self);
159
160 fn zeros(nrows: IndexType, ncols: IndexType, ctx: Self::C) -> Self;
162
163 fn new_from_sparsity(
165 nrows: IndexType,
166 ncols: IndexType,
167 sparsity: Option<Self::Sparsity>,
168 ctx: Self::C,
169 ) -> Self;
170
171 fn from_diagonal(v: &Self::V) -> Self;
173
174 fn set_column(&mut self, j: IndexType, v: &Self::V);
177
178 fn add_column_to_vector(&self, j: IndexType, v: &mut Self::V);
179
180 fn set_data_with_indices(
185 &mut self,
186 dst_indices: &<Self::V as Vector>::Index,
187 src_indices: &<Self::V as Vector>::Index,
188 data: &Self::V,
189 );
190
191 fn gather(&mut self, other: &Self, indices: &<Self::V as Vector>::Index);
197
198 fn split(
213 &self,
214 algebraic_indices: &<Self::V as Vector>::Index,
215 ) -> [(Self, <Self::V as Vector>::Index); 4] {
216 match self.sparsity() {
217 Some(sp) => sp.split(algebraic_indices).map(|(sp, src_indices)| {
218 let mut m = Self::new_from_sparsity(
219 sp.nrows(),
220 sp.ncols(),
221 Some(sp),
222 self.context().clone(),
223 );
224 m.gather(self, &src_indices);
225 (m, src_indices)
226 }),
227 None => Dense::<Self>::new(self.nrows(), self.ncols())
228 .split(algebraic_indices)
229 .map(|(sp, src_indices)| {
230 let mut m = Self::new_from_sparsity(
231 sp.nrows(),
232 sp.ncols(),
233 None,
234 self.context().clone(),
235 );
236 m.gather(self, &src_indices);
237 (m, src_indices)
238 }),
239 }
240 }
241
242 fn combine(
243 ul: &Self,
244 ur: &Self,
245 ll: &Self,
246 lr: &Self,
247 algebraic_indices: &<Self::V as Vector>::Index,
248 ) -> Self {
249 combine(ul, ur, ll, lr, algebraic_indices)
250 }
251
252 fn scale_add_and_assign(&mut self, x: &Self, beta: Self::T, y: &Self);
255
256 fn triplet_iter(&self) -> impl Iterator<Item = (IndexType, IndexType, Self::T)>;
257
258 fn try_from_triplets(
260 nrows: IndexType,
261 ncols: IndexType,
262 triplets: Vec<(IndexType, IndexType, Self::T)>,
263 ctx: Self::C,
264 ) -> Result<Self, DiffsolError>;
265}
266
267pub trait MatrixHost: Matrix<V: VectorHost> {}
269
270impl<T: Matrix<V: VectorHost>> MatrixHost for T {}
271
272pub trait DenseMatrix:
274 Matrix
275 + for<'b> MatrixOpsByValue<&'b Self, Self>
276 + for<'b> MatrixMutOpsByValue<&'b Self>
277 + for<'a, 'b> MatrixOpsByValue<&'b Self::View<'a>, Self>
278 + for<'a, 'b> MatrixMutOpsByValue<&'b Self::View<'a>>
279{
280 type View<'a>: MatrixView<'a, Owned = Self, T = Self::T, V = Self::V>
282 where
283 Self: 'a;
284
285 type ViewMut<'a>: MatrixViewMut<
287 'a,
288 Owned = Self,
289 T = Self::T,
290 V = Self::V,
291 View = Self::View<'a>,
292 >
293 where
294 Self: 'a;
295
296 fn gemm(&mut self, alpha: Self::T, a: &Self, b: &Self, beta: Self::T);
298
299 fn column_axpy(&mut self, alpha: Self::T, j: IndexType, i: IndexType);
301
302 fn columns(&self, start: IndexType, end: IndexType) -> Self::View<'_>;
304
305 fn column(&self, i: IndexType) -> <Self::V as Vector>::View<'_>;
307
308 fn columns_mut(&mut self, start: IndexType, end: IndexType) -> Self::ViewMut<'_>;
310
311 fn column_mut(&mut self, i: IndexType) -> <Self::V as Vector>::ViewMut<'_>;
313
314 fn set_index(&mut self, i: IndexType, j: IndexType, value: Self::T);
316
317 fn get_index(&self, i: IndexType, j: IndexType) -> Self::T;
319
320 fn mat_mul(&self, b: &Self) -> Self {
322 let nrows = self.nrows();
323 let ncols = b.ncols();
324 let mut ret = Self::zeros(nrows, ncols, self.context().clone());
325 ret.gemm(Self::T::one(), self, b, Self::T::zero());
326 ret
327 }
328
329 fn from_vec(nrows: IndexType, ncols: IndexType, data: Vec<Self::T>, ctx: Self::C) -> Self;
332}
333
334#[cfg(test)]
335mod tests {
336 use super::{DenseMatrix, Matrix};
337 use crate::{scalar::IndexType, VectorIndex};
338
339 pub fn test_partition_indices_by_zero_diagonal<M: Matrix>() {
340 let triplets = vec![(0, 0, 1.0.into()), (1, 1, 2.0.into()), (3, 3, 1.0.into())];
341 let m = M::try_from_triplets(4, 4, triplets, Default::default()).unwrap();
342 let (zero_diagonal_indices, non_zero_diagonal_indices) =
343 m.partition_indices_by_zero_diagonal();
344 assert_eq!(zero_diagonal_indices.clone_as_vec(), vec![2]);
345 assert_eq!(non_zero_diagonal_indices.clone_as_vec(), vec![0, 1, 3]);
346
347 let triplets = vec![
348 (0, 0, 1.0.into()),
349 (1, 1, 2.0.into()),
350 (2, 2, 0.0.into()),
351 (3, 3, 1.0.into()),
352 ];
353 let m = M::try_from_triplets(4, 4, triplets, Default::default()).unwrap();
354 let (zero_diagonal_indices, non_zero_diagonal_indices) =
355 m.partition_indices_by_zero_diagonal();
356 assert_eq!(zero_diagonal_indices.clone_as_vec(), vec![2]);
357 assert_eq!(non_zero_diagonal_indices.clone_as_vec(), vec![0, 1, 3]);
358
359 let triplets = vec![
360 (0, 0, 1.0.into()),
361 (1, 1, 2.0.into()),
362 (2, 2, 3.0.into()),
363 (3, 3, 1.0.into()),
364 ];
365 let m = M::try_from_triplets(4, 4, triplets, Default::default()).unwrap();
366 let (zero_diagonal_indices, non_zero_diagonal_indices) =
367 m.partition_indices_by_zero_diagonal();
368 assert_eq!(
369 zero_diagonal_indices.clone_as_vec(),
370 Vec::<IndexType>::new()
371 );
372 assert_eq!(non_zero_diagonal_indices.clone_as_vec(), vec![0, 1, 2, 3]);
373 }
374
375 pub fn test_column_axpy<M: DenseMatrix>() {
376 let mut a = M::zeros(2, 2, Default::default());
379 a.set_index(0, 0, M::T::from(1.0));
380 a.set_index(0, 1, M::T::from(2.0));
381 a.set_index(1, 0, M::T::from(3.0));
382 a.set_index(1, 1, M::T::from(4.0));
383
384 a.column_axpy(M::T::from(2.0), 0, 1);
386 assert_eq!(a.get_index(0, 0), M::T::from(1.0));
389 assert_eq!(a.get_index(0, 1), M::T::from(4.0));
390 assert_eq!(a.get_index(1, 0), M::T::from(3.0));
391 assert_eq!(a.get_index(1, 1), M::T::from(10.0));
392 }
393}