clarabel/algebra/
matrix_traits.rs

1#![allow(non_snake_case)]
2
3use crate::algebra::MatrixConcatenationError;
4use crate::algebra::MatrixShape;
5
6#[cfg_attr(not(feature = "sdp"), allow(dead_code))]
7pub(crate) trait ShapedMatrix {
8    fn shape(&self) -> MatrixShape;
9    fn size(&self) -> (usize, usize);
10    fn nrows(&self) -> usize {
11        self.size().0
12    }
13    fn ncols(&self) -> usize {
14        self.size().1
15    }
16    fn is_square(&self) -> bool {
17        self.nrows() == self.ncols()
18    }
19}
20
21/// Checks for upper and lower triangular matrix data
22pub trait TriangularMatrixChecks {
23    /// True if the matrix is lower triangular
24    fn is_triu(&self) -> bool;
25    /// True if the matrix is lower triangular
26    fn is_tril(&self) -> bool;
27}
28
29/// Blockwise matrix concatenation
30pub trait BlockConcatenate: Sized {
31    /// horizontal matrix concatenation
32    ///
33    /// ```text
34    /// C = [A B]
35    /// ```
36    ///
37    /// Errors if row dimensions are incompatible
38    fn hcat(A: &Self, B: &Self) -> Result<Self, MatrixConcatenationError>;
39
40    /// vertical matrix concatenation
41    ///
42    /// ```text
43    /// C = [ A ]
44    ///     [ B ]
45    /// ```
46    ///
47    /// Errors if column dimensions are incompatible
48    fn vcat(A: &Self, B: &Self) -> Result<Self, MatrixConcatenationError>;
49
50    /// general block concatenation
51    fn hvcat(mats: &[&[&Self]]) -> Result<Self, MatrixConcatenationError>;
52
53    /// block diagonal concatenation
54    fn blockdiag(mats: &[&Self]) -> Result<Self, MatrixConcatenationError>;
55}
56
57pub(crate) fn hvcat_dim_check<MAT: ShapedMatrix>(
58    mats: &[&[&MAT]],
59) -> Result<(), MatrixConcatenationError> {
60    // error if no blocks
61    if mats.is_empty() || mats[0].is_empty() {
62        return Err(MatrixConcatenationError::IncompatibleDimension);
63    };
64
65    // error unless every block row has the same number of blocks
66    let len0 = mats[0].len();
67    for mat in mats.iter().skip(1) {
68        if mat.len() != len0 {
69            return Err(MatrixConcatenationError::IncompatibleDimension);
70        }
71    }
72
73    // check for block dimensional consistency across block row and columns
74
75    //row checks
76    for blockrow in mats {
77        let rows = blockrow[0].nrows();
78        for mat in blockrow.iter().skip(1) {
79            if mat.nrows() != rows {
80                return Err(MatrixConcatenationError::IncompatibleDimension);
81            }
82        }
83    }
84
85    // column checks
86    for (blockcol, topblock) in mats[0].iter().enumerate() {
87        let cols = topblock.ncols();
88        for matrow in mats.iter().skip(1) {
89            if matrow[blockcol].ncols() != cols {
90                return Err(MatrixConcatenationError::IncompatibleDimension);
91            }
92        }
93    }
94
95    Ok(())
96}