p3_matrix/
strided.rs

1use crate::Matrix;
2use crate::row_index_mapped::{RowIndexMap, RowIndexMappedView};
3
4/// A vertical row-mapping strategy that selects every `stride`-th row from an inner matrix,
5/// starting at a fixed `offset`.
6///
7/// This enables vertical striding like selecting rows: `offset`, `offset + stride`, etc.
8#[derive(Debug)]
9pub struct VerticallyStridedRowIndexMap {
10    /// The number of rows in the resulting view.
11    height: usize,
12    /// The step size between selected rows in the inner matrix.
13    stride: usize,
14    /// The offset to start the stride from.
15    offset: usize,
16}
17
18pub type VerticallyStridedMatrixView<Inner> =
19    RowIndexMappedView<VerticallyStridedRowIndexMap, Inner>;
20
21impl VerticallyStridedRowIndexMap {
22    /// Create a new vertically strided view over a matrix.
23    ///
24    /// This selects rows in the inner matrix starting from `offset`, and then every `stride` rows after.
25    ///
26    /// # Arguments
27    /// - `inner`: The inner matrix to view.
28    /// - `stride`: The number of rows between each selected row.
29    /// - `offset`: The initial row to start from.
30    pub fn new_view<T: Send + Sync + Clone, Inner: Matrix<T>>(
31        inner: Inner,
32        stride: usize,
33        offset: usize,
34    ) -> VerticallyStridedMatrixView<Inner> {
35        let h = inner.height();
36        let full_strides = h / stride;
37        let remainder = h % stride;
38        let final_stride = offset < remainder;
39        let height = full_strides + final_stride as usize;
40        RowIndexMappedView {
41            index_map: Self {
42                height,
43                stride,
44                offset,
45            },
46            inner,
47        }
48    }
49}
50
51impl RowIndexMap for VerticallyStridedRowIndexMap {
52    fn height(&self) -> usize {
53        self.height
54    }
55
56    fn map_row_index(&self, r: usize) -> usize {
57        r * self.stride + self.offset
58    }
59}
60
61#[cfg(test)]
62mod tests {
63    use alloc::vec;
64
65    use super::*;
66    use crate::{Matrix, RowMajorMatrix};
67
68    fn sample_matrix() -> RowMajorMatrix<i32> {
69        // A 5x3 matrix:
70        // [10, 11, 12]
71        // [20, 21, 22]
72        // [30, 31, 32]
73        // [40, 41, 42]
74        // [50, 51, 52]
75        RowMajorMatrix::new(
76            vec![10, 11, 12, 20, 21, 22, 30, 31, 32, 40, 41, 42, 50, 51, 52],
77            3,
78        )
79    }
80
81    #[test]
82    fn test_vertically_strided_view_stride_1_offset_0() {
83        let matrix = sample_matrix();
84        let view = VerticallyStridedRowIndexMap::new_view(matrix, 1, 0);
85
86        assert_eq!(view.height(), 5);
87        assert_eq!(view.width(), 3);
88
89        assert_eq!(view.get(0, 0), Some(10));
90        assert_eq!(view.get(1, 1), Some(21));
91        unsafe {
92            assert_eq!(view.get_unchecked(4, 2), 52);
93        }
94        assert_eq!(view.get(5, 0), None); // out of bounds
95        assert_eq!(view.get(0, 3), None); // out of bounds
96    }
97
98    #[test]
99    fn test_vertically_strided_view_stride_2_offset_0() {
100        let matrix = sample_matrix();
101        let view = VerticallyStridedRowIndexMap::new_view(matrix, 2, 0);
102
103        assert_eq!(view.height(), 3);
104        assert_eq!(view.get(0, 0), Some(10)); // row 0
105        unsafe {
106            assert_eq!(view.get_unchecked(1, 1), 31); // row 2
107            assert_eq!(view.get_unchecked(2, 2), 52); // row 4
108        }
109        assert_eq!(view.get(0, 3), None); // out of bounds
110    }
111
112    #[test]
113    fn test_vertically_strided_view_stride_2_offset_1() {
114        let matrix = sample_matrix();
115        let view = VerticallyStridedRowIndexMap::new_view(matrix, 2, 1);
116
117        assert_eq!(view.height(), 2);
118        assert_eq!(view.get(0, 0), Some(20)); // row 1
119        unsafe {
120            assert_eq!(view.get_unchecked(1, 1), 41);
121        } // row 3
122    }
123
124    #[test]
125    fn test_vertically_strided_view_stride_3_offset_0() {
126        let matrix = sample_matrix();
127        let view = VerticallyStridedRowIndexMap::new_view(matrix, 3, 0);
128
129        assert_eq!(view.height(), 2);
130        assert_eq!(view.get(0, 0), Some(10)); // row 0
131        assert_eq!(view.get(1, 1), Some(41)); // row 3
132    }
133
134    #[test]
135    fn test_vertically_strided_view_stride_3_offset_1() {
136        let matrix = sample_matrix();
137        let view = VerticallyStridedRowIndexMap::new_view(matrix, 3, 1);
138
139        assert_eq!(view.height(), 2);
140        unsafe {
141            assert_eq!(view.get_unchecked(0, 0), 20); // row 1
142            assert_eq!(view.get_unchecked(1, 1), 51); // row 4
143        }
144    }
145
146    #[test]
147    fn test_vertically_strided_view_stride_3_offset_2() {
148        let matrix = sample_matrix();
149        let view = VerticallyStridedRowIndexMap::new_view(matrix, 3, 2);
150
151        assert_eq!(view.height(), 1);
152        assert_eq!(view.get(0, 2), Some(32)); // row 2
153    }
154
155    #[test]
156    fn test_vertically_strided_view_stride_greater_than_height() {
157        let matrix = sample_matrix();
158        let view = VerticallyStridedRowIndexMap::new_view(matrix, 10, 0);
159
160        assert_eq!(view.height(), 1);
161        assert_eq!(view.get(0, 0), Some(10)); // row 0
162    }
163
164    #[test]
165    fn test_vertically_strided_view_stride_greater_than_height_with_valid_offset() {
166        let matrix = sample_matrix(); // height = 5
167        let view = VerticallyStridedRowIndexMap::new_view(matrix, 10, 4);
168
169        // offset == 4 < height == 5 → view selects row 4
170        assert_eq!(view.height(), 1);
171        assert_eq!(view.get(0, 2), Some(52)); // row 4
172    }
173
174    #[test]
175    fn test_vertically_strided_view_stride_greater_than_height_with_offset_beyond_height() {
176        let matrix = sample_matrix(); // height = 5
177        let view = VerticallyStridedRowIndexMap::new_view(matrix, 10, 6);
178
179        // offset == 6 > height == 5 → no valid row
180        assert_eq!(view.height(), 0);
181        assert_eq!(view.get(0, 0), None); // out of bounds
182    }
183}