use nalgebra::{DMatrix, Dyn, PermutationSequence, Scalar};
use nalgebra_block_triangularization::{
lower_block_triangular_structure, lower_triangular_permutations,
};
fn apply_perms<T: Scalar + Copy>(
mut m: DMatrix<T>,
pr: &PermutationSequence<Dyn>,
pc: &PermutationSequence<Dyn>,
) -> DMatrix<T> {
pr.permute_rows(&mut m);
pc.permute_columns(&mut m);
m
}
fn is_lower_block_triangular_u8(m: &DMatrix<u8>, block_sizes: &[usize]) -> bool {
let n = m.nrows();
if n != m.ncols() {
return false;
}
if block_sizes.iter().sum::<usize>() != n {
return false;
}
let mut row_block = vec![0usize; n];
let mut col_block = vec![0usize; n];
let mut idx = 0usize;
for (b, &sz) in block_sizes.iter().enumerate() {
for _ in 0..sz {
row_block[idx] = b;
col_block[idx] = b;
idx += 1;
}
}
for i in 0..n {
for j in 0..n {
if m[(i, j)] != 0 && row_block[i] < col_block[j] {
return false;
}
}
}
true
}
#[test]
fn example_matrix_produces_lower_block_triangular_form() {
let data: [[u8; 8]; 8] = [
[1, 0, 1, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 0, 0, 0, 0],
[1, 1, 0, 1, 1, 0, 0, 0],
[1, 1, 0, 1, 1, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 1, 1, 0],
[1, 1, 1, 0, 0, 1, 1, 0],
[1, 1, 0, 0, 0, 0, 1, 1],
];
let m = DMatrix::from_fn(8, 8, |i, j| data[i][j]);
let structure = lower_block_triangular_structure(&m);
let (pr, pc) = lower_triangular_permutations(&m);
let l = apply_perms(m.clone(), &pr, &pc);
assert_eq!(structure.matching_size, 8);
assert!(is_lower_block_triangular_u8(&l, &structure.block_sizes));
}
#[test]
fn empty_matrix() {
let m: DMatrix<u8> = DMatrix::zeros(0, 0);
let structure = lower_block_triangular_structure(&m);
let (pr, pc) = lower_triangular_permutations(&m);
assert_eq!(structure.matching_size, 0);
assert_eq!(structure.block_sizes.len(), 0);
assert_eq!(structure.row_order.len(), 0);
assert_eq!(structure.col_order.len(), 0);
let l = apply_perms(m.clone(), &pr, &pc);
assert_eq!(l.nrows(), 0);
assert_eq!(l.ncols(), 0);
}
#[test]
fn single_element_nonzero() {
let m = DMatrix::from_element(1, 1, 1u8);
let structure = lower_block_triangular_structure(&m);
let (pr, pc) = lower_triangular_permutations(&m);
assert_eq!(structure.matching_size, 1);
assert_eq!(structure.block_sizes, vec![1]);
assert_eq!(structure.row_order, vec![0]);
assert_eq!(structure.col_order, vec![0]);
let l = apply_perms(m.clone(), &pr, &pc);
assert_eq!(l[(0, 0)], 1);
}
#[test]
fn single_element_zero() {
let m = DMatrix::from_element(1, 1, 0u8);
let structure = lower_block_triangular_structure(&m);
assert_eq!(structure.matching_size, 0);
assert_eq!(structure.row_order, vec![0]);
assert_eq!(structure.col_order, vec![0]);
}
#[test]
fn identity_matrix() {
let m: DMatrix<f64> = DMatrix::identity(5, 5);
let structure = lower_block_triangular_structure(&m);
let (pr, pc) = lower_triangular_permutations(&m);
assert_eq!(structure.matching_size, 5);
assert_eq!(structure.block_sizes.len(), 5);
assert_eq!(structure.block_sizes.iter().sum::<usize>(), 5);
let l = apply_perms(m.clone(), &pr, &pc);
assert!(is_lower_block_triangular_u8(
&l.map(|x| if x != 0.0 { 1 } else { 0 }),
&structure.block_sizes
));
}
#[test]
fn all_zeros_matrix() {
let m: DMatrix<u8> = DMatrix::zeros(4, 4);
let structure = lower_block_triangular_structure(&m);
assert_eq!(structure.matching_size, 0);
assert_eq!(structure.row_order.len(), 4);
assert_eq!(structure.col_order.len(), 4);
}
#[test]
fn all_ones_matrix() {
let m = DMatrix::from_element(4, 4, 1u8);
let structure = lower_block_triangular_structure(&m);
let (pr, pc) = lower_triangular_permutations(&m);
assert_eq!(structure.matching_size, 4);
assert_eq!(structure.block_sizes.len(), 1);
assert_eq!(structure.block_sizes[0], 4);
let l = apply_perms(m.clone(), &pr, &pc);
assert!(is_lower_block_triangular_u8(&l, &structure.block_sizes));
}
#[test]
fn rectangular_more_rows() {
let m = DMatrix::from_row_slice(5, 3, &[1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0]);
let structure = lower_block_triangular_structure(&m);
assert_eq!(structure.matching_size, 3);
assert_eq!(structure.row_order.len(), 5);
assert_eq!(structure.col_order.len(), 3);
}
#[test]
fn rectangular_more_cols() {
let m = DMatrix::from_row_slice(3, 5, &[1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0]);
let structure = lower_block_triangular_structure(&m);
assert_eq!(structure.matching_size, 3);
assert_eq!(structure.row_order.len(), 3);
assert_eq!(structure.col_order.len(), 5);
let unmatched_cols = structure.col_order[3..].to_vec();
assert_eq!(unmatched_cols.len(), 2);
}
#[test]
fn triangular_already_lower() {
let m = DMatrix::from_row_slice(4, 4, &[1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1]);
let structure = lower_block_triangular_structure(&m);
let (pr, pc) = lower_triangular_permutations(&m);
assert_eq!(structure.matching_size, 4);
let l = apply_perms(m.clone(), &pr, &pc);
assert!(is_lower_block_triangular_u8(&l, &structure.block_sizes));
}
#[test]
fn triangular_upper() {
let m = DMatrix::from_row_slice(4, 4, &[1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1]);
let structure = lower_block_triangular_structure(&m);
let (pr, pc) = lower_triangular_permutations(&m);
assert_eq!(structure.matching_size, 4);
let l = apply_perms(m.clone(), &pr, &pc);
assert!(is_lower_block_triangular_u8(&l, &structure.block_sizes));
}
#[test]
fn block_diagonal() {
let m = DMatrix::from_row_slice(4, 4, &[1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1]);
let structure = lower_block_triangular_structure(&m);
let (pr, pc) = lower_triangular_permutations(&m);
assert_eq!(structure.matching_size, 4);
assert_eq!(structure.block_sizes.len(), 2);
let l = apply_perms(m.clone(), &pr, &pc);
assert!(is_lower_block_triangular_u8(&l, &structure.block_sizes));
}
#[test]
fn structurally_singular() {
let m = DMatrix::from_row_slice(
4,
4,
&[
1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0,
],
);
let structure = lower_block_triangular_structure(&m);
assert_eq!(structure.matching_size, 3);
}
#[test]
fn cyclic_dependency() {
let m = DMatrix::from_row_slice(
3,
3,
&[
0, 1, 1, 1, 0, 1, 1, 1, 0, ],
);
let structure = lower_block_triangular_structure(&m);
let (pr, pc) = lower_triangular_permutations(&m);
assert_eq!(structure.matching_size, 3);
assert_eq!(structure.block_sizes.len(), 1);
assert_eq!(structure.block_sizes[0], 3);
let l = apply_perms(m.clone(), &pr, &pc);
assert!(is_lower_block_triangular_u8(&l, &structure.block_sizes));
}
#[test]
fn sparse_pattern() {
let m = DMatrix::from_row_slice(
6,
6,
&[
1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1,
1, 0, 0, 0, 1, 1, 1,
],
);
let structure = lower_block_triangular_structure(&m);
let (pr, pc) = lower_triangular_permutations(&m);
assert_eq!(structure.matching_size, 6);
let l = apply_perms(m.clone(), &pr, &pc);
assert!(is_lower_block_triangular_u8(&l, &structure.block_sizes));
}
#[test]
fn different_scalar_types() {
let m_f64 = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 0.0, 3.0, 4.0, 0.0, 0.0, 5.0, 6.0]);
let structure = lower_block_triangular_structure(&m_f64);
assert_eq!(structure.matching_size, 3);
let m_i32 = DMatrix::from_row_slice(3, 3, &[1, 2, 0, 3, 4, 0, 0, 5, 6]);
let structure = lower_block_triangular_structure(&m_i32);
assert_eq!(structure.matching_size, 3);
}
#[test]
fn permutations_are_invertible() {
let m = DMatrix::from_row_slice(4, 4, &[0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0]);
let structure = lower_block_triangular_structure(&m);
let (pr, pc) = lower_triangular_permutations(&m);
let l = apply_perms(m.clone(), &pr, &pc);
assert!(is_lower_block_triangular_u8(&l, &structure.block_sizes));
assert_eq!(structure.row_order.len(), 4);
assert_eq!(structure.col_order.len(), 4);
}
#[test]
fn block_indices_method() {
let m = DMatrix::from_row_slice(
6,
6,
&[
1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1,
1, 0, 0, 0, 1, 1, 1,
],
);
let structure = lower_block_triangular_structure(&m);
let blocks = structure.block_indices();
assert_eq!(blocks.len(), structure.block_sizes.len());
for (i, (row_block, col_block)) in blocks.iter().enumerate() {
assert_eq!(row_block.len(), structure.block_sizes[i]);
assert_eq!(col_block.len(), structure.block_sizes[i]);
}
let all_rows: Vec<usize> = blocks
.iter()
.flat_map(|(rows, _)| rows.iter().copied())
.collect();
let mut sorted_rows = all_rows.clone();
sorted_rows.sort_unstable();
sorted_rows.dedup();
assert_eq!(sorted_rows.len(), 6);
assert_eq!(sorted_rows, vec![0, 1, 2, 3, 4, 5]);
}