#![allow(non_snake_case)]
use crate::algebra::MatrixConcatenationError;
use crate::algebra::MatrixShape;
#[cfg_attr(not(feature = "sdp"), allow(dead_code))]
pub(crate) trait ShapedMatrix {
fn shape(&self) -> MatrixShape;
fn size(&self) -> (usize, usize);
fn nrows(&self) -> usize {
self.size().0
}
fn ncols(&self) -> usize {
self.size().1
}
fn is_square(&self) -> bool {
self.nrows() == self.ncols()
}
}
pub trait TriangularMatrixChecks {
fn is_triu(&self) -> bool;
fn is_tril(&self) -> bool;
}
pub trait BlockConcatenate: Sized {
fn hcat(A: &Self, B: &Self) -> Result<Self, MatrixConcatenationError>;
fn vcat(A: &Self, B: &Self) -> Result<Self, MatrixConcatenationError>;
fn hvcat(mats: &[&[&Self]]) -> Result<Self, MatrixConcatenationError>;
fn blockdiag(mats: &[&Self]) -> Result<Self, MatrixConcatenationError>;
}
pub(crate) fn hvcat_dim_check<MAT: ShapedMatrix>(
mats: &[&[&MAT]],
) -> Result<(), MatrixConcatenationError> {
if mats.is_empty() || mats[0].is_empty() {
return Err(MatrixConcatenationError::IncompatibleDimension);
};
let len0 = mats[0].len();
for mat in mats.iter().skip(1) {
if mat.len() != len0 {
return Err(MatrixConcatenationError::IncompatibleDimension);
}
}
for blockrow in mats {
let rows = blockrow[0].nrows();
for mat in blockrow.iter().skip(1) {
if mat.nrows() != rows {
return Err(MatrixConcatenationError::IncompatibleDimension);
}
}
}
for (blockcol, topblock) in mats[0].iter().enumerate() {
let cols = topblock.ncols();
for matrow in mats.iter().skip(1) {
if matrow[blockcol].ncols() != cols {
return Err(MatrixConcatenationError::IncompatibleDimension);
}
}
}
Ok(())
}