Skip to main content

p3_matrix/
stack.rs

1use core::ops::Deref;
2
3use crate::Matrix;
4use crate::bitrev::BitReversibleMatrix;
5use crate::dense::RowMajorMatrixView;
6
7/// A type alias representing a vertical composition of two row-major matrix views.
8///
9/// `ViewPair` combines two [`RowMajorMatrixView`]'s with the same element type `T`
10/// and lifetime `'a` into a single virtual matrix stacked vertically.
11///
12/// Both views must have the same width; the resulting view has a height equal
13/// to the sum of the two original heights.
14pub type ViewPair<'a, T> = VerticalPair<RowMajorMatrixView<'a, T>, RowMajorMatrixView<'a, T>>;
15
16/// A matrix composed by stacking two matrices vertically, one on top of the other.
17///
18/// Both matrices must have the same `width`.
19/// The resulting matrix has dimensions:
20/// - `width`: The same as the inputs.
21/// - `height`: The sum of the `heights` of the input matrices.
22///
23/// Element access and iteration will first access the rows of the top matrix,
24/// followed by the rows of the bottom matrix.
25#[derive(Copy, Clone, Debug)]
26pub struct VerticalPair<Top, Bottom> {
27    /// The top matrix in the vertical composition.
28    pub top: Top,
29    /// The bottom matrix in the vertical composition.
30    pub bottom: Bottom,
31}
32
33/// A matrix composed by placing two matrices side-by-side horizontally.
34///
35/// Both matrices must have the same `height`.
36/// The resulting matrix has dimensions:
37/// - `width`: The sum of the `widths` of the input matrices.
38/// - `height`: The same as the inputs.
39///
40/// Element access and iteration for a given row `i` will first access the elements in the `i`'th row of the left matrix,
41/// followed by elements in the `i'`th row of the right matrix.
42#[derive(Copy, Clone, Debug)]
43pub struct HorizontalPair<Left, Right> {
44    /// The left matrix in the horizontal composition.
45    pub left: Left,
46    /// The right matrix in the horizontal composition.
47    pub right: Right,
48}
49
50impl<Top, Bottom> VerticalPair<Top, Bottom> {
51    /// Create a new `VerticalPair` by stacking two matrices vertically.
52    ///
53    /// # Panics
54    /// Panics if the two matrices do not have the same width (i.e., number of columns),
55    /// since vertical composition requires column alignment.
56    ///
57    /// # Returns
58    /// A `VerticalPair` that represents the combined matrix.
59    pub fn new<T>(top: Top, bottom: Bottom) -> Self
60    where
61        T: Send + Sync + Clone,
62        Top: Matrix<T>,
63        Bottom: Matrix<T>,
64    {
65        assert_eq!(top.width(), bottom.width());
66        Self { top, bottom }
67    }
68}
69
70impl<Left, Right> HorizontalPair<Left, Right> {
71    /// Create a new `HorizontalPair` by joining two matrices side by side.
72    ///
73    /// # Panics
74    /// Panics if the two matrices do not have the same height (i.e., number of rows),
75    /// since horizontal composition requires row alignment.
76    ///
77    /// # Returns
78    /// A `HorizontalPair` that represents the combined matrix.
79    pub fn new<T>(left: Left, right: Right) -> Self
80    where
81        T: Send + Sync + Clone,
82        Left: Matrix<T>,
83        Right: Matrix<T>,
84    {
85        assert_eq!(left.height(), right.height());
86        Self { left, right }
87    }
88}
89
90impl<T: Send + Sync + Clone, Top: Matrix<T>, Bottom: Matrix<T>> Matrix<T>
91    for VerticalPair<Top, Bottom>
92{
93    fn width(&self) -> usize {
94        self.top.width()
95    }
96
97    fn height(&self) -> usize {
98        self.top.height() + self.bottom.height()
99    }
100
101    unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
102        unsafe {
103            // Safety: The caller must ensure that r < self.height() and c < self.width()
104            if r < self.top.height() {
105                self.top.get_unchecked(r, c)
106            } else {
107                self.bottom.get_unchecked(r - self.top.height(), c)
108            }
109        }
110    }
111
112    unsafe fn row_unchecked(
113        &self,
114        r: usize,
115    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
116        unsafe {
117            // Safety: The caller must ensure that r < self.height()
118            if r < self.top.height() {
119                EitherRow::Left(self.top.row_unchecked(r).into_iter())
120            } else {
121                EitherRow::Right(self.bottom.row_unchecked(r - self.top.height()).into_iter())
122            }
123        }
124    }
125
126    unsafe fn row_subseq_unchecked(
127        &self,
128        r: usize,
129        start: usize,
130        end: usize,
131    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
132        unsafe {
133            // Safety: The caller must ensure that r < self.height() and start <= end <= self.width()
134            if r < self.top.height() {
135                EitherRow::Left(self.top.row_subseq_unchecked(r, start, end).into_iter())
136            } else {
137                EitherRow::Right(
138                    self.bottom
139                        .row_subseq_unchecked(r - self.top.height(), start, end)
140                        .into_iter(),
141                )
142            }
143        }
144    }
145
146    unsafe fn row_slice_unchecked(&self, r: usize) -> impl Deref<Target = [T]> {
147        unsafe {
148            // Safety: The caller must ensure that r < self.height()
149            if r < self.top.height() {
150                EitherRow::Left(self.top.row_slice_unchecked(r))
151            } else {
152                EitherRow::Right(self.bottom.row_slice_unchecked(r - self.top.height()))
153            }
154        }
155    }
156
157    unsafe fn row_subslice_unchecked(
158        &self,
159        r: usize,
160        start: usize,
161        end: usize,
162    ) -> impl Deref<Target = [T]> {
163        unsafe {
164            // Safety: The caller must ensure that r < self.height() and start <= end <= self.width()
165            if r < self.top.height() {
166                EitherRow::Left(self.top.row_subslice_unchecked(r, start, end))
167            } else {
168                EitherRow::Right(self.bottom.row_subslice_unchecked(
169                    r - self.top.height(),
170                    start,
171                    end,
172                ))
173            }
174        }
175    }
176}
177
178impl<T: Send + Sync + Clone, Left: Matrix<T>, Right: Matrix<T>> Matrix<T>
179    for HorizontalPair<Left, Right>
180{
181    fn width(&self) -> usize {
182        self.left.width() + self.right.width()
183    }
184
185    fn height(&self) -> usize {
186        self.left.height()
187    }
188
189    unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
190        unsafe {
191            // Safety: The caller must ensure that r < self.height() and c < self.width()
192            if c < self.left.width() {
193                self.left.get_unchecked(r, c)
194            } else {
195                self.right.get_unchecked(r, c - self.left.width())
196            }
197        }
198    }
199
200    unsafe fn row_unchecked(
201        &self,
202        r: usize,
203    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
204        unsafe {
205            // Safety: The caller must ensure that r < self.height()
206            self.left
207                .row_unchecked(r)
208                .into_iter()
209                .chain(self.right.row_unchecked(r))
210        }
211    }
212}
213
214/// We use this to wrap both the row iterator and the row slice.
215#[derive(Debug)]
216pub enum EitherRow<L, R> {
217    Left(L),
218    Right(R),
219}
220
221impl<T, L, R> Iterator for EitherRow<L, R>
222where
223    L: Iterator<Item = T>,
224    R: Iterator<Item = T>,
225{
226    type Item = T;
227
228    fn next(&mut self) -> Option<Self::Item> {
229        match self {
230            Self::Left(l) => l.next(),
231            Self::Right(r) => r.next(),
232        }
233    }
234}
235
236impl<T, L, R> Deref for EitherRow<L, R>
237where
238    L: Deref<Target = [T]>,
239    R: Deref<Target = [T]>,
240{
241    type Target = [T];
242    fn deref(&self) -> &Self::Target {
243        match self {
244            Self::Left(l) => l,
245            Self::Right(r) => r,
246        }
247    }
248}
249
250impl<T: Clone + Send + Sync, Left: BitReversibleMatrix<T>, Right: BitReversibleMatrix<T>>
251    BitReversibleMatrix<T> for HorizontalPair<Left, Right>
252{
253    type BitRev = HorizontalPair<Left::BitRev, Right::BitRev>;
254
255    fn bit_reverse_rows(self) -> Self::BitRev {
256        HorizontalPair {
257            left: self.left.bit_reverse_rows(),
258            right: self.right.bit_reverse_rows(),
259        }
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use alloc::vec;
266    use alloc::vec::Vec;
267
268    use itertools::Itertools;
269
270    use super::*;
271    use crate::RowMajorMatrix;
272
273    #[test]
274    fn test_vertical_pair_empty_top() {
275        let top = RowMajorMatrix::new(vec![], 2); // 0x2
276        let bottom = RowMajorMatrix::new(vec![1, 2, 3, 4], 2); // 2x2
277        let vpair = VerticalPair::new::<i32>(top, bottom);
278        assert_eq!(vpair.height(), 2);
279        assert_eq!(vpair.get(1, 1), Some(4));
280        unsafe {
281            assert_eq!(vpair.get_unchecked(0, 0), 1);
282        }
283    }
284
285    #[test]
286    fn test_vertical_pair_composition() {
287        let top = RowMajorMatrix::new(vec![1, 2, 3, 4], 2); // 2x2
288        let bottom = RowMajorMatrix::new(vec![5, 6, 7, 8], 2); // 2x2
289        let vertical = VerticalPair::new::<i32>(top, bottom);
290
291        // Dimensions
292        assert_eq!(vertical.width(), 2);
293        assert_eq!(vertical.height(), 4);
294
295        // Values from top
296        assert_eq!(vertical.get(0, 0), Some(1));
297        assert_eq!(vertical.get(1, 1), Some(4));
298
299        // Values from bottom
300        unsafe {
301            assert_eq!(vertical.get_unchecked(2, 0), 5);
302            assert_eq!(vertical.get_unchecked(3, 1), 8);
303        }
304
305        // Row iter from bottom
306        let row = vertical.row(3).unwrap().into_iter().collect_vec();
307        assert_eq!(row, vec![7, 8]);
308
309        unsafe {
310            // Row iter from top
311            let row = vertical.row_unchecked(1).into_iter().collect_vec();
312            assert_eq!(row, vec![3, 4]);
313
314            let row = vertical
315                .row_subseq_unchecked(0, 0, 1)
316                .into_iter()
317                .collect_vec();
318            assert_eq!(row, vec![1]);
319        }
320
321        // Row slice
322        assert_eq!(vertical.row_slice(2).unwrap().deref(), &[5, 6]);
323
324        unsafe {
325            // Row slice unchecked
326            assert_eq!(vertical.row_slice_unchecked(3).deref(), &[7, 8]);
327            assert_eq!(vertical.row_subslice_unchecked(1, 1, 2).deref(), &[4]);
328        }
329
330        assert_eq!(vertical.get(0, 2), None); // Width out of bounds
331        assert_eq!(vertical.get(4, 0), None); // Height out of bounds
332        assert!(vertical.row(4).is_none()); // Height out of bounds
333        assert!(vertical.row_slice(4).is_none()); // Height out of bounds
334    }
335
336    #[test]
337    fn test_horizontal_pair_composition() {
338        let left = RowMajorMatrix::new(vec![1, 2, 3, 4], 2); // 2x2
339        let right = RowMajorMatrix::new(vec![5, 6, 7, 8], 2); // 2x2
340        let horizontal = HorizontalPair::new::<i32>(left, right);
341
342        // Dimensions
343        assert_eq!(horizontal.height(), 2);
344        assert_eq!(horizontal.width(), 4);
345
346        // Left values
347        assert_eq!(horizontal.get(0, 0), Some(1));
348        assert_eq!(horizontal.get(1, 1), Some(4));
349
350        // Right values
351        unsafe {
352            assert_eq!(horizontal.get_unchecked(0, 2), 5);
353            assert_eq!(horizontal.get_unchecked(1, 3), 8);
354        }
355
356        // Row iter
357        let row = horizontal.row(0).unwrap().into_iter().collect_vec();
358        assert_eq!(row, vec![1, 2, 5, 6]);
359
360        unsafe {
361            let row = horizontal.row_unchecked(1).into_iter().collect_vec();
362            assert_eq!(row, vec![3, 4, 7, 8]);
363        }
364
365        assert_eq!(horizontal.get(0, 4), None); // Width out of bounds
366        assert_eq!(horizontal.get(2, 0), None); // Height out of bounds
367        assert!(horizontal.row(2).is_none()); // Height out of bounds
368    }
369
370    #[test]
371    fn test_either_row_iterator_behavior() {
372        type Iter = alloc::vec::IntoIter<i32>;
373
374        // Left variant
375        let left: EitherRow<Iter, Iter> = EitherRow::Left(vec![10, 20].into_iter());
376        assert_eq!(left.collect::<Vec<_>>(), vec![10, 20]);
377
378        // Right variant
379        let right: EitherRow<Iter, Iter> = EitherRow::Right(vec![30, 40].into_iter());
380        assert_eq!(right.collect::<Vec<_>>(), vec![30, 40]);
381    }
382
383    #[test]
384    fn test_either_row_deref_behavior() {
385        let left: EitherRow<&[i32], &[i32]> = EitherRow::Left(&[1, 2, 3]);
386        let right: EitherRow<&[i32], &[i32]> = EitherRow::Right(&[4, 5]);
387
388        assert_eq!(&*left, &[1, 2, 3]);
389        assert_eq!(&*right, &[4, 5]);
390    }
391
392    #[test]
393    #[should_panic]
394    fn test_vertical_pair_width_mismatch_should_panic() {
395        let a = RowMajorMatrix::new(vec![1, 2, 3], 1); // 3x1
396        let b = RowMajorMatrix::new(vec![4, 5], 2); // 1x2
397        let _ = VerticalPair::new::<i32>(a, b);
398    }
399
400    #[test]
401    #[should_panic]
402    fn test_horizontal_pair_height_mismatch_should_panic() {
403        let a = RowMajorMatrix::new(vec![1, 2, 3], 3); // 1x3
404        let b = RowMajorMatrix::new(vec![4, 5], 1); // 2x1
405        let _ = HorizontalPair::new::<i32>(a, b);
406    }
407}