use crate::traits::IntoView;
use crate::{CsVecBuilder, CsrMatrixView, Real};
#[derive(Clone)]
pub struct CsrMatrixSet<T> {
col_indices: Vec<usize>,
values: Vec<T>,
row_offsets: Vec<usize>,
ncols: Vec<usize>,
partition: Vec<Partition>,
}
impl<T> Default for CsrMatrixSet<T> {
#[inline]
fn default() -> Self {
Self {
col_indices: Vec::new(),
values: Vec::new(),
row_offsets: vec![0],
ncols: Vec::new(),
partition: Vec::new(),
}
}
}
#[derive(Clone)]
struct Partition {
pub value_offset: usize,
pub value_len: usize,
pub row_offset: usize,
pub row_len: usize,
}
impl Partition {
#[inline]
pub fn value_range(&self) -> std::ops::Range<usize> {
self.value_offset..self.value_offset + self.value_len
}
#[inline]
pub fn row_offset_range(&self) -> std::ops::Range<usize> {
self.row_offset..self.row_offset + self.row_len
}
}
impl<T: Real> CsrMatrixSet<T> {
pub fn clear(&mut self) {
self.col_indices.clear();
self.values.clear();
self.row_offsets.clear();
self.ncols.clear();
self.partition.clear();
}
pub fn new_matrix(&mut self, ncol: usize, zero_threshold: T) -> CsrMatrixBuilder<T> {
let value_start = self.values.len();
let row_start = self.row_offsets.len();
self.row_offsets.push(0);
CsrMatrixBuilder {
set: self,
zero_threshold,
value_start,
row_start,
ncol,
}
}
}
impl<T> CsrMatrixSet<T> {
#[inline]
pub fn get(&self, index: usize) -> CsrMatrixView<T> {
let partition = &self.partition[index];
CsrMatrixView::from_parts_unchecked(
&self.row_offsets[partition.row_offset_range()],
&self.col_indices[partition.value_range()],
&self.values[partition.value_range()],
self.ncols[index],
)
}
#[inline]
pub fn as_view(&self) -> CsrMatrixSetView<'_, T> {
CsrMatrixSetView {
col_indices: &self.col_indices,
values: &self.values,
row_offsets: &self.row_offsets,
ncols: &self.ncols,
partition: &self.partition,
}
}
}
#[derive(Clone, Copy)]
pub struct CsrMatrixSetView<'a, T> {
col_indices: &'a [usize],
values: &'a [T],
row_offsets: &'a [usize],
ncols: &'a [usize],
partition: &'a [Partition],
}
impl<'a, T> CsrMatrixSetView<'a, T> {
#[inline]
pub fn get(self, index: usize) -> CsrMatrixView<'a, T> {
let partition = &self.partition[index];
CsrMatrixView::from_parts_unchecked(
&self.row_offsets[partition.row_offset_range()],
&self.col_indices[partition.value_range()],
&self.values[partition.value_range()],
self.ncols[index],
)
}
#[inline]
pub fn len(&self) -> usize {
self.partition.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.partition.is_empty()
}
#[inline]
pub fn split_at(self, index: usize) -> (Self, Self) {
let (left_partition, right_partition) = self.partition.split_at(index);
let (left_ncols, right_ncols) = self.ncols.split_at(index);
let left = CsrMatrixSetView {
col_indices: self.col_indices,
values: self.values,
row_offsets: self.row_offsets,
ncols: left_ncols,
partition: left_partition,
};
let right = CsrMatrixSetView {
col_indices: self.col_indices,
values: self.values,
row_offsets: self.row_offsets,
ncols: right_ncols,
partition: right_partition,
};
(left, right)
}
}
impl<'a, T> IntoView for &'a CsrMatrixSet<T> {
type View = CsrMatrixSetView<'a, T>;
#[inline]
fn into_view(self) -> Self::View {
self.as_view()
}
}
pub trait CsrMatrixSetMethods<V> {
fn len(&self) -> usize;
#[inline]
fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl<'a, T, V> CsrMatrixSetMethods<V> for &'a T
where
&'a T: IntoView<View = CsrMatrixSetView<'a, V>>,
V: Real,
{
#[inline]
fn len(&self) -> usize {
self.into_view().len()
}
}
pub struct CsrMatrixBuilder<'a, T> {
set: &'a mut CsrMatrixSet<T>,
value_start: usize,
row_start: usize,
ncol: usize,
zero_threshold: T,
}
impl<T> Drop for CsrMatrixBuilder<'_, T> {
fn drop(&mut self) {
self.set.ncols.push(self.ncol);
let partition = Partition {
value_offset: self.value_start,
value_len: self.set.values.len() - self.value_start,
row_offset: self.row_start,
row_len: self.set.row_offsets.len() - self.row_start,
};
self.set.partition.push(partition);
}
}
impl<T: Real> CsrMatrixBuilder<'_, T> {
#[inline]
pub fn ncol(&self) -> usize {
self.ncol
}
#[inline]
pub fn new_row(&mut self) -> CsVecBuilder<T> {
CsVecBuilder::from_parts_unchecked(
&mut self.set.col_indices,
&mut self.set.row_offsets,
&mut self.set.values,
self.value_start,
self.zero_threshold,
)
}
}
#[cfg(test)]
mod tests {
use approx::assert_relative_eq;
use super::*;
use crate::csm::CsrMatrixViewMethods;
use crate::traits::IntoView;
fn create_test_matrix_set() -> CsrMatrixSet<f32> {
let mut set = CsrMatrixSet::default();
{
let mut builder = set.new_matrix(3, 1e-10);
{
let mut row = builder.new_row();
row.extend_with_nonzeros(vec![(0, 1.0), (2, 2.0)]);
}
{
let mut row = builder.new_row();
row.extend_with_nonzeros(vec![(1, 3.0)]);
}
{
let mut row = builder.new_row();
row.extend_with_nonzeros(vec![(0, 4.0), (1, 5.0)]);
}
}
{
let mut builder = set.new_matrix(2, 1e-10);
{
let mut row = builder.new_row();
row.extend_with_nonzeros(vec![(0, 6.0), (1, 7.0)]);
}
{
let mut row = builder.new_row();
row.extend_with_nonzeros(vec![(1, 8.0)]);
}
}
{
let mut builder = set.new_matrix(1, 1e-10);
{
let mut row = builder.new_row();
row.extend_with_nonzeros(vec![(0, 9.0)]);
}
}
{
let mut builder = set.new_matrix(3, 1e-10);
{
let mut row = builder.new_row();
row.extend_with_nonzeros(vec![(1, 10.0)]);
}
{
let mut row = builder.new_row();
row.extend_with_nonzeros(vec![(0, 11.0), (2, 12.0)]);
}
}
set
}
#[test]
fn test_split_at_beginning() {
let set = create_test_matrix_set();
let view = set.as_view();
let (left, right) = view.split_at(0);
assert_eq!(left.len(), 0);
assert_eq!(right.len(), 4);
for i in 0..4 {
let original_matrix = set.get(i);
let split_matrix = right.get(i);
assert_relative_eq!(original_matrix.to_dense(), split_matrix.to_dense());
}
}
#[test]
fn test_split_at_end() {
let set = create_test_matrix_set();
let view = set.as_view();
let (left, right) = view.split_at(4);
assert_eq!(left.len(), 4);
assert_eq!(right.len(), 0);
for i in 0..4 {
let original_matrix = set.get(i);
let split_matrix = left.get(i);
assert_relative_eq!(original_matrix.to_dense(), split_matrix.to_dense());
}
}
#[test]
fn test_split_at_middle() {
let set = create_test_matrix_set();
let view = set.as_view();
let (left, right) = view.split_at(2);
assert_eq!(left.len(), 2);
assert_eq!(right.len(), 2);
for i in 0..2 {
let original_matrix = set.get(i);
let split_matrix = left.get(i);
assert_relative_eq!(original_matrix.to_dense(), split_matrix.to_dense());
}
for i in 0..2 {
let original_matrix = set.get(i + 2);
let split_matrix = right.get(i);
assert_relative_eq!(original_matrix.to_dense(), split_matrix.to_dense());
}
}
#[test]
fn test_split_at_various_positions() {
let set = create_test_matrix_set();
let view = set.as_view();
for split_index in 0..=4 {
let (left, right) = view.split_at(split_index);
assert_eq!(left.len(), split_index);
assert_eq!(right.len(), 4 - split_index);
for i in 0..4 {
let original_matrix = set.get(i);
let split_matrix = if i < split_index {
left.get(i)
} else {
right.get(i - split_index)
};
assert_relative_eq!(original_matrix.to_dense(), split_matrix.to_dense());
}
}
}
#[test]
fn test_split_multiple_times() {
let set = create_test_matrix_set();
let view = set.as_view();
let (left, right) = view.split_at(2);
let (left_left, left_right) = left.split_at(1);
let (right_left, right_right) = right.split_at(1);
assert_eq!(left_left.len(), 1);
assert_eq!(left_right.len(), 1);
assert_eq!(right_left.len(), 1);
assert_eq!(right_right.len(), 1);
assert_relative_eq!(set.get(0).to_dense(), left_left.get(0).to_dense());
assert_relative_eq!(set.get(1).to_dense(), left_right.get(0).to_dense());
assert_relative_eq!(set.get(2).to_dense(), right_left.get(0).to_dense());
assert_relative_eq!(set.get(3).to_dense(), right_right.get(0).to_dense());
}
#[test]
fn test_split_single_matrix() {
let mut set = CsrMatrixSet::default();
{
let mut builder = set.new_matrix(2, 1e-10);
{
let mut row = builder.new_row();
row.extend_with_nonzeros(vec![(0, 1.0), (1, 2.0)]);
}
}
let view = (&set).into_view();
let (left, right) = view.split_at(0);
assert_eq!(left.len(), 0);
assert_eq!(right.len(), 1);
let (left, right) = view.split_at(1);
assert_eq!(left.len(), 1);
assert_eq!(right.len(), 0);
assert_relative_eq!(set.get(0).to_dense(), left.get(0).to_dense());
}
#[test]
fn test_split_empty_view() {
let set = CsrMatrixSet::<f32>::default();
let view = set.as_view();
let (left, right) = view.split_at(0);
assert_eq!(left.len(), 0);
assert_eq!(right.len(), 0);
assert!(left.is_empty());
assert!(right.is_empty());
}
#[test]
fn test_split_view_data_integrity() {
let set = create_test_matrix_set();
let view = set.as_view();
let (left, right) = view.split_at(2);
let original_matrices: Vec<_> = (0..4).map(|i| set.get(i).to_dense()).collect();
let left_matrices: Vec<_> = (0..2).map(|i| left.get(i).to_dense()).collect();
let right_matrices: Vec<_> = (0..2).map(|i| right.get(i).to_dense()).collect();
for (i, original) in original_matrices.iter().enumerate() {
let split_matrix = if i < 2 {
&left_matrices[i]
} else {
&right_matrices[i - 2]
};
assert_relative_eq!(original, split_matrix);
}
}
#[test]
fn test_split_view_independence() {
let set = create_test_matrix_set();
let view = set.as_view();
let (left1, right1) = view.split_at(2);
let (left2, right2) = view.split_at(2);
assert_eq!(left1.len(), left2.len());
assert_eq!(right1.len(), right2.len());
for i in 0..left1.len() {
assert_relative_eq!(left1.get(i).to_dense(), left2.get(i).to_dense());
}
for i in 0..right1.len() {
assert_relative_eq!(right1.get(i).to_dense(), right2.get(i).to_dense());
}
}
}