use crate::Matrix;
use crate::row_index_mapped::{RowIndexMap, RowIndexMappedView};
#[derive(Debug)]
pub struct VerticallyStridedRowIndexMap {
height: usize,
stride: usize,
offset: usize,
}
pub type VerticallyStridedMatrixView<Inner> =
RowIndexMappedView<VerticallyStridedRowIndexMap, Inner>;
impl VerticallyStridedRowIndexMap {
pub fn new_view<T: Send + Sync + Clone, Inner: Matrix<T>>(
inner: Inner,
stride: usize,
offset: usize,
) -> VerticallyStridedMatrixView<Inner> {
let h = inner.height();
let full_strides = h / stride;
let remainder = h % stride;
let final_stride = offset < remainder;
let height = full_strides + final_stride as usize;
RowIndexMappedView {
index_map: Self {
height,
stride,
offset,
},
inner,
}
}
}
impl RowIndexMap for VerticallyStridedRowIndexMap {
fn height(&self) -> usize {
self.height
}
fn map_row_index(&self, r: usize) -> usize {
r * self.stride + self.offset
}
}
#[cfg(test)]
mod tests {
use alloc::vec;
use super::*;
use crate::{Matrix, RowMajorMatrix};
fn sample_matrix() -> RowMajorMatrix<i32> {
RowMajorMatrix::new(
vec![10, 11, 12, 20, 21, 22, 30, 31, 32, 40, 41, 42, 50, 51, 52],
3,
)
}
#[test]
fn test_vertically_strided_view_stride_1_offset_0() {
let matrix = sample_matrix();
let view = VerticallyStridedRowIndexMap::new_view(matrix, 1, 0);
assert_eq!(view.height(), 5);
assert_eq!(view.width(), 3);
assert_eq!(view.get(0, 0), Some(10));
assert_eq!(view.get(1, 1), Some(21));
unsafe {
assert_eq!(view.get_unchecked(4, 2), 52);
}
assert_eq!(view.get(5, 0), None); assert_eq!(view.get(0, 3), None); }
#[test]
fn test_vertically_strided_view_stride_2_offset_0() {
let matrix = sample_matrix();
let view = VerticallyStridedRowIndexMap::new_view(matrix, 2, 0);
assert_eq!(view.height(), 3);
assert_eq!(view.get(0, 0), Some(10)); unsafe {
assert_eq!(view.get_unchecked(1, 1), 31); assert_eq!(view.get_unchecked(2, 2), 52); }
assert_eq!(view.get(0, 3), None); }
#[test]
fn test_vertically_strided_view_stride_2_offset_1() {
let matrix = sample_matrix();
let view = VerticallyStridedRowIndexMap::new_view(matrix, 2, 1);
assert_eq!(view.height(), 2);
assert_eq!(view.get(0, 0), Some(20)); unsafe {
assert_eq!(view.get_unchecked(1, 1), 41);
} }
#[test]
fn test_vertically_strided_view_stride_3_offset_0() {
let matrix = sample_matrix();
let view = VerticallyStridedRowIndexMap::new_view(matrix, 3, 0);
assert_eq!(view.height(), 2);
assert_eq!(view.get(0, 0), Some(10)); assert_eq!(view.get(1, 1), Some(41)); }
#[test]
fn test_vertically_strided_view_stride_3_offset_1() {
let matrix = sample_matrix();
let view = VerticallyStridedRowIndexMap::new_view(matrix, 3, 1);
assert_eq!(view.height(), 2);
unsafe {
assert_eq!(view.get_unchecked(0, 0), 20); assert_eq!(view.get_unchecked(1, 1), 51); }
}
#[test]
fn test_vertically_strided_view_stride_3_offset_2() {
let matrix = sample_matrix();
let view = VerticallyStridedRowIndexMap::new_view(matrix, 3, 2);
assert_eq!(view.height(), 1);
assert_eq!(view.get(0, 2), Some(32)); }
#[test]
fn test_vertically_strided_view_stride_greater_than_height() {
let matrix = sample_matrix();
let view = VerticallyStridedRowIndexMap::new_view(matrix, 10, 0);
assert_eq!(view.height(), 1);
assert_eq!(view.get(0, 0), Some(10)); }
#[test]
fn test_vertically_strided_view_stride_greater_than_height_with_valid_offset() {
let matrix = sample_matrix(); let view = VerticallyStridedRowIndexMap::new_view(matrix, 10, 4);
assert_eq!(view.height(), 1);
assert_eq!(view.get(0, 2), Some(52)); }
#[test]
fn test_vertically_strided_view_stride_greater_than_height_with_offset_beyond_height() {
let matrix = sample_matrix(); let view = VerticallyStridedRowIndexMap::new_view(matrix, 10, 6);
assert_eq!(view.height(), 0);
assert_eq!(view.get(0, 0), None); }
}