1use crate::Matrix;
2use crate::row_index_mapped::{RowIndexMap, RowIndexMappedView};
3
4#[derive(Debug)]
9pub struct VerticallyStridedRowIndexMap {
10 height: usize,
12 stride: usize,
14 offset: usize,
16}
17
18pub type VerticallyStridedMatrixView<Inner> =
19 RowIndexMappedView<VerticallyStridedRowIndexMap, Inner>;
20
21impl VerticallyStridedRowIndexMap {
22 pub fn new_view<T: Send + Sync + Clone, Inner: Matrix<T>>(
31 inner: Inner,
32 stride: usize,
33 offset: usize,
34 ) -> VerticallyStridedMatrixView<Inner> {
35 let h = inner.height();
36 let full_strides = h / stride;
37 let remainder = h % stride;
38 let final_stride = offset < remainder;
39 let height = full_strides + final_stride as usize;
40 RowIndexMappedView {
41 index_map: Self {
42 height,
43 stride,
44 offset,
45 },
46 inner,
47 }
48 }
49}
50
51impl RowIndexMap for VerticallyStridedRowIndexMap {
52 fn height(&self) -> usize {
53 self.height
54 }
55
56 fn map_row_index(&self, r: usize) -> usize {
57 r * self.stride + self.offset
58 }
59}
60
61#[cfg(test)]
62mod tests {
63 use alloc::vec;
64
65 use super::*;
66 use crate::{Matrix, RowMajorMatrix};
67
68 fn sample_matrix() -> RowMajorMatrix<i32> {
69 RowMajorMatrix::new(
76 vec![10, 11, 12, 20, 21, 22, 30, 31, 32, 40, 41, 42, 50, 51, 52],
77 3,
78 )
79 }
80
81 #[test]
82 fn test_vertically_strided_view_stride_1_offset_0() {
83 let matrix = sample_matrix();
84 let view = VerticallyStridedRowIndexMap::new_view(matrix, 1, 0);
85
86 assert_eq!(view.height(), 5);
87 assert_eq!(view.width(), 3);
88
89 assert_eq!(view.get(0, 0), Some(10));
90 assert_eq!(view.get(1, 1), Some(21));
91 unsafe {
92 assert_eq!(view.get_unchecked(4, 2), 52);
93 }
94 assert_eq!(view.get(5, 0), None); assert_eq!(view.get(0, 3), None); }
97
98 #[test]
99 fn test_vertically_strided_view_stride_2_offset_0() {
100 let matrix = sample_matrix();
101 let view = VerticallyStridedRowIndexMap::new_view(matrix, 2, 0);
102
103 assert_eq!(view.height(), 3);
104 assert_eq!(view.get(0, 0), Some(10)); unsafe {
106 assert_eq!(view.get_unchecked(1, 1), 31); assert_eq!(view.get_unchecked(2, 2), 52); }
109 assert_eq!(view.get(0, 3), None); }
111
112 #[test]
113 fn test_vertically_strided_view_stride_2_offset_1() {
114 let matrix = sample_matrix();
115 let view = VerticallyStridedRowIndexMap::new_view(matrix, 2, 1);
116
117 assert_eq!(view.height(), 2);
118 assert_eq!(view.get(0, 0), Some(20)); unsafe {
120 assert_eq!(view.get_unchecked(1, 1), 41);
121 } }
123
124 #[test]
125 fn test_vertically_strided_view_stride_3_offset_0() {
126 let matrix = sample_matrix();
127 let view = VerticallyStridedRowIndexMap::new_view(matrix, 3, 0);
128
129 assert_eq!(view.height(), 2);
130 assert_eq!(view.get(0, 0), Some(10)); assert_eq!(view.get(1, 1), Some(41)); }
133
134 #[test]
135 fn test_vertically_strided_view_stride_3_offset_1() {
136 let matrix = sample_matrix();
137 let view = VerticallyStridedRowIndexMap::new_view(matrix, 3, 1);
138
139 assert_eq!(view.height(), 2);
140 unsafe {
141 assert_eq!(view.get_unchecked(0, 0), 20); assert_eq!(view.get_unchecked(1, 1), 51); }
144 }
145
146 #[test]
147 fn test_vertically_strided_view_stride_3_offset_2() {
148 let matrix = sample_matrix();
149 let view = VerticallyStridedRowIndexMap::new_view(matrix, 3, 2);
150
151 assert_eq!(view.height(), 1);
152 assert_eq!(view.get(0, 2), Some(32)); }
154
155 #[test]
156 fn test_vertically_strided_view_stride_greater_than_height() {
157 let matrix = sample_matrix();
158 let view = VerticallyStridedRowIndexMap::new_view(matrix, 10, 0);
159
160 assert_eq!(view.height(), 1);
161 assert_eq!(view.get(0, 0), Some(10)); }
163
164 #[test]
165 fn test_vertically_strided_view_stride_greater_than_height_with_valid_offset() {
166 let matrix = sample_matrix(); let view = VerticallyStridedRowIndexMap::new_view(matrix, 10, 4);
168
169 assert_eq!(view.height(), 1);
171 assert_eq!(view.get(0, 2), Some(52)); }
173
174 #[test]
175 fn test_vertically_strided_view_stride_greater_than_height_with_offset_beyond_height() {
176 let matrix = sample_matrix(); let view = VerticallyStridedRowIndexMap::new_view(matrix, 10, 6);
178
179 assert_eq!(view.height(), 0);
181 assert_eq!(view.get(0, 0), None); }
183}