use na::{DMatrix, DMatrixView, RealField};
use crate::traits::IntoView;
pub struct DiagonalBlockMatrix<T> {
values: Vec<T>,
block_row_offsets: Vec<usize>,
block_element_offsets: Vec<usize>,
}
impl<T> DiagonalBlockMatrix<T> {
#[inline]
pub fn from_block_values(values: Vec<T>, block_sizes: &[usize]) -> Self {
let mut block_row_offsets = Vec::with_capacity(block_sizes.len() + 1);
let mut block_element_offsets = Vec::with_capacity(block_sizes.len() + 1);
let mut block_row_offset = 0;
let mut block_element_offset = 0;
for size in block_sizes.iter().copied() {
block_row_offsets.push(block_row_offset);
block_element_offsets.push(block_element_offset);
block_row_offset += size;
block_element_offset += size * size;
}
block_row_offsets.push(block_row_offset);
block_element_offsets.push(block_element_offset);
Self {
values,
block_row_offsets,
block_element_offsets,
}
}
#[inline]
pub fn nrows(&self) -> usize {
self.block_row_offsets.last().copied().unwrap_or(0)
}
#[inline]
pub fn ncols(&self) -> usize {
self.nrows()
}
}
impl<'a, T> IntoView for &'a DiagonalBlockMatrix<T> {
type View = DiagonalBlockMatrixView<'a, T>;
#[inline]
fn into_view(self) -> Self::View {
DiagonalBlockMatrixView {
values: &self.values,
block_element_offsets: &self.block_element_offsets,
block_row_offsets: &self.block_row_offsets,
}
}
}
#[derive(Clone, Copy)]
pub struct DiagonalBlockMatrixView<'a, T> {
values: &'a [T],
block_row_offsets: &'a [usize],
block_element_offsets: &'a [usize],
}
impl<'a, T> DiagonalBlockMatrixView<'a, T>
where
T: RealField,
{
#[inline]
pub fn from_parts_unchecked(
values: &'a [T],
block_row_offsets: &'a [usize],
block_element_offsets: &'a [usize],
) -> Self {
let slf = Self {
values,
block_row_offsets,
block_element_offsets,
};
#[cfg(debug_assertions)]
slf.assert_valid();
slf
}
#[cfg(debug_assertions)]
fn assert_valid(&self) {
let mut expected_num_values = 0;
let num_blocks = self.num_blocks();
for block_index in 0..num_blocks {
assert!(
self.block_row_offsets[block_index] < self.block_row_offsets[block_index + 1],
"Block sizes must be positive"
);
let block_size =
self.block_row_offsets[block_index + 1] - self.block_row_offsets[block_index];
let num_block_elements = self.block_element_offsets[block_index + 1]
- self.block_element_offsets[block_index];
assert!(
num_block_elements == block_size * block_size,
"Block element offsets do not match the expected size based on block sizes, expected {}, got {}, block_index = {}",
block_size * block_size,
num_block_elements,
block_index
);
expected_num_values += block_size * block_size;
}
assert!(
expected_num_values == self.values.len(),
"Number of values does not match the expected number based on block sizes , expected {}, got {}",
expected_num_values,
self.values.len()
);
}
#[inline]
pub fn values(&self) -> &[T] {
self.values
}
#[inline]
pub fn num_blocks(&self) -> usize {
self.block_row_offsets.len() - 1
}
#[inline]
pub fn nrows(&self) -> usize {
self.block_row_offsets.last().copied().unwrap_or(0)
}
#[inline]
pub fn ncols(&self) -> usize {
self.nrows()
}
#[inline]
pub fn get_block_size(&self, block_index: usize) -> usize {
debug_assert!(
(block_index < self.num_blocks()),
"Block index out of bounds"
);
self.block_row_offsets[block_index + 1] - self.block_row_offsets[block_index]
}
#[inline]
pub fn get_block_row_start(&self, block_index: usize) -> usize {
self.block_row_offsets[block_index]
}
#[inline]
pub fn get_block_row_range(&self, block_index: usize) -> std::ops::Range<usize> {
debug_assert!(
(block_index < self.num_blocks()),
"Block index out of bounds"
);
let start = self.block_row_offsets[block_index];
let end = self.block_row_offsets[block_index + 1];
start..end
}
pub fn view_block(&self, block_index: usize) -> DMatrixView<T> {
debug_assert!(
(block_index < self.num_blocks()),
"Block index out of bounds"
);
let start = self.block_element_offsets[block_index];
let end = self.block_element_offsets[block_index + 1];
let size = self.block_row_offsets[block_index + 1] - self.block_row_offsets[block_index];
debug_assert!(size > 0);
DMatrixView::from_slice(&self.values[start..end], size, size)
}
pub fn to_dense(&self) -> DMatrix<T> {
let nrows = self.nrows();
let ncols = self.ncols();
let mut dense = DMatrix::zeros(nrows, ncols);
for b_idx in 0..self.num_blocks() {
let block = self.view_block(b_idx);
let range = self.get_block_row_range(b_idx);
let mut dst = dense.view_range_mut(range.clone(), range);
dst.copy_from(&block);
}
dense
}
}