use thiserror::Error;
use crate::{math::Point, FloatNumber};
#[derive(Debug, PartialEq, Error)]
pub enum MatrixError {
#[error("Invalid Points: The points slice is not in the expected shape: {0}x{1}.")]
InvalidPoints(usize, usize),
}
#[derive(Debug, PartialEq)]
pub struct MatrixView<'a, T, const N: usize>
where
T: FloatNumber,
{
pub(crate) cols: usize,
pub(crate) rows: usize,
points: &'a [Point<T, N>],
}
impl<'a, T, const N: usize> MatrixView<'a, T, N>
where
T: FloatNumber,
{
#[inline]
pub fn new(cols: usize, rows: usize, points: &'a [Point<T, N>]) -> Result<Self, MatrixError> {
if cols * rows != points.len() {
return Err(MatrixError::InvalidPoints(cols, rows));
}
Ok(Self { cols, rows, points })
}
#[inline(always)]
#[must_use]
pub fn size(&self) -> usize {
self.points.len()
}
#[inline]
#[must_use]
pub fn shape(&self) -> (usize, usize) {
(self.cols, self.rows)
}
#[inline(always)]
#[must_use]
pub fn index(&self, col: usize, row: usize) -> Option<usize> {
if col < self.cols && row < self.rows {
Some(col + row * self.cols)
} else {
None
}
}
#[inline(always)]
#[must_use]
pub fn get(&self, col: usize, row: usize) -> Option<&Point<T, N>> {
self.index(col, row).map(|index| &self.points[index])
}
#[inline]
#[must_use]
pub fn neighbors(&self, col: usize, row: usize) -> NeighborIterator<T, N> {
NeighborIterator::new(self, col, row, 1)
}
#[inline]
#[must_use]
pub fn neighbors_with_size(
&self,
col: usize,
row: usize,
radius: usize,
) -> NeighborIterator<T, N> {
NeighborIterator::new(self, col, row, radius)
}
}
#[derive(Debug, PartialEq)]
pub struct NeighborIterator<'a, T, const N: usize>
where
T: FloatNumber,
{
matrix: &'a MatrixView<'a, T, N>,
col: usize,
row: usize,
radius: isize,
dx: isize,
dy: isize,
}
impl<'a, T, const N: usize> NeighborIterator<'a, T, N>
where
T: FloatNumber,
{
#[inline]
#[must_use]
pub fn new(matrix: &'a MatrixView<'a, T, N>, col: usize, row: usize, radius: usize) -> Self {
let radius = radius as isize;
Self {
matrix,
col,
row,
radius,
dx: -radius,
dy: -radius,
}
}
}
impl<'a, T, const N: usize> Iterator for NeighborIterator<'a, T, N>
where
T: FloatNumber,
{
type Item = (usize, &'a Point<T, N>);
fn next(&mut self) -> Option<Self::Item> {
let (cols, rows) = self.matrix.shape();
if self.col >= cols || self.row >= rows {
return None;
}
while self.dy <= self.radius {
let dy = self.dy;
while self.dx <= self.radius {
let dx = self.dx;
self.dx += 1;
if dx == 0 && dy == 0 {
continue;
}
let col = self.col.checked_add_signed(dx);
let row = self.row.checked_add_signed(dy);
match (col, row) {
(Some(col), Some(row)) => {
if col >= cols || row >= rows {
continue;
}
let index = col + row * cols;
let point = &self.matrix.points[index];
return Some((index, point));
}
_ => continue,
}
}
self.dx = -self.radius;
self.dy += 1;
}
None
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
#[test]
fn test_new() {
let cols = 16;
let rows = 9;
let points = vec![[0.0; 3]; cols * rows];
let actual = MatrixView::new(cols, rows, &points);
assert!(actual.is_ok());
assert_eq!(
actual.unwrap(),
MatrixView {
cols,
rows,
points: &points,
}
);
}
#[test]
fn test_new_empty() {
let points = Vec::<[f64; 3]>::new();
let matrix = MatrixView::new(0, 0, &points);
assert!(matrix.is_ok());
let matrix = matrix.unwrap();
assert_eq!(matrix.cols, 0);
assert_eq!(matrix.rows, 0);
assert_eq!(matrix.points.len(), 0);
}
#[test]
fn test_new_invalid_points() {
let cols = 16;
let rows = 9;
let points = vec![[0.0; 3]; cols * rows - 1];
let matrix = MatrixView::new(cols, rows, &points);
assert!(matrix.is_err());
assert_eq!(matrix.unwrap_err(), MatrixError::InvalidPoints(cols, rows));
}
#[rstest]
#[case(1, 1, 1)]
#[case(2, 3, 6)]
#[case(16, 9, 144)]
fn test_size(#[case] cols: usize, #[case] rows: usize, #[case] expected: usize) {
let points = vec![[0.0; 3]; cols * rows];
let matrix = MatrixView::new(cols, rows, &points).unwrap();
let actual = matrix.size();
assert_eq!(actual, expected);
}
#[rstest]
#[case(1, 1, (1, 1))]
#[case(4, 9, (4, 9))]
#[case(9, 4, (9, 4))]
#[case(1, 1024, (1, 1024))]
#[case(1024, 1, (1024, 1))]
fn test_shape(#[case] cols: usize, #[case] rows: usize, #[case] expected: (usize, usize)) {
let points = vec![[0.0; 3]; cols * rows];
let matrix = MatrixView::new(cols, rows, &points).unwrap();
let actual = matrix.shape();
assert_eq!(actual, expected);
}
#[rstest]
#[case(0, 0, Some(0))]
#[case(1, 0, Some(1))]
#[case(15, 0, Some(15))]
#[case(0, 1, Some(16))]
#[case(1, 1, Some(17))]
#[case(0, 8, Some(128))]
#[case(15, 8, Some(143))]
#[case(16, 0, None)]
#[case(0, 9, None)]
#[case(16, 9, None)]
fn test_index(#[case] col: usize, #[case] row: usize, #[case] expected: Option<usize>) {
let cols = 16;
let rows = 9;
let points = vec![[0.0; 3]; cols * rows];
let matrix = MatrixView::new(cols, rows, &points).unwrap();
let actual = matrix.index(col, row);
assert_eq!(actual, expected);
}
#[rstest]
#[case::center(8, 4, 72)]
#[case::left_top(0, 0, 0)]
#[case::left_bottom(0, 8, 128)]
#[case::right_top(15, 0, 15)]
#[case::right_bottom(15, 8, 143)]
fn test_get(#[case] col: usize, #[case] row: usize, #[case] index: usize) {
let cols = 16;
let rows = 9;
let mut points = vec![[0.0; 3]; cols * rows];
for i in 0..points.len() {
points[i] = [i as f64, i as f64, i as f64];
}
let matrix = MatrixView::new(cols, rows, &points).unwrap();
let actual = matrix.get(col, row);
assert!(actual.is_some());
assert_eq!(actual.unwrap(), &points[index]);
}
#[rstest]
#[case::left_bottom(16, 0)]
#[case::right_top(0, 9)]
#[case::right_bottom(16, 9)]
fn test_get_out_of_bounds(#[case] col: usize, #[case] row: usize) {
let cols = 16;
let rows = 9;
let points = vec![[0.0; 3]; cols * rows];
let matrix = MatrixView::new(cols, rows, &points).unwrap();
let actual = matrix.get(col, row);
assert!(actual.is_none());
}
#[rstest]
#[case((0, 0), vec![1, 16, 17])]
#[case((0, 1), vec![0, 1, 17, 32, 33])]
#[case((1, 0), vec![0, 2, 16, 17, 18])]
#[case((1, 1), vec![0, 1, 2, 16, 18, 32, 33, 34])]
#[case((0, 8), vec![112, 113, 129])]
#[case((1, 8), vec![112, 113, 114, 128, 130])]
#[case((15, 0), vec![14, 30, 31])]
#[case((15, 7), vec![110, 111, 126, 142, 143])]
#[case((15, 8), vec![126, 127, 142])]
fn test_neighbors(#[case] (col, row): (usize, usize), #[case] expected: Vec<usize>) {
let cols = 16;
let rows = 9;
let points = vec![[0.0; 3]; cols * rows];
let matrix = MatrixView::new(cols, rows, &points).unwrap();
let actual =
matrix
.neighbors(col, row)
.fold(Vec::with_capacity(8), |mut acc, (index, _)| {
acc.push(index);
acc
});
assert_eq!(actual.len(), expected.len());
assert_eq!(actual, expected);
}
#[rstest]
#[case::right_top(0, 9)]
#[case::left_bottom(16, 0)]
#[case::right_bottom(16, 9)]
fn test_neighbors_empty(#[case] col: usize, #[case] row: usize) {
let cols = 16;
let rows = 9;
let points = vec![[0.0; 3]; cols * rows];
let matrix = MatrixView::new(cols, rows, &points).unwrap();
let actual =
matrix
.neighbors(col, row)
.fold(Vec::with_capacity(8), |mut acc, (index, _)| {
acc.push(index);
acc
});
assert!(actual.is_empty());
}
#[rstest]
#[case(0, (0, 0), vec![])]
#[case(1, (0, 0), vec![1, 16, 17])]
#[case(2, (0, 0), vec![1, 2, 16, 17, 18, 32, 33, 34])]
#[case(0, (8, 4), vec![])]
#[case(1, (8,4), vec![55, 56, 57, 71, 73, 87, 88, 89])]
#[case(2, (8,4), vec![38, 39, 40, 41, 42, 54, 55, 56, 57, 58, 70, 71, 73, 74, 86, 87, 88, 89, 90, 102, 103, 104, 105, 106])]
#[case(0, (15, 8), vec![])]
#[case(1, (15, 8), vec![126, 127, 142])]
#[case(2, (15, 8), vec![109, 110, 111, 125, 126, 127, 141, 142])]
fn test_neighbors_with_size(
#[case] radius: usize,
#[case] (col, row): (usize, usize),
#[case] expected: Vec<usize>,
) {
let cols = 16;
let rows = 9;
let points = vec![[0.0; 3]; cols * rows];
let matrix = MatrixView::new(cols, rows, &points).unwrap();
let actual =
matrix
.neighbors_with_size(col, row, radius)
.fold(Vec::new(), |mut acc, (index, _)| {
acc.push(index);
acc
});
assert_eq!(actual, expected);
}
#[test]
fn test_neighbor_iterator_new() {
let cols = 16;
let rows = 9;
let points = vec![[0.0; 3]; cols * rows];
let matrix = MatrixView::new(cols, rows, &points).unwrap();
let actual = NeighborIterator::new(&matrix, 8, 4, 1);
assert_eq!(
actual,
NeighborIterator {
matrix: &matrix,
col: 8,
row: 4,
radius: 1,
dx: -1,
dy: -1,
}
);
}
#[test]
fn test_neighbor_iterator_next() {
let cols = 3;
let rows = 2;
let mut points = vec![[0.0; 2]; cols * rows];
for i in 0..points.len() {
points[i] = [i as f64, i as f64];
}
let matrix = MatrixView::new(cols, rows, &points).unwrap();
let mut iterator = NeighborIterator::new(&matrix, 1, 1, 1);
assert_eq!(iterator.next(), Some((0, &points[0])));
assert_eq!(iterator.next(), Some((1, &points[1])));
assert_eq!(iterator.next(), Some((2, &points[2])));
assert_eq!(iterator.next(), Some((3, &points[3])));
assert_eq!(iterator.next(), Some((5, &points[5])));
assert_eq!(iterator.next(), None);
}
#[rstest]
#[case::cols(3, 1)]
#[case::rows(2, 2)]
#[case::cols_rows(3, 2)]
fn test_neighbor_iterator_next_out_of_bounds(#[case] col: usize, #[case] row: usize) {
let cols = 3;
let rows = 2;
let points = vec![[0.0; 2]; cols * rows];
let matrix = MatrixView::new(cols, rows, &points).unwrap();
let mut iterator = NeighborIterator::new(&matrix, col, row, 1);
assert_eq!(iterator.next(), None);
}
}