1use core::ops::Deref;
2
3use crate::Matrix;
4use crate::bitrev::BitReversibleMatrix;
5use crate::dense::RowMajorMatrixView;
6
7pub type ViewPair<'a, T> = VerticalPair<RowMajorMatrixView<'a, T>, RowMajorMatrixView<'a, T>>;
15
16#[derive(Copy, Clone, Debug)]
26pub struct VerticalPair<Top, Bottom> {
27 pub top: Top,
29 pub bottom: Bottom,
31}
32
33#[derive(Copy, Clone, Debug)]
43pub struct HorizontalPair<Left, Right> {
44 pub left: Left,
46 pub right: Right,
48}
49
50impl<Top, Bottom> VerticalPair<Top, Bottom> {
51 pub fn new<T>(top: Top, bottom: Bottom) -> Self
60 where
61 T: Send + Sync + Clone,
62 Top: Matrix<T>,
63 Bottom: Matrix<T>,
64 {
65 assert_eq!(top.width(), bottom.width());
66 Self { top, bottom }
67 }
68}
69
70impl<Left, Right> HorizontalPair<Left, Right> {
71 pub fn new<T>(left: Left, right: Right) -> Self
80 where
81 T: Send + Sync + Clone,
82 Left: Matrix<T>,
83 Right: Matrix<T>,
84 {
85 assert_eq!(left.height(), right.height());
86 Self { left, right }
87 }
88}
89
90impl<T: Send + Sync + Clone, Top: Matrix<T>, Bottom: Matrix<T>> Matrix<T>
91 for VerticalPair<Top, Bottom>
92{
93 fn width(&self) -> usize {
94 self.top.width()
95 }
96
97 fn height(&self) -> usize {
98 self.top.height() + self.bottom.height()
99 }
100
101 unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
102 unsafe {
103 if r < self.top.height() {
105 self.top.get_unchecked(r, c)
106 } else {
107 self.bottom.get_unchecked(r - self.top.height(), c)
108 }
109 }
110 }
111
112 unsafe fn row_unchecked(
113 &self,
114 r: usize,
115 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
116 unsafe {
117 if r < self.top.height() {
119 EitherRow::Left(self.top.row_unchecked(r).into_iter())
120 } else {
121 EitherRow::Right(self.bottom.row_unchecked(r - self.top.height()).into_iter())
122 }
123 }
124 }
125
126 unsafe fn row_subseq_unchecked(
127 &self,
128 r: usize,
129 start: usize,
130 end: usize,
131 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
132 unsafe {
133 if r < self.top.height() {
135 EitherRow::Left(self.top.row_subseq_unchecked(r, start, end).into_iter())
136 } else {
137 EitherRow::Right(
138 self.bottom
139 .row_subseq_unchecked(r - self.top.height(), start, end)
140 .into_iter(),
141 )
142 }
143 }
144 }
145
146 unsafe fn row_slice_unchecked(&self, r: usize) -> impl Deref<Target = [T]> {
147 unsafe {
148 if r < self.top.height() {
150 EitherRow::Left(self.top.row_slice_unchecked(r))
151 } else {
152 EitherRow::Right(self.bottom.row_slice_unchecked(r - self.top.height()))
153 }
154 }
155 }
156
157 unsafe fn row_subslice_unchecked(
158 &self,
159 r: usize,
160 start: usize,
161 end: usize,
162 ) -> impl Deref<Target = [T]> {
163 unsafe {
164 if r < self.top.height() {
166 EitherRow::Left(self.top.row_subslice_unchecked(r, start, end))
167 } else {
168 EitherRow::Right(self.bottom.row_subslice_unchecked(
169 r - self.top.height(),
170 start,
171 end,
172 ))
173 }
174 }
175 }
176}
177
178impl<T: Send + Sync + Clone, Left: Matrix<T>, Right: Matrix<T>> Matrix<T>
179 for HorizontalPair<Left, Right>
180{
181 fn width(&self) -> usize {
182 self.left.width() + self.right.width()
183 }
184
185 fn height(&self) -> usize {
186 self.left.height()
187 }
188
189 unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
190 unsafe {
191 if c < self.left.width() {
193 self.left.get_unchecked(r, c)
194 } else {
195 self.right.get_unchecked(r, c - self.left.width())
196 }
197 }
198 }
199
200 unsafe fn row_unchecked(
201 &self,
202 r: usize,
203 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
204 unsafe {
205 self.left
207 .row_unchecked(r)
208 .into_iter()
209 .chain(self.right.row_unchecked(r))
210 }
211 }
212}
213
214#[derive(Debug)]
216pub enum EitherRow<L, R> {
217 Left(L),
218 Right(R),
219}
220
221impl<T, L, R> Iterator for EitherRow<L, R>
222where
223 L: Iterator<Item = T>,
224 R: Iterator<Item = T>,
225{
226 type Item = T;
227
228 fn next(&mut self) -> Option<Self::Item> {
229 match self {
230 Self::Left(l) => l.next(),
231 Self::Right(r) => r.next(),
232 }
233 }
234}
235
236impl<T, L, R> Deref for EitherRow<L, R>
237where
238 L: Deref<Target = [T]>,
239 R: Deref<Target = [T]>,
240{
241 type Target = [T];
242 fn deref(&self) -> &Self::Target {
243 match self {
244 Self::Left(l) => l,
245 Self::Right(r) => r,
246 }
247 }
248}
249
250impl<T: Clone + Send + Sync, Left: BitReversibleMatrix<T>, Right: BitReversibleMatrix<T>>
251 BitReversibleMatrix<T> for HorizontalPair<Left, Right>
252{
253 type BitRev = HorizontalPair<Left::BitRev, Right::BitRev>;
254
255 fn bit_reverse_rows(self) -> Self::BitRev {
256 HorizontalPair {
257 left: self.left.bit_reverse_rows(),
258 right: self.right.bit_reverse_rows(),
259 }
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use alloc::vec;
266 use alloc::vec::Vec;
267
268 use itertools::Itertools;
269
270 use super::*;
271 use crate::RowMajorMatrix;
272
273 #[test]
274 fn test_vertical_pair_empty_top() {
275 let top = RowMajorMatrix::new(vec![], 2); let bottom = RowMajorMatrix::new(vec![1, 2, 3, 4], 2); let vpair = VerticalPair::new::<i32>(top, bottom);
278 assert_eq!(vpair.height(), 2);
279 assert_eq!(vpair.get(1, 1), Some(4));
280 unsafe {
281 assert_eq!(vpair.get_unchecked(0, 0), 1);
282 }
283 }
284
285 #[test]
286 fn test_vertical_pair_composition() {
287 let top = RowMajorMatrix::new(vec![1, 2, 3, 4], 2); let bottom = RowMajorMatrix::new(vec![5, 6, 7, 8], 2); let vertical = VerticalPair::new::<i32>(top, bottom);
290
291 assert_eq!(vertical.width(), 2);
293 assert_eq!(vertical.height(), 4);
294
295 assert_eq!(vertical.get(0, 0), Some(1));
297 assert_eq!(vertical.get(1, 1), Some(4));
298
299 unsafe {
301 assert_eq!(vertical.get_unchecked(2, 0), 5);
302 assert_eq!(vertical.get_unchecked(3, 1), 8);
303 }
304
305 let row = vertical.row(3).unwrap().into_iter().collect_vec();
307 assert_eq!(row, vec![7, 8]);
308
309 unsafe {
310 let row = vertical.row_unchecked(1).into_iter().collect_vec();
312 assert_eq!(row, vec![3, 4]);
313
314 let row = vertical
315 .row_subseq_unchecked(0, 0, 1)
316 .into_iter()
317 .collect_vec();
318 assert_eq!(row, vec![1]);
319 }
320
321 assert_eq!(vertical.row_slice(2).unwrap().deref(), &[5, 6]);
323
324 unsafe {
325 assert_eq!(vertical.row_slice_unchecked(3).deref(), &[7, 8]);
327 assert_eq!(vertical.row_subslice_unchecked(1, 1, 2).deref(), &[4]);
328 }
329
330 assert_eq!(vertical.get(0, 2), None); assert_eq!(vertical.get(4, 0), None); assert!(vertical.row(4).is_none()); assert!(vertical.row_slice(4).is_none()); }
335
336 #[test]
337 fn test_horizontal_pair_composition() {
338 let left = RowMajorMatrix::new(vec![1, 2, 3, 4], 2); let right = RowMajorMatrix::new(vec![5, 6, 7, 8], 2); let horizontal = HorizontalPair::new::<i32>(left, right);
341
342 assert_eq!(horizontal.height(), 2);
344 assert_eq!(horizontal.width(), 4);
345
346 assert_eq!(horizontal.get(0, 0), Some(1));
348 assert_eq!(horizontal.get(1, 1), Some(4));
349
350 unsafe {
352 assert_eq!(horizontal.get_unchecked(0, 2), 5);
353 assert_eq!(horizontal.get_unchecked(1, 3), 8);
354 }
355
356 let row = horizontal.row(0).unwrap().into_iter().collect_vec();
358 assert_eq!(row, vec![1, 2, 5, 6]);
359
360 unsafe {
361 let row = horizontal.row_unchecked(1).into_iter().collect_vec();
362 assert_eq!(row, vec![3, 4, 7, 8]);
363 }
364
365 assert_eq!(horizontal.get(0, 4), None); assert_eq!(horizontal.get(2, 0), None); assert!(horizontal.row(2).is_none()); }
369
370 #[test]
371 fn test_either_row_iterator_behavior() {
372 type Iter = alloc::vec::IntoIter<i32>;
373
374 let left: EitherRow<Iter, Iter> = EitherRow::Left(vec![10, 20].into_iter());
376 assert_eq!(left.collect::<Vec<_>>(), vec![10, 20]);
377
378 let right: EitherRow<Iter, Iter> = EitherRow::Right(vec![30, 40].into_iter());
380 assert_eq!(right.collect::<Vec<_>>(), vec![30, 40]);
381 }
382
383 #[test]
384 fn test_either_row_deref_behavior() {
385 let left: EitherRow<&[i32], &[i32]> = EitherRow::Left(&[1, 2, 3]);
386 let right: EitherRow<&[i32], &[i32]> = EitherRow::Right(&[4, 5]);
387
388 assert_eq!(&*left, &[1, 2, 3]);
389 assert_eq!(&*right, &[4, 5]);
390 }
391
392 #[test]
393 #[should_panic]
394 fn test_vertical_pair_width_mismatch_should_panic() {
395 let a = RowMajorMatrix::new(vec![1, 2, 3], 1); let b = RowMajorMatrix::new(vec![4, 5], 2); let _ = VerticalPair::new::<i32>(a, b);
398 }
399
400 #[test]
401 #[should_panic]
402 fn test_horizontal_pair_height_mismatch_should_panic() {
403 let a = RowMajorMatrix::new(vec![1, 2, 3], 3); let b = RowMajorMatrix::new(vec![4, 5], 1); let _ = HorizontalPair::new::<i32>(a, b);
406 }
407}