use super::SpecializedMatrix;
use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign, One, Zero};
use std::fmt::Debug;
use std::iter::Sum;
#[derive(Debug, Clone)]
pub struct BlockDiagonalMatrix<A>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
blocks: Vec<Array2<A>>,
size: usize,
block_offsets: Vec<usize>,
}
impl<A> BlockDiagonalMatrix<A>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
pub fn new(blocks: Vec<Array2<A>>) -> LinalgResult<Self> {
if blocks.is_empty() {
return Err(LinalgError::InvalidInput(
"At least one block is required".to_string(),
));
}
for (i, block) in blocks.iter().enumerate() {
if block.nrows() != block.ncols() {
return Err(LinalgError::InvalidInput(format!(
"Block {} is not square: {}x{}",
i,
block.nrows(),
block.ncols()
)));
}
}
let mut block_offsets = Vec::with_capacity(blocks.len());
let mut offset = 0;
for block in &blocks {
block_offsets.push(offset);
offset += block.nrows();
}
let size = offset;
Ok(Self {
blocks,
size,
block_offsets,
})
}
pub fn from_views(blocks: Vec<ArrayView2<A>>) -> LinalgResult<Self> {
let owned_blocks: Vec<Array2<A>> = blocks.into_iter().map(|b| b.to_owned()).collect();
Self::new(owned_blocks)
}
pub fn num_blocks(&self) -> usize {
self.blocks.len()
}
pub fn block(&self, index: usize) -> Option<&Array2<A>> {
self.blocks.get(index)
}
pub fn block_mut(&mut self, index: usize) -> Option<&mut Array2<A>> {
self.blocks.get_mut(index)
}
pub fn blocksize(&self, index: usize) -> Option<usize> {
self.blocks.get(index).map(|b| b.nrows())
}
pub fn block_offset(&self, index: usize) -> Option<usize> {
self.block_offsets.get(index).copied()
}
fn find_block(&self, index: usize) -> Option<usize> {
if index >= self.size {
return None;
}
for (block_idx, &offset) in self.block_offsets.iter().enumerate() {
let blocksize = self.blocks[block_idx].nrows();
if index >= offset && index < offset + blocksize {
return Some(block_idx);
}
}
None
}
pub fn solve(&self, b: &ArrayView1<A>) -> LinalgResult<Array1<A>> {
if b.len() != self.size {
return Err(LinalgError::ShapeError(format!(
"Right-hand side has length {}, expected {}",
b.len(),
self.size
)));
}
let mut x = Array1::zeros(self.size);
for (block_idx, block) in self.blocks.iter().enumerate() {
let offset = self.block_offsets[block_idx];
let blocksize = block.nrows();
let b_block = b.slice(scirs2_core::ndarray::s![offset..offset + blocksize]);
let x_block = crate::solve::solve(
&block.view(),
&b_block,
Some(1), )?;
x.slice_mut(scirs2_core::ndarray::s![offset..offset + blocksize])
.assign(&x_block);
}
Ok(x)
}
pub fn determinant(&self) -> LinalgResult<A> {
let mut det = A::one();
for block in &self.blocks {
let block_det = crate::basic::det(&block.view(), Some(1))?; det *= block_det;
}
Ok(det)
}
pub fn inverse(&self) -> LinalgResult<Self> {
let mut inv_blocks = Vec::with_capacity(self.blocks.len());
for block in &self.blocks {
let inv_block = crate::basic::inv(&block.view(), Some(1))?; inv_blocks.push(inv_block);
}
Self::new(inv_blocks)
}
pub fn trace(&self) -> A {
let mut trace = A::zero();
for block in &self.blocks {
for i in 0..block.nrows() {
trace += block[[i, i]];
}
}
trace
}
}
impl<A> SpecializedMatrix<A> for BlockDiagonalMatrix<A>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
fn nrows(&self) -> usize {
self.size
}
fn ncols(&self) -> usize {
self.size
}
fn get(&self, i: usize, j: usize) -> LinalgResult<A> {
if i >= self.size || j >= self.size {
return Err(LinalgError::IndexError(format!(
"Index ({}, {}) out of bounds for {}x{} matrix",
i, j, self.size, self.size
)));
}
let block_i = self.find_block(i);
let block_j = self.find_block(j);
match (block_i, block_j) {
(Some(bi), Some(bj)) if bi == bj => {
let offset = self.block_offsets[bi];
let local_i = i - offset;
let local_j = j - offset;
Ok(self.blocks[bi][[local_i, local_j]])
}
_ => {
Ok(A::zero())
}
}
}
fn matvec(&self, x: &ArrayView1<A>) -> LinalgResult<Array1<A>> {
if x.len() != self.size {
return Err(LinalgError::ShapeError(format!(
"Vector has length {}, expected {}",
x.len(),
self.size
)));
}
let mut result = Array1::zeros(self.size);
for (block_idx, block) in self.blocks.iter().enumerate() {
let offset = self.block_offsets[block_idx];
let blocksize = block.nrows();
let x_block = x.slice(scirs2_core::ndarray::s![offset..offset + blocksize]);
let y_block = block.dot(&x_block);
result
.slice_mut(scirs2_core::ndarray::s![offset..offset + blocksize])
.assign(&y_block);
}
Ok(result)
}
fn matvec_transpose(&self, x: &ArrayView1<A>) -> LinalgResult<Array1<A>> {
if x.len() != self.size {
return Err(LinalgError::ShapeError(format!(
"Vector has length {}, expected {}",
x.len(),
self.size
)));
}
let mut result = Array1::zeros(self.size);
for (block_idx, block) in self.blocks.iter().enumerate() {
let offset = self.block_offsets[block_idx];
let blocksize = block.nrows();
let x_block = x.slice(scirs2_core::ndarray::s![offset..offset + blocksize]);
let block_t = block.t();
let y_block = block_t.dot(&x_block);
result
.slice_mut(scirs2_core::ndarray::s![offset..offset + blocksize])
.assign(&y_block);
}
Ok(result)
}
fn to_dense(&self) -> LinalgResult<Array2<A>> {
let mut dense = Array2::zeros((self.size, self.size));
for (block_idx, block) in self.blocks.iter().enumerate() {
let offset = self.block_offsets[block_idx];
let blocksize = block.nrows();
let mut dense_block = dense.slice_mut(scirs2_core::ndarray::s![
offset..offset + blocksize,
offset..offset + blocksize
]);
dense_block.assign(block);
}
Ok(dense)
}
}
#[allow(dead_code)]
pub fn solve_block_diagonal<A>(
matrix: &BlockDiagonalMatrix<A>,
b: &ArrayView1<A>,
) -> LinalgResult<Array1<A>>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
matrix.solve(b)
}
#[allow(dead_code)]
pub fn block_diagonal_determinant<A>(matrix: &BlockDiagonalMatrix<A>) -> LinalgResult<A>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
matrix.determinant()
}
#[allow(dead_code)]
pub fn create_block_diagonal<A>(blocks: Vec<Array2<A>>) -> LinalgResult<BlockDiagonalMatrix<A>>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
BlockDiagonalMatrix::new(blocks)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_block_diagonal_creation() {
let block1 = array![[1.0, 2.0], [3.0, 4.0]];
let block2 = array![[5.0, 6.0], [7.0, 8.0]];
let block3 = array![[9.0]];
let blocks = vec![block1, block2, block3];
let bdmatrix = BlockDiagonalMatrix::new(blocks).expect("Operation failed");
assert_eq!(bdmatrix.size, 5);
assert_eq!(bdmatrix.num_blocks(), 3);
assert_eq!(bdmatrix.blocksize(0), Some(2));
assert_eq!(bdmatrix.blocksize(1), Some(2));
assert_eq!(bdmatrix.blocksize(2), Some(1));
}
#[test]
fn test_block_diagonal_get() {
let block1 = array![[1.0, 2.0], [3.0, 4.0]];
let block2 = array![[5.0]];
let blocks = vec![block1, block2];
let bdmatrix = BlockDiagonalMatrix::new(blocks).expect("Operation failed");
assert_eq!(bdmatrix.get(0, 0).expect("Operation failed"), 1.0);
assert_eq!(bdmatrix.get(0, 1).expect("Operation failed"), 2.0);
assert_eq!(bdmatrix.get(1, 0).expect("Operation failed"), 3.0);
assert_eq!(bdmatrix.get(1, 1).expect("Operation failed"), 4.0);
assert_eq!(bdmatrix.get(2, 2).expect("Operation failed"), 5.0);
assert_eq!(bdmatrix.get(0, 2).expect("Operation failed"), 0.0);
assert_eq!(bdmatrix.get(2, 0).expect("Operation failed"), 0.0);
}
#[test]
fn test_block_diagonal_matvec() {
let block1 = array![[2.0, 0.0], [0.0, 3.0]];
let block2 = array![[4.0]];
let blocks = vec![block1, block2];
let bdmatrix = BlockDiagonalMatrix::new(blocks).expect("Operation failed");
let x = array![1.0, 2.0, 3.0];
let result = bdmatrix.matvec(&x.view()).expect("Operation failed");
assert_eq!(result, array![2.0, 6.0, 12.0]);
}
#[test]
fn test_block_diagonal_determinant() {
let block1 = array![[2.0, 0.0], [0.0, 3.0]];
let block2 = array![[4.0]];
let blocks = vec![block1, block2];
let bdmatrix = BlockDiagonalMatrix::new(blocks).expect("Operation failed");
let det = bdmatrix.determinant().expect("Operation failed");
assert!((det - 24.0).abs() < 1e-10);
}
#[test]
fn test_block_diagonal_to_dense() {
let block1 = array![[1.0, 2.0], [3.0, 4.0]];
let block2 = array![[5.0]];
let blocks = vec![block1, block2];
let bdmatrix = BlockDiagonalMatrix::new(blocks).expect("Operation failed");
let dense = bdmatrix.to_dense().expect("Operation failed");
let expected = array![[1.0, 2.0, 0.0], [3.0, 4.0, 0.0], [0.0, 0.0, 5.0]];
assert_eq!(dense, expected);
}
}