use crate::comparators::ElementwiseComparator;
use crate::{
Access, Coordinate, DenseAccess, DimensionMismatch, ElementsMismatch, Matrix,
MatrixComparisonFailure, MatrixElementComparisonFailure, SparseAccess,
};
use num_traits::Zero;
use std::collections::{HashMap, HashSet};
use crate::Entry;
enum HashMapBuildError {
OutOfBoundsCoord(Coordinate),
DuplicateCoord(Coordinate),
}
fn try_build_sparse_hash_map<T>(
rows: usize,
cols: usize,
triplets: &[(usize, usize, T)],
) -> Result<HashMap<(usize, usize), T>, HashMapBuildError>
where
T: Clone,
{
let mut matrix = HashMap::new();
for (i, j, v) in triplets.iter().cloned() {
if i >= rows || j >= cols {
return Err(HashMapBuildError::OutOfBoundsCoord((i, j)));
} else if matrix.insert((i, j), v).is_some() {
return Err(HashMapBuildError::DuplicateCoord((i, j)));
}
}
Ok(matrix)
}
fn compare_sparse_sparse<T, C>(
left: &dyn SparseAccess<T>,
right: &dyn SparseAccess<T>,
comparator: &C,
) -> Result<(), MatrixComparisonFailure<T, C::Error>>
where
T: Zero + Clone,
C: ElementwiseComparator<T>,
{
assert!(left.rows() == right.rows() && left.cols() == right.cols());
let left_hash = try_build_sparse_hash_map(left.rows(), left.cols(), &left.fetch_triplets())
.map_err(|build_error| match build_error {
HashMapBuildError::OutOfBoundsCoord(coord) => {
MatrixComparisonFailure::SparseEntryOutOfBounds(Entry::Left(coord))
}
HashMapBuildError::DuplicateCoord(coord) => {
MatrixComparisonFailure::DuplicateSparseEntry(Entry::Left(coord))
}
})?;
let right_hash = try_build_sparse_hash_map(right.rows(), right.cols(), &right.fetch_triplets())
.map_err(|build_error| match build_error {
HashMapBuildError::OutOfBoundsCoord(coord) => {
MatrixComparisonFailure::SparseEntryOutOfBounds(Entry::Right(coord))
}
HashMapBuildError::DuplicateCoord(coord) => {
MatrixComparisonFailure::DuplicateSparseEntry(Entry::Right(coord))
}
})?;
let mut mismatches = Vec::new();
let left_keys: HashSet<_> = left_hash.keys().collect();
let right_keys: HashSet<_> = right_hash.keys().collect();
let zero = T::zero();
for coord in left_keys.union(&right_keys) {
let a = left_hash.get(coord).unwrap_or(&zero);
let b = right_hash.get(coord).unwrap_or(&zero);
if let Err(error) = comparator.compare(&a, &b) {
mismatches.push(MatrixElementComparisonFailure {
left: a.clone(),
right: b.clone(),
error,
row: coord.0,
col: coord.1,
});
}
}
mismatches.sort_by_key(|mismatch| (mismatch.row, mismatch.col));
if mismatches.is_empty() {
Ok(())
} else {
Err(MatrixComparisonFailure::MismatchedElements(
ElementsMismatch {
comparator_description: comparator.description(),
mismatches,
},
))
}
}
fn find_dense_sparse_mismatches<T, C>(
dense: &dyn DenseAccess<T>,
sparse: &HashMap<(usize, usize), T>,
comparator: &C,
swap_order: bool,
) -> Option<ElementsMismatch<T, C::Error>>
where
T: Zero + Clone,
C: ElementwiseComparator<T>,
{
let mut mismatches = Vec::new();
let zero = T::zero();
for i in 0..dense.rows() {
for j in 0..dense.cols() {
let a = &dense.fetch_single(i, j);
let b = sparse.get(&(i, j)).unwrap_or(&zero);
let (a, b) = if swap_order { (b, a) } else { (a, b) };
if let Err(error) = comparator.compare(a, b) {
mismatches.push(MatrixElementComparisonFailure {
left: a.clone(),
right: b.clone(),
error,
row: i,
col: j,
});
}
}
}
if mismatches.is_empty() {
None
} else {
Some(ElementsMismatch {
comparator_description: comparator.description(),
mismatches,
})
}
}
fn compare_dense_sparse<T, C>(
dense: &dyn DenseAccess<T>,
sparse: &dyn SparseAccess<T>,
comparator: &C,
swap_order: bool,
) -> Result<(), MatrixComparisonFailure<T, C::Error>>
where
T: Zero + Clone,
C: ElementwiseComparator<T>,
{
assert!(dense.rows() == sparse.rows() && dense.cols() == sparse.cols());
let triplets = sparse.fetch_triplets();
let sparse_hash = try_build_sparse_hash_map(sparse.rows(), sparse.cols(), &triplets);
match sparse_hash {
Ok(y_hash) => {
let mismatches = find_dense_sparse_mismatches(dense, &y_hash, comparator, swap_order);
if let Some(mismatches) = mismatches {
Err(MatrixComparisonFailure::MismatchedElements(mismatches))
} else {
Ok(())
}
}
Err(build_error) => {
let make_entry = |coord| {
if swap_order {
Entry::Left(coord)
} else {
Entry::Right(coord)
}
};
use MatrixComparisonFailure::*;
match build_error {
HashMapBuildError::OutOfBoundsCoord(coord) => {
Err(SparseEntryOutOfBounds(make_entry(coord)))
}
HashMapBuildError::DuplicateCoord(coord) => {
Err(DuplicateSparseEntry(make_entry(coord)))
}
}
}
}
}
fn compare_dense_dense<T, C>(
left: &dyn DenseAccess<T>,
right: &dyn DenseAccess<T>,
comparator: &C,
) -> Result<(), MatrixComparisonFailure<T, C::Error>>
where
T: Clone,
C: ElementwiseComparator<T>,
{
assert!(left.rows() == right.rows() && left.cols() == right.cols());
let mut mismatches = Vec::new();
for i in 0..left.rows() {
for j in 0..left.cols() {
let a = left.fetch_single(i, j);
let b = right.fetch_single(i, j);
if let Err(error) = comparator.compare(&a, &b) {
mismatches.push(MatrixElementComparisonFailure {
left: a.clone(),
right: b.clone(),
error,
row: i,
col: j,
});
}
}
}
if mismatches.is_empty() {
Ok(())
} else {
Err(MatrixComparisonFailure::MismatchedElements(
ElementsMismatch {
comparator_description: comparator.description(),
mismatches,
},
))
}
}
pub fn compare_matrices<T, C>(
left: impl Matrix<T>,
right: impl Matrix<T>,
comparator: &C,
) -> Result<(), MatrixComparisonFailure<T, C::Error>>
where
T: Zero + Clone,
C: ElementwiseComparator<T>,
{
let shapes_match = left.rows() == right.rows() && left.cols() == right.cols();
if shapes_match {
use Access::{Dense, Sparse};
match (left.access(), right.access()) {
(Dense(left_access), Dense(right_access)) => {
compare_dense_dense(left_access, right_access, comparator)
}
(Dense(left_access), Sparse(right_access)) => {
let swap = false;
compare_dense_sparse(left_access, right_access, comparator, swap)
}
(Sparse(left_access), Dense(right_access)) => {
let swap = true;
compare_dense_sparse(right_access, left_access, comparator, swap)
}
(Sparse(left_access), Sparse(right_access)) => {
compare_sparse_sparse(left_access, right_access, comparator)
}
}
} else {
Err(MatrixComparisonFailure::MismatchedDimensions(
DimensionMismatch {
dim_left: (left.rows(), left.cols()),
dim_right: (right.rows(), right.cols()),
},
))
}
}