diffsol/matrix/
mod.rs

1use std::fmt::Debug;
2use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign};
3
4use crate::error::DiffsolError;
5use crate::scalar::Scale;
6use crate::vector::VectorHost;
7use crate::{Context, IndexType, Scalar, Vector, VectorIndex};
8
9use extract_block::combine;
10use num_traits::{One, Zero};
11use sparsity::{Dense, MatrixSparsity, MatrixSparsityRef};
12
13#[cfg(feature = "cuda")]
14pub mod cuda;
15
16#[cfg(feature = "nalgebra")]
17pub mod dense_nalgebra_serial;
18
19#[cfg(feature = "faer")]
20pub mod dense_faer_serial;
21
22#[cfg(feature = "faer")]
23pub mod sparse_faer;
24
25pub mod default_solver;
26pub mod extract_block;
27pub mod sparsity;
28
29#[macro_use]
30mod utils;
31
32/// Common interface for matrix types, providing access to scalar type, context, and dimensions.
33pub trait MatrixCommon: Sized + Debug {
34    type V: Vector<T = Self::T, C = Self::C, Index: VectorIndex<C = Self::C>>;
35    type T: Scalar;
36    type C: Context;
37    type Inner;
38
39    /// Get the number of rows in this matrix.
40    fn nrows(&self) -> IndexType;
41    /// Get the number of columns in this matrix.
42    fn ncols(&self) -> IndexType;
43    /// Get a reference to the inner representation of the matrix.
44    fn inner(&self) -> &Self::Inner;
45}
46
47impl<M> MatrixCommon for &M
48where
49    M: MatrixCommon,
50{
51    type T = M::T;
52    type V = M::V;
53    type C = M::C;
54    type Inner = M::Inner;
55
56    fn nrows(&self) -> IndexType {
57        M::nrows(*self)
58    }
59    fn ncols(&self) -> IndexType {
60        M::ncols(*self)
61    }
62    fn inner(&self) -> &Self::Inner {
63        M::inner(*self)
64    }
65}
66
67impl<M> MatrixCommon for &mut M
68where
69    M: MatrixCommon,
70{
71    type T = M::T;
72    type V = M::V;
73    type C = M::C;
74    type Inner = M::Inner;
75
76    fn ncols(&self) -> IndexType {
77        M::ncols(*self)
78    }
79    fn nrows(&self) -> IndexType {
80        M::nrows(*self)
81    }
82    fn inner(&self) -> &Self::Inner {
83        M::inner(*self)
84    }
85}
86
87/// Operations on matrices by value (addition and subtraction).
88///
89/// This trait defines matrix addition and subtraction when both operands are owned or references.
90pub trait MatrixOpsByValue<Rhs = Self, Output = Self>:
91    MatrixCommon + Add<Rhs, Output = Output> + Sub<Rhs, Output = Output>
92{
93}
94
95impl<M, Rhs, Output> MatrixOpsByValue<Rhs, Output> for M where
96    M: MatrixCommon + Add<Rhs, Output = Output> + Sub<Rhs, Output = Output>
97{
98}
99
100/// In-place operations on matrices (addition and subtraction).
101///
102/// This trait defines in-place matrix addition and subtraction (self += rhs, self -= rhs).
103pub trait MatrixMutOpsByValue<Rhs = Self>: MatrixCommon + AddAssign<Rhs> + SubAssign<Rhs> {}
104
105impl<M, Rhs> MatrixMutOpsByValue<Rhs> for M where M: MatrixCommon + AddAssign<Rhs> + SubAssign<Rhs> {}
106
107/// A trait allowing for references to implement matrix operations
108pub trait MatrixRef<M: MatrixCommon>: Mul<Scale<M::T>, Output = M> {}
109impl<RefT, M: MatrixCommon> MatrixRef<M> for RefT where RefT: Mul<Scale<M::T>, Output = M> {}
110
111/// A mutable view of a dense matrix, supporting in-place operations and modifications.
112///
113/// This trait represents a temporary mutable reference to a matrix's data, allowing in-place
114/// arithmetic operations (+=, -=, *=) and matrix-matrix multiplication. Mutable views can be
115/// created via the `columns_mut()` or `column_mut()` methods on a `DenseMatrix`.
116pub trait MatrixViewMut<'a>:
117    for<'b> MatrixMutOpsByValue<&'b Self>
118    + for<'b> MatrixMutOpsByValue<&'b Self::View>
119    + MulAssign<Scale<Self::T>>
120{
121    type Owned;
122    type View;
123    /// Convert this mutable view into an owned matrix, cloning the data if necessary.
124    fn into_owned(self) -> Self::Owned;
125    /// Perform matrix-matrix multiplication with owned matrices: self = alpha * a * b + beta * self
126    fn gemm_oo(&mut self, alpha: Self::T, a: &Self::Owned, b: &Self::Owned, beta: Self::T);
127    /// Perform matrix-matrix multiplication with a view and owned matrix: self = alpha * a * b + beta * self
128    fn gemm_vo(&mut self, alpha: Self::T, a: &Self::View, b: &Self::Owned, beta: Self::T);
129}
130
131/// A borrowed immutable view of a dense matrix, supporting read-only arithmetic operations.
132///
133/// This trait represents a temporary immutable reference to a matrix's data, allowing read-only
134/// operations like addition, subtraction, scalar multiplication, and matrix-vector multiplication.
135/// Matrix views can be created via the `columns()` methods on a `DenseMatrix`.
136pub trait MatrixView<'a>:
137    for<'b> MatrixOpsByValue<&'b Self::Owned, Self::Owned> + Mul<Scale<Self::T>, Output = Self::Owned>
138{
139    type Owned;
140
141    /// Convert this view into an owned matrix, cloning the data if necessary.
142    fn into_owned(self) -> Self::Owned;
143
144    /// Perform a matrix-vector multiplication with a vector view: y = alpha * self * x + beta * y
145    fn gemv_v(
146        &self,
147        alpha: Self::T,
148        x: &<Self::V as Vector>::View<'_>,
149        beta: Self::T,
150        y: &mut Self::V,
151    );
152
153    /// Perform a matrix-vector multiplication with an owned vector: y = alpha * self * x + beta * y
154    fn gemv_o(&self, alpha: Self::T, x: &Self::V, beta: Self::T, y: &mut Self::V);
155}
156
157/// A base matrix trait supporting both sparse and dense matrices.
158///
159/// This trait provides a complete interface for matrix operations including:
160/// - Matrix creation and memory management
161/// - Matrix-vector and matrix-matrix multiplication
162/// - Element access and modification
163/// - Sparsity information and handling
164/// - Matrix decomposition and combination operations
165/// - Triplet-based construction for sparse matrices
166///
167/// Implementing matrices can be dense or sparse, and may be hosted on CPU or GPU.
168/// Users typically do not need to implement this trait; use provided implementations.
169pub trait Matrix: MatrixCommon + Mul<Scale<Self::T>, Output = Self> + Clone + 'static {
170    type Sparsity: MatrixSparsity<Self>;
171    type SparsityRef<'a>: MatrixSparsityRef<'a, Self>
172    where
173        Self: 'a;
174
175    /// Return sparsity information, or `None` if the matrix is dense.
176    fn sparsity(&self) -> Option<Self::SparsityRef<'_>>;
177
178    /// Get the context associated with this matrix (for device placement, memory management, etc.).
179    fn context(&self) -> &Self::C;
180
181    /// Returns true if this matrix is stored in a sparse format
182    fn is_sparse() -> bool {
183        Self::zeros(1, 1, Default::default()).sparsity().is_some()
184    }
185
186    /// Partition the diagonal indices into two groups: those with zero diagonal elements and those with non-zero diagonal elements.
187    ///
188    /// This is useful for identifying algebraic constraints, which typically have zero diagonal elements in the mass matrix.
189    /// Returns a tuple of (zero_diagonal_indices, non_zero_diagonal_indices).
190    fn partition_indices_by_zero_diagonal(
191        &self,
192    ) -> (<Self::V as Vector>::Index, <Self::V as Vector>::Index);
193
194    /// Perform a matrix-vector multiplication: y = alpha * self * x + beta * y
195    fn gemv(&self, alpha: Self::T, x: &Self::V, beta: Self::T, y: &mut Self::V);
196
197    /// Copy the contents of `other` into this matrix.
198    fn copy_from(&mut self, other: &Self);
199
200    /// Create a new matrix of shape `nrows` x `ncols` filled with zeros.
201    fn zeros(nrows: IndexType, ncols: IndexType, ctx: Self::C) -> Self;
202
203    /// Create a new matrix from a sparsity pattern. Non-zero elements are not initialized.
204    fn new_from_sparsity(
205        nrows: IndexType,
206        ncols: IndexType,
207        sparsity: Option<Self::Sparsity>,
208        ctx: Self::C,
209    ) -> Self;
210
211    /// Create a new diagonal matrix from a vector holding the diagonal elements.
212    fn from_diagonal(v: &Self::V) -> Self;
213
214    /// Set the values of column `j` to be equal to the values in `v`.
215    ///
216    /// For sparse matrices, only the existing non-zero elements are updated.
217    fn set_column(&mut self, j: IndexType, v: &Self::V);
218
219    /// Add a column of this matrix to a vector: v += self[:, j]
220    fn add_column_to_vector(&self, j: IndexType, v: &mut Self::V);
221
222    /// Assign the values in the `data` vector to this matrix at the indices in `dst_indices`
223    /// from the indices in `src_indices`.
224    ///
225    /// For dense matrices, the index is the data index in column-major order.
226    /// For sparse matrices, the index is the index into the data array.
227    fn set_data_with_indices(
228        &mut self,
229        dst_indices: &<Self::V as Vector>::Index,
230        src_indices: &<Self::V as Vector>::Index,
231        data: &Self::V,
232    );
233
234    /// Gather values from another matrix at specified indices into this matrix.
235    ///
236    /// For sparse matrices: the index `idx_i` in `indices` is an index into the data array for `other`,
237    /// and is copied to the index `idx_i` in the data array for this matrix.
238    /// For dense matrices: the index is the data index in column-major order.
239    fn gather(&mut self, other: &Self, indices: &<Self::V as Vector>::Index);
240
241    /// Split this matrix into four submatrices based on algebraic constraint indices.
242    ///
243    /// Partitions the matrix into blocks:
244    /// ```text
245    /// M = [UL, UR]
246    ///     [LL, LR]
247    /// ```
248    /// where:
249    /// - UL contains rows and columns NOT in `algebraic_indices`
250    /// - UR contains rows NOT in `algebraic_indices` and columns in `algebraic_indices`
251    /// - LL contains rows in `algebraic_indices` and columns NOT in `algebraic_indices`
252    /// - LR contains rows and columns in `algebraic_indices`
253    ///
254    /// Returns an array of tuples, where each tuple contains a submatrix and the indices that were used to create it.
255    /// These indices can be used with `gather()` to update the submatrix.
256    fn split(
257        &self,
258        algebraic_indices: &<Self::V as Vector>::Index,
259    ) -> [(Self, <Self::V as Vector>::Index); 4] {
260        match self.sparsity() {
261            Some(sp) => sp.split(algebraic_indices).map(|(sp, src_indices)| {
262                let mut m = Self::new_from_sparsity(
263                    sp.nrows(),
264                    sp.ncols(),
265                    Some(sp),
266                    self.context().clone(),
267                );
268                m.gather(self, &src_indices);
269                (m, src_indices)
270            }),
271            None => Dense::<Self>::new(self.nrows(), self.ncols())
272                .split(algebraic_indices)
273                .map(|(sp, src_indices)| {
274                    let mut m = Self::new_from_sparsity(
275                        sp.nrows(),
276                        sp.ncols(),
277                        None,
278                        self.context().clone(),
279                    );
280                    m.gather(self, &src_indices);
281                    (m, src_indices)
282                }),
283        }
284    }
285
286    /// Combine four submatrices back into a single matrix based on algebraic constraint indices.
287    ///
288    /// Inverse operation of `split()`. Takes submatrices `ul`, `ur`, `ll`, `lr` and combines them
289    /// back into the original matrix structure.
290    fn combine(
291        ul: &Self,
292        ur: &Self,
293        ll: &Self,
294        lr: &Self,
295        algebraic_indices: &<Self::V as Vector>::Index,
296    ) -> Self {
297        combine(ul, ur, ll, lr, algebraic_indices)
298    }
299
300    /// Perform the assignment: self = x + beta * y where x and y are matrices and beta is a scalar.
301    ///
302    /// Note: Panics if the sparsity patterns of self, x, and y do not match.
303    /// The sparsity of self must be the union of the sparsity of x and y.
304    fn scale_add_and_assign(&mut self, x: &Self, beta: Self::T, y: &Self);
305
306    /// Iterate over all non-zero elements in triplet format (row, column, value).
307    fn triplet_iter(&self) -> impl Iterator<Item = (IndexType, IndexType, Self::T)>;
308
309    /// Create a new matrix from a vector of triplets (row, column, value).
310    ///
311    /// This is useful for sparse matrix construction. The sparsity pattern is inferred from the triplets.
312    fn try_from_triplets(
313        nrows: IndexType,
314        ncols: IndexType,
315        triplets: Vec<(IndexType, IndexType, Self::T)>,
316        ctx: Self::C,
317    ) -> Result<Self, DiffsolError>;
318}
319
320/// A host matrix is a matrix type whose vector type is hosted on the CPU.
321///
322/// This trait extends `Matrix` to ensure the associated vector type implements `VectorHost`,
323/// enabling direct CPU-side access to data. GPU matrices typically do not implement this trait.
324pub trait MatrixHost: Matrix<V: VectorHost> {}
325
326impl<T: Matrix<V: VectorHost>> MatrixHost for T {}
327
328/// A dense column-major matrix with efficient column access operations.
329///
330/// This trait represents matrices stored in column-major order, where accessing matrix columns
331/// is efficient. It supports:
332/// - Matrix views and mutable views
333/// - Matrix-matrix multiplication (GEMM)
334/// - Column operations (axpy, access, modification)
335/// - Element access and modification
336/// - Matrix resizing
337///
338/// The column-major layout makes operations on individual or ranges of columns very efficient.
339pub trait DenseMatrix:
340    Matrix
341    + for<'b> MatrixOpsByValue<&'b Self, Self>
342    + for<'b> MatrixMutOpsByValue<&'b Self>
343    + for<'a, 'b> MatrixOpsByValue<&'b Self::View<'a>, Self>
344    + for<'a, 'b> MatrixMutOpsByValue<&'b Self::View<'a>>
345{
346    /// A view of the dense matrix type
347    type View<'a>: MatrixView<'a, Owned = Self, T = Self::T, V = Self::V>
348    where
349        Self: 'a;
350
351    /// A mutable view of the dense matrix type
352    type ViewMut<'a>: MatrixViewMut<
353        'a,
354        Owned = Self,
355        T = Self::T,
356        V = Self::V,
357        View = Self::View<'a>,
358    >
359    where
360        Self: 'a;
361
362    /// Perform a matrix-matrix multiplication: self = alpha * a * b + beta * self
363    fn gemm(&mut self, alpha: Self::T, a: &Self, b: &Self, beta: Self::T);
364
365    /// Perform a column AXPY operation: column i = alpha * column j + column i
366    ///
367    /// This is equivalent to: self[:, i] += alpha * self[:, j]
368    fn column_axpy(&mut self, alpha: Self::T, j: IndexType, i: IndexType);
369
370    /// Get an immutable view of columns from `start` (inclusive) to `end` (exclusive).
371    fn columns(&self, start: IndexType, end: IndexType) -> Self::View<'_>;
372
373    /// Get an immutable vector view of column `i`.
374    fn column(&self, i: IndexType) -> <Self::V as Vector>::View<'_>;
375
376    /// Get a mutable view of columns from `start` (inclusive) to `end` (exclusive).
377    fn columns_mut(&mut self, start: IndexType, end: IndexType) -> Self::ViewMut<'_>;
378
379    /// Get a mutable vector view of column `i`.
380    fn column_mut(&mut self, i: IndexType) -> <Self::V as Vector>::ViewMut<'_>;
381
382    /// Set the value at the given row and column indices.
383    fn set_index(&mut self, i: IndexType, j: IndexType, value: Self::T);
384
385    /// Get the value at the given row and column indices.
386    fn get_index(&self, i: IndexType, j: IndexType) -> Self::T;
387
388    /// Perform matrix-matrix multiplication using GEMM, allocating a new matrix for the result.
389    fn mat_mul(&self, b: &Self) -> Self {
390        let nrows = self.nrows();
391        let ncols = b.ncols();
392        let mut ret = Self::zeros(nrows, ncols, self.context().clone());
393        ret.gemm(Self::T::one(), self, b, Self::T::zero());
394        ret
395    }
396
397    /// Resize the number of columns in the matrix, preserving existing data.
398    ///
399    /// New elements (if added) are uninitialized. If the number of columns decreases, trailing columns are discarded.
400    fn resize_cols(&mut self, ncols: IndexType);
401
402    /// Create a new matrix from a vector of values in column-major order.
403    ///
404    /// The values are assumed to be stored in column-major order (first column, then second column, etc.).
405    fn from_vec(nrows: IndexType, ncols: IndexType, data: Vec<Self::T>, ctx: Self::C) -> Self;
406}
407
408#[cfg(test)]
409mod tests {
410    use super::{DenseMatrix, Matrix};
411    use crate::{scalar::IndexType, VectorIndex};
412    use num_traits::{FromPrimitive, One, Zero};
413
414    pub fn test_partition_indices_by_zero_diagonal<M: Matrix>() {
415        let triplets = vec![
416            (0, 0, M::T::one()),
417            (1, 1, M::T::from_f64(2.0).unwrap()),
418            (3, 3, M::T::one()),
419        ];
420        let m = M::try_from_triplets(4, 4, triplets, Default::default()).unwrap();
421        let (zero_diagonal_indices, non_zero_diagonal_indices) =
422            m.partition_indices_by_zero_diagonal();
423        assert_eq!(zero_diagonal_indices.clone_as_vec(), vec![2]);
424        assert_eq!(non_zero_diagonal_indices.clone_as_vec(), vec![0, 1, 3]);
425
426        let triplets = vec![
427            (0, 0, M::T::one()),
428            (1, 1, M::T::from_f64(2.0).unwrap()),
429            (2, 2, M::T::zero()),
430            (3, 3, M::T::one()),
431        ];
432        let m = M::try_from_triplets(4, 4, triplets, Default::default()).unwrap();
433        let (zero_diagonal_indices, non_zero_diagonal_indices) =
434            m.partition_indices_by_zero_diagonal();
435        assert_eq!(zero_diagonal_indices.clone_as_vec(), vec![2]);
436        assert_eq!(non_zero_diagonal_indices.clone_as_vec(), vec![0, 1, 3]);
437
438        let triplets = vec![
439            (0, 0, M::T::one()),
440            (1, 1, M::T::from_f64(2.0).unwrap()),
441            (2, 2, M::T::from_f64(3.0).unwrap()),
442            (3, 3, M::T::one()),
443        ];
444        let m = M::try_from_triplets(4, 4, triplets, Default::default()).unwrap();
445        let (zero_diagonal_indices, non_zero_diagonal_indices) =
446            m.partition_indices_by_zero_diagonal();
447        assert_eq!(
448            zero_diagonal_indices.clone_as_vec(),
449            Vec::<IndexType>::new()
450        );
451        assert_eq!(non_zero_diagonal_indices.clone_as_vec(), vec![0, 1, 2, 3]);
452    }
453
454    pub fn test_column_axpy<M: DenseMatrix>() {
455        // M = [1 2]
456        //     [3 4]
457        let mut a = M::zeros(2, 2, Default::default());
458        a.set_index(0, 0, M::T::one());
459        a.set_index(0, 1, M::T::from_f64(2.0).unwrap());
460        a.set_index(1, 0, M::T::from_f64(3.0).unwrap());
461        a.set_index(1, 1, M::T::from_f64(4.0).unwrap());
462
463        // op is M(:, 1) = 2 * M(:, 0) + M(:, 1)
464        a.column_axpy(M::T::from_f64(2.0).unwrap(), 0, 1);
465        // M = [1 4]
466        //     [3 10]
467        assert_eq!(a.get_index(0, 0), M::T::one());
468        assert_eq!(a.get_index(0, 1), M::T::from_f64(4.0).unwrap());
469        assert_eq!(a.get_index(1, 0), M::T::from_f64(3.0).unwrap());
470        assert_eq!(a.get_index(1, 1), M::T::from_f64(10.0).unwrap());
471    }
472
473    pub fn test_resize_cols<M: DenseMatrix>() {
474        let mut a = M::zeros(2, 2, Default::default());
475        a.set_index(0, 0, M::T::one());
476        a.set_index(0, 1, M::T::from_f64(2.0).unwrap());
477        a.set_index(1, 0, M::T::from_f64(3.0).unwrap());
478        a.set_index(1, 1, M::T::from_f64(4.0).unwrap());
479
480        a.resize_cols(3);
481        assert_eq!(a.ncols(), 3);
482        assert_eq!(a.nrows(), 2);
483        assert_eq!(a.get_index(0, 0), M::T::one());
484        assert_eq!(a.get_index(0, 1), M::T::from_f64(2.0).unwrap());
485        assert_eq!(a.get_index(1, 0), M::T::from_f64(3.0).unwrap());
486        assert_eq!(a.get_index(1, 1), M::T::from_f64(4.0).unwrap());
487
488        a.set_index(0, 2, M::T::from_f64(5.0).unwrap());
489        a.set_index(1, 2, M::T::from_f64(6.0).unwrap());
490        assert_eq!(a.get_index(0, 2), M::T::from_f64(5.0).unwrap());
491        assert_eq!(a.get_index(1, 2), M::T::from_f64(6.0).unwrap());
492
493        a.resize_cols(2);
494        assert_eq!(a.ncols(), 2);
495        assert_eq!(a.nrows(), 2);
496        assert_eq!(a.get_index(0, 0), M::T::one());
497        assert_eq!(a.get_index(0, 1), M::T::from_f64(2.0).unwrap());
498        assert_eq!(a.get_index(1, 0), M::T::from_f64(3.0).unwrap());
499        assert_eq!(a.get_index(1, 1), M::T::from_f64(4.0).unwrap());
500    }
501}