p3_matrix/
stack.rs

1use core::ops::Deref;
2
3use crate::Matrix;
4
5/// A matrix composed by stacking two matrices vertically, one on top of the other.
6///
7/// Both matrices must have the same `width`.
8/// The resulting matrix has dimensions:
9/// - `width`: The same as the inputs.
10/// - `height`: The sum of the `heights` of the input matrices.
11///
12/// Element access and iteration will first access the rows of the top matrix,
13/// followed by the rows of the bottom matrix.
14#[derive(Copy, Clone, Debug)]
15pub struct VerticalPair<Top, Bottom> {
16    /// The top matrix in the vertical composition.
17    pub top: Top,
18    /// The bottom matrix in the vertical composition.
19    pub bottom: Bottom,
20}
21
22/// A matrix composed by placing two matrices side-by-side horizontally.
23///
24/// Both matrices must have the same `height`.
25/// The resulting matrix has dimensions:
26/// - `width`: The sum of the `widths` of the input matrices.
27/// - `height`: The same as the inputs.
28///
29/// Element access and iteration for a given row `i` will first access the elements in the `i`'th row of the left matrix,
30/// followed by elements in the `i'`th row of the right matrix.
31#[derive(Copy, Clone, Debug)]
32pub struct HorizontalPair<Left, Right> {
33    /// The left matrix in the horizontal composition.
34    pub left: Left,
35    /// The right matrix in the horizontal composition.
36    pub right: Right,
37}
38
39impl<Top, Bottom> VerticalPair<Top, Bottom> {
40    /// Create a new `VerticalPair` by stacking two matrices vertically.
41    ///
42    /// # Panics
43    /// Panics if the two matrices do not have the same width (i.e., number of columns),
44    /// since vertical composition requires column alignment.
45    ///
46    /// # Returns
47    /// A `VerticalPair` that represents the combined matrix.
48    pub fn new<T>(top: Top, bottom: Bottom) -> Self
49    where
50        T: Send + Sync + Clone,
51        Top: Matrix<T>,
52        Bottom: Matrix<T>,
53    {
54        assert_eq!(top.width(), bottom.width());
55        Self { top, bottom }
56    }
57}
58
59impl<Left, Right> HorizontalPair<Left, Right> {
60    /// Create a new `HorizontalPair` by joining two matrices side by side.
61    ///
62    /// # Panics
63    /// Panics if the two matrices do not have the same height (i.e., number of rows),
64    /// since horizontal composition requires row alignment.
65    ///
66    /// # Returns
67    /// A `HorizontalPair` that represents the combined matrix.
68    pub fn new<T>(left: Left, right: Right) -> Self
69    where
70        T: Send + Sync + Clone,
71        Left: Matrix<T>,
72        Right: Matrix<T>,
73    {
74        assert_eq!(left.height(), right.height());
75        Self { left, right }
76    }
77}
78
79impl<T: Send + Sync + Clone, Top: Matrix<T>, Bottom: Matrix<T>> Matrix<T>
80    for VerticalPair<Top, Bottom>
81{
82    fn width(&self) -> usize {
83        self.top.width()
84    }
85
86    fn height(&self) -> usize {
87        self.top.height() + self.bottom.height()
88    }
89
90    unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
91        unsafe {
92            // Safety: The caller must ensure that r < self.height() and c < self.width()
93            if r < self.top.height() {
94                self.top.get_unchecked(r, c)
95            } else {
96                self.bottom.get_unchecked(r - self.top.height(), c)
97            }
98        }
99    }
100
101    unsafe fn row_unchecked(
102        &self,
103        r: usize,
104    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
105        unsafe {
106            // Safety: The caller must ensure that r < self.height()
107            if r < self.top.height() {
108                EitherRow::Left(self.top.row_unchecked(r).into_iter())
109            } else {
110                EitherRow::Right(self.bottom.row_unchecked(r - self.top.height()).into_iter())
111            }
112        }
113    }
114
115    unsafe fn row_subseq_unchecked(
116        &self,
117        r: usize,
118        start: usize,
119        end: usize,
120    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
121        unsafe {
122            // Safety: The caller must ensure that r < self.height() and start <= end <= self.width()
123            if r < self.top.height() {
124                EitherRow::Left(self.top.row_subseq_unchecked(r, start, end).into_iter())
125            } else {
126                EitherRow::Right(
127                    self.bottom
128                        .row_subseq_unchecked(r - self.top.height(), start, end)
129                        .into_iter(),
130                )
131            }
132        }
133    }
134
135    unsafe fn row_slice_unchecked(&self, r: usize) -> impl Deref<Target = [T]> {
136        unsafe {
137            // Safety: The caller must ensure that r < self.height()
138            if r < self.top.height() {
139                EitherRow::Left(self.top.row_slice_unchecked(r))
140            } else {
141                EitherRow::Right(self.bottom.row_slice_unchecked(r - self.top.height()))
142            }
143        }
144    }
145
146    unsafe fn row_subslice_unchecked(
147        &self,
148        r: usize,
149        start: usize,
150        end: usize,
151    ) -> impl Deref<Target = [T]> {
152        unsafe {
153            // Safety: The caller must ensure that r < self.height() and start <= end <= self.width()
154            if r < self.top.height() {
155                EitherRow::Left(self.top.row_subslice_unchecked(r, start, end))
156            } else {
157                EitherRow::Right(self.bottom.row_subslice_unchecked(
158                    r - self.top.height(),
159                    start,
160                    end,
161                ))
162            }
163        }
164    }
165}
166
167impl<T: Send + Sync + Clone, Left: Matrix<T>, Right: Matrix<T>> Matrix<T>
168    for HorizontalPair<Left, Right>
169{
170    fn width(&self) -> usize {
171        self.left.width() + self.right.width()
172    }
173
174    fn height(&self) -> usize {
175        self.left.height()
176    }
177
178    unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
179        unsafe {
180            // Safety: The caller must ensure that r < self.height() and c < self.width()
181            if c < self.left.width() {
182                self.left.get_unchecked(r, c)
183            } else {
184                self.right.get_unchecked(r, c - self.left.width())
185            }
186        }
187    }
188
189    unsafe fn row_unchecked(
190        &self,
191        r: usize,
192    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
193        unsafe {
194            // Safety: The caller must ensure that r < self.height()
195            self.left
196                .row_unchecked(r)
197                .into_iter()
198                .chain(self.right.row_unchecked(r))
199        }
200    }
201}
202
203/// We use this to wrap both the row iterator and the row slice.
204#[derive(Debug)]
205pub enum EitherRow<L, R> {
206    Left(L),
207    Right(R),
208}
209
210impl<T, L, R> Iterator for EitherRow<L, R>
211where
212    L: Iterator<Item = T>,
213    R: Iterator<Item = T>,
214{
215    type Item = T;
216
217    fn next(&mut self) -> Option<Self::Item> {
218        match self {
219            Self::Left(l) => l.next(),
220            Self::Right(r) => r.next(),
221        }
222    }
223}
224
225impl<T, L, R> Deref for EitherRow<L, R>
226where
227    L: Deref<Target = [T]>,
228    R: Deref<Target = [T]>,
229{
230    type Target = [T];
231    fn deref(&self) -> &Self::Target {
232        match self {
233            Self::Left(l) => l,
234            Self::Right(r) => r,
235        }
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use alloc::vec;
242    use alloc::vec::Vec;
243
244    use itertools::Itertools;
245
246    use super::*;
247    use crate::RowMajorMatrix;
248
249    #[test]
250    fn test_vertical_pair_empty_top() {
251        let top = RowMajorMatrix::new(vec![], 2); // 0x2
252        let bottom = RowMajorMatrix::new(vec![1, 2, 3, 4], 2); // 2x2
253        let vpair = VerticalPair::new::<i32>(top, bottom);
254        assert_eq!(vpair.height(), 2);
255        assert_eq!(vpair.get(1, 1), Some(4));
256        unsafe {
257            assert_eq!(vpair.get_unchecked(0, 0), 1);
258        }
259    }
260
261    #[test]
262    fn test_vertical_pair_composition() {
263        let top = RowMajorMatrix::new(vec![1, 2, 3, 4], 2); // 2x2
264        let bottom = RowMajorMatrix::new(vec![5, 6, 7, 8], 2); // 2x2
265        let vertical = VerticalPair::new::<i32>(top, bottom);
266
267        // Dimensions
268        assert_eq!(vertical.width(), 2);
269        assert_eq!(vertical.height(), 4);
270
271        // Values from top
272        assert_eq!(vertical.get(0, 0), Some(1));
273        assert_eq!(vertical.get(1, 1), Some(4));
274
275        // Values from bottom
276        unsafe {
277            assert_eq!(vertical.get_unchecked(2, 0), 5);
278            assert_eq!(vertical.get_unchecked(3, 1), 8);
279        }
280
281        // Row iter from bottom
282        let row = vertical.row(3).unwrap().into_iter().collect_vec();
283        assert_eq!(row, vec![7, 8]);
284
285        unsafe {
286            // Row iter from top
287            let row = vertical.row_unchecked(1).into_iter().collect_vec();
288            assert_eq!(row, vec![3, 4]);
289
290            let row = vertical
291                .row_subseq_unchecked(0, 0, 1)
292                .into_iter()
293                .collect_vec();
294            assert_eq!(row, vec![1]);
295        }
296
297        // Row slice
298        assert_eq!(vertical.row_slice(2).unwrap().deref(), &[5, 6]);
299
300        unsafe {
301            // Row slice unchecked
302            assert_eq!(vertical.row_slice_unchecked(3).deref(), &[7, 8]);
303            assert_eq!(vertical.row_subslice_unchecked(1, 1, 2).deref(), &[4]);
304        }
305
306        assert_eq!(vertical.get(0, 2), None); // Width out of bounds
307        assert_eq!(vertical.get(4, 0), None); // Height out of bounds
308        assert!(vertical.row(4).is_none()); // Height out of bounds
309        assert!(vertical.row_slice(4).is_none()); // Height out of bounds
310    }
311
312    #[test]
313    fn test_horizontal_pair_composition() {
314        let left = RowMajorMatrix::new(vec![1, 2, 3, 4], 2); // 2x2
315        let right = RowMajorMatrix::new(vec![5, 6, 7, 8], 2); // 2x2
316        let horizontal = HorizontalPair::new::<i32>(left, right);
317
318        // Dimensions
319        assert_eq!(horizontal.height(), 2);
320        assert_eq!(horizontal.width(), 4);
321
322        // Left values
323        assert_eq!(horizontal.get(0, 0), Some(1));
324        assert_eq!(horizontal.get(1, 1), Some(4));
325
326        // Right values
327        unsafe {
328            assert_eq!(horizontal.get_unchecked(0, 2), 5);
329            assert_eq!(horizontal.get_unchecked(1, 3), 8);
330        }
331
332        // Row iter
333        let row = horizontal.row(0).unwrap().into_iter().collect_vec();
334        assert_eq!(row, vec![1, 2, 5, 6]);
335
336        unsafe {
337            let row = horizontal.row_unchecked(1).into_iter().collect_vec();
338            assert_eq!(row, vec![3, 4, 7, 8]);
339        }
340
341        assert_eq!(horizontal.get(0, 4), None); // Width out of bounds
342        assert_eq!(horizontal.get(2, 0), None); // Height out of bounds
343        assert!(horizontal.row(2).is_none()); // Height out of bounds
344    }
345
346    #[test]
347    fn test_either_row_iterator_behavior() {
348        type Iter = alloc::vec::IntoIter<i32>;
349
350        // Left variant
351        let left: EitherRow<Iter, Iter> = EitherRow::Left(vec![10, 20].into_iter());
352        assert_eq!(left.collect::<Vec<_>>(), vec![10, 20]);
353
354        // Right variant
355        let right: EitherRow<Iter, Iter> = EitherRow::Right(vec![30, 40].into_iter());
356        assert_eq!(right.collect::<Vec<_>>(), vec![30, 40]);
357    }
358
359    #[test]
360    fn test_either_row_deref_behavior() {
361        let left: EitherRow<&[i32], &[i32]> = EitherRow::Left(&[1, 2, 3]);
362        let right: EitherRow<&[i32], &[i32]> = EitherRow::Right(&[4, 5]);
363
364        assert_eq!(&*left, &[1, 2, 3]);
365        assert_eq!(&*right, &[4, 5]);
366    }
367
368    #[test]
369    #[should_panic]
370    fn test_vertical_pair_width_mismatch_should_panic() {
371        let a = RowMajorMatrix::new(vec![1, 2, 3], 1); // 3x1
372        let b = RowMajorMatrix::new(vec![4, 5], 2); // 1x2
373        let _ = VerticalPair::new::<i32>(a, b);
374    }
375
376    #[test]
377    #[should_panic]
378    fn test_horizontal_pair_height_mismatch_should_panic() {
379        let a = RowMajorMatrix::new(vec![1, 2, 3], 3); // 1x3
380        let b = RowMajorMatrix::new(vec![4, 5], 1); // 2x1
381        let _ = HorizontalPair::new::<i32>(a, b);
382    }
383}