use core::ops::Deref;
use crate::Matrix;
use crate::dense::RowMajorMatrixView;
pub type ViewPair<'a, T> = VerticalPair<RowMajorMatrixView<'a, T>, RowMajorMatrixView<'a, T>>;
#[derive(Copy, Clone, Debug)]
pub struct VerticalPair<Top, Bottom> {
pub top: Top,
pub bottom: Bottom,
}
#[derive(Copy, Clone, Debug)]
pub struct HorizontalPair<Left, Right> {
pub left: Left,
pub right: Right,
}
impl<Top, Bottom> VerticalPair<Top, Bottom> {
pub fn new<T>(top: Top, bottom: Bottom) -> Self
where
T: Send + Sync + Clone,
Top: Matrix<T>,
Bottom: Matrix<T>,
{
assert_eq!(top.width(), bottom.width());
Self { top, bottom }
}
}
impl<Left, Right> HorizontalPair<Left, Right> {
pub fn new<T>(left: Left, right: Right) -> Self
where
T: Send + Sync + Clone,
Left: Matrix<T>,
Right: Matrix<T>,
{
assert_eq!(left.height(), right.height());
Self { left, right }
}
}
impl<T: Send + Sync + Clone, Top: Matrix<T>, Bottom: Matrix<T>> Matrix<T>
for VerticalPair<Top, Bottom>
{
fn width(&self) -> usize {
self.top.width()
}
fn height(&self) -> usize {
self.top.height() + self.bottom.height()
}
unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
unsafe {
if r < self.top.height() {
self.top.get_unchecked(r, c)
} else {
self.bottom.get_unchecked(r - self.top.height(), c)
}
}
}
unsafe fn row_unchecked(
&self,
r: usize,
) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
unsafe {
if r < self.top.height() {
EitherRow::Left(self.top.row_unchecked(r).into_iter())
} else {
EitherRow::Right(self.bottom.row_unchecked(r - self.top.height()).into_iter())
}
}
}
unsafe fn row_subseq_unchecked(
&self,
r: usize,
start: usize,
end: usize,
) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
unsafe {
if r < self.top.height() {
EitherRow::Left(self.top.row_subseq_unchecked(r, start, end).into_iter())
} else {
EitherRow::Right(
self.bottom
.row_subseq_unchecked(r - self.top.height(), start, end)
.into_iter(),
)
}
}
}
unsafe fn row_slice_unchecked(&self, r: usize) -> impl Deref<Target = [T]> {
unsafe {
if r < self.top.height() {
EitherRow::Left(self.top.row_slice_unchecked(r))
} else {
EitherRow::Right(self.bottom.row_slice_unchecked(r - self.top.height()))
}
}
}
unsafe fn row_subslice_unchecked(
&self,
r: usize,
start: usize,
end: usize,
) -> impl Deref<Target = [T]> {
unsafe {
if r < self.top.height() {
EitherRow::Left(self.top.row_subslice_unchecked(r, start, end))
} else {
EitherRow::Right(self.bottom.row_subslice_unchecked(
r - self.top.height(),
start,
end,
))
}
}
}
}
impl<T: Send + Sync + Clone, Left: Matrix<T>, Right: Matrix<T>> Matrix<T>
for HorizontalPair<Left, Right>
{
fn width(&self) -> usize {
self.left.width() + self.right.width()
}
fn height(&self) -> usize {
self.left.height()
}
unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
unsafe {
if c < self.left.width() {
self.left.get_unchecked(r, c)
} else {
self.right.get_unchecked(r, c - self.left.width())
}
}
}
unsafe fn row_unchecked(
&self,
r: usize,
) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
unsafe {
self.left
.row_unchecked(r)
.into_iter()
.chain(self.right.row_unchecked(r))
}
}
}
#[derive(Debug)]
pub enum EitherRow<L, R> {
Left(L),
Right(R),
}
impl<T, L, R> Iterator for EitherRow<L, R>
where
L: Iterator<Item = T>,
R: Iterator<Item = T>,
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
match self {
Self::Left(l) => l.next(),
Self::Right(r) => r.next(),
}
}
}
impl<T, L, R> Deref for EitherRow<L, R>
where
L: Deref<Target = [T]>,
R: Deref<Target = [T]>,
{
type Target = [T];
fn deref(&self) -> &Self::Target {
match self {
Self::Left(l) => l,
Self::Right(r) => r,
}
}
}
#[cfg(test)]
mod tests {
use alloc::vec;
use alloc::vec::Vec;
use itertools::Itertools;
use super::*;
use crate::RowMajorMatrix;
#[test]
fn test_vertical_pair_empty_top() {
let top = RowMajorMatrix::new(vec![], 2); let bottom = RowMajorMatrix::new(vec![1, 2, 3, 4], 2); let vpair = VerticalPair::new::<i32>(top, bottom);
assert_eq!(vpair.height(), 2);
assert_eq!(vpair.get(1, 1), Some(4));
unsafe {
assert_eq!(vpair.get_unchecked(0, 0), 1);
}
}
#[test]
fn test_vertical_pair_composition() {
let top = RowMajorMatrix::new(vec![1, 2, 3, 4], 2); let bottom = RowMajorMatrix::new(vec![5, 6, 7, 8], 2); let vertical = VerticalPair::new::<i32>(top, bottom);
assert_eq!(vertical.width(), 2);
assert_eq!(vertical.height(), 4);
assert_eq!(vertical.get(0, 0), Some(1));
assert_eq!(vertical.get(1, 1), Some(4));
unsafe {
assert_eq!(vertical.get_unchecked(2, 0), 5);
assert_eq!(vertical.get_unchecked(3, 1), 8);
}
let row = vertical.row(3).unwrap().into_iter().collect_vec();
assert_eq!(row, vec![7, 8]);
unsafe {
let row = vertical.row_unchecked(1).into_iter().collect_vec();
assert_eq!(row, vec![3, 4]);
let row = vertical
.row_subseq_unchecked(0, 0, 1)
.into_iter()
.collect_vec();
assert_eq!(row, vec![1]);
}
assert_eq!(vertical.row_slice(2).unwrap().deref(), &[5, 6]);
unsafe {
assert_eq!(vertical.row_slice_unchecked(3).deref(), &[7, 8]);
assert_eq!(vertical.row_subslice_unchecked(1, 1, 2).deref(), &[4]);
}
assert_eq!(vertical.get(0, 2), None); assert_eq!(vertical.get(4, 0), None); assert!(vertical.row(4).is_none()); assert!(vertical.row_slice(4).is_none()); }
#[test]
fn test_horizontal_pair_composition() {
let left = RowMajorMatrix::new(vec![1, 2, 3, 4], 2); let right = RowMajorMatrix::new(vec![5, 6, 7, 8], 2); let horizontal = HorizontalPair::new::<i32>(left, right);
assert_eq!(horizontal.height(), 2);
assert_eq!(horizontal.width(), 4);
assert_eq!(horizontal.get(0, 0), Some(1));
assert_eq!(horizontal.get(1, 1), Some(4));
unsafe {
assert_eq!(horizontal.get_unchecked(0, 2), 5);
assert_eq!(horizontal.get_unchecked(1, 3), 8);
}
let row = horizontal.row(0).unwrap().into_iter().collect_vec();
assert_eq!(row, vec![1, 2, 5, 6]);
unsafe {
let row = horizontal.row_unchecked(1).into_iter().collect_vec();
assert_eq!(row, vec![3, 4, 7, 8]);
}
assert_eq!(horizontal.get(0, 4), None); assert_eq!(horizontal.get(2, 0), None); assert!(horizontal.row(2).is_none()); }
#[test]
fn test_either_row_iterator_behavior() {
type Iter = alloc::vec::IntoIter<i32>;
let left: EitherRow<Iter, Iter> = EitherRow::Left(vec![10, 20].into_iter());
assert_eq!(left.collect::<Vec<_>>(), vec![10, 20]);
let right: EitherRow<Iter, Iter> = EitherRow::Right(vec![30, 40].into_iter());
assert_eq!(right.collect::<Vec<_>>(), vec![30, 40]);
}
#[test]
fn test_either_row_deref_behavior() {
let left: EitherRow<&[i32], &[i32]> = EitherRow::Left(&[1, 2, 3]);
let right: EitherRow<&[i32], &[i32]> = EitherRow::Right(&[4, 5]);
assert_eq!(&*left, &[1, 2, 3]);
assert_eq!(&*right, &[4, 5]);
}
#[test]
#[should_panic]
fn test_vertical_pair_width_mismatch_should_panic() {
let a = RowMajorMatrix::new(vec![1, 2, 3], 1); let b = RowMajorMatrix::new(vec![4, 5], 2); let _ = VerticalPair::new::<i32>(a, b);
}
#[test]
#[should_panic]
fn test_horizontal_pair_height_mismatch_should_panic() {
let a = RowMajorMatrix::new(vec![1, 2, 3], 3); let b = RowMajorMatrix::new(vec![4, 5], 1); let _ = HorizontalPair::new::<i32>(a, b);
}
}