1use core::ops::Deref;
2
3use crate::Matrix;
4
5#[derive(Copy, Clone, Debug)]
15pub struct VerticalPair<Top, Bottom> {
16 pub top: Top,
18 pub bottom: Bottom,
20}
21
22#[derive(Copy, Clone, Debug)]
32pub struct HorizontalPair<Left, Right> {
33 pub left: Left,
35 pub right: Right,
37}
38
39impl<Top, Bottom> VerticalPair<Top, Bottom> {
40 pub fn new<T>(top: Top, bottom: Bottom) -> Self
49 where
50 T: Send + Sync + Clone,
51 Top: Matrix<T>,
52 Bottom: Matrix<T>,
53 {
54 assert_eq!(top.width(), bottom.width());
55 Self { top, bottom }
56 }
57}
58
59impl<Left, Right> HorizontalPair<Left, Right> {
60 pub fn new<T>(left: Left, right: Right) -> Self
69 where
70 T: Send + Sync + Clone,
71 Left: Matrix<T>,
72 Right: Matrix<T>,
73 {
74 assert_eq!(left.height(), right.height());
75 Self { left, right }
76 }
77}
78
79impl<T: Send + Sync + Clone, Top: Matrix<T>, Bottom: Matrix<T>> Matrix<T>
80 for VerticalPair<Top, Bottom>
81{
82 fn width(&self) -> usize {
83 self.top.width()
84 }
85
86 fn height(&self) -> usize {
87 self.top.height() + self.bottom.height()
88 }
89
90 unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
91 unsafe {
92 if r < self.top.height() {
94 self.top.get_unchecked(r, c)
95 } else {
96 self.bottom.get_unchecked(r - self.top.height(), c)
97 }
98 }
99 }
100
101 unsafe fn row_unchecked(
102 &self,
103 r: usize,
104 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
105 unsafe {
106 if r < self.top.height() {
108 EitherRow::Left(self.top.row_unchecked(r).into_iter())
109 } else {
110 EitherRow::Right(self.bottom.row_unchecked(r - self.top.height()).into_iter())
111 }
112 }
113 }
114
115 unsafe fn row_subseq_unchecked(
116 &self,
117 r: usize,
118 start: usize,
119 end: usize,
120 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
121 unsafe {
122 if r < self.top.height() {
124 EitherRow::Left(self.top.row_subseq_unchecked(r, start, end).into_iter())
125 } else {
126 EitherRow::Right(
127 self.bottom
128 .row_subseq_unchecked(r - self.top.height(), start, end)
129 .into_iter(),
130 )
131 }
132 }
133 }
134
135 unsafe fn row_slice_unchecked(&self, r: usize) -> impl Deref<Target = [T]> {
136 unsafe {
137 if r < self.top.height() {
139 EitherRow::Left(self.top.row_slice_unchecked(r))
140 } else {
141 EitherRow::Right(self.bottom.row_slice_unchecked(r - self.top.height()))
142 }
143 }
144 }
145
146 unsafe fn row_subslice_unchecked(
147 &self,
148 r: usize,
149 start: usize,
150 end: usize,
151 ) -> impl Deref<Target = [T]> {
152 unsafe {
153 if r < self.top.height() {
155 EitherRow::Left(self.top.row_subslice_unchecked(r, start, end))
156 } else {
157 EitherRow::Right(self.bottom.row_subslice_unchecked(
158 r - self.top.height(),
159 start,
160 end,
161 ))
162 }
163 }
164 }
165}
166
167impl<T: Send + Sync + Clone, Left: Matrix<T>, Right: Matrix<T>> Matrix<T>
168 for HorizontalPair<Left, Right>
169{
170 fn width(&self) -> usize {
171 self.left.width() + self.right.width()
172 }
173
174 fn height(&self) -> usize {
175 self.left.height()
176 }
177
178 unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
179 unsafe {
180 if c < self.left.width() {
182 self.left.get_unchecked(r, c)
183 } else {
184 self.right.get_unchecked(r, c - self.left.width())
185 }
186 }
187 }
188
189 unsafe fn row_unchecked(
190 &self,
191 r: usize,
192 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
193 unsafe {
194 self.left
196 .row_unchecked(r)
197 .into_iter()
198 .chain(self.right.row_unchecked(r))
199 }
200 }
201}
202
203#[derive(Debug)]
205pub enum EitherRow<L, R> {
206 Left(L),
207 Right(R),
208}
209
210impl<T, L, R> Iterator for EitherRow<L, R>
211where
212 L: Iterator<Item = T>,
213 R: Iterator<Item = T>,
214{
215 type Item = T;
216
217 fn next(&mut self) -> Option<Self::Item> {
218 match self {
219 Self::Left(l) => l.next(),
220 Self::Right(r) => r.next(),
221 }
222 }
223}
224
225impl<T, L, R> Deref for EitherRow<L, R>
226where
227 L: Deref<Target = [T]>,
228 R: Deref<Target = [T]>,
229{
230 type Target = [T];
231 fn deref(&self) -> &Self::Target {
232 match self {
233 Self::Left(l) => l,
234 Self::Right(r) => r,
235 }
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use alloc::vec;
242 use alloc::vec::Vec;
243
244 use itertools::Itertools;
245
246 use super::*;
247 use crate::RowMajorMatrix;
248
249 #[test]
250 fn test_vertical_pair_empty_top() {
251 let top = RowMajorMatrix::new(vec![], 2); let bottom = RowMajorMatrix::new(vec![1, 2, 3, 4], 2); let vpair = VerticalPair::new::<i32>(top, bottom);
254 assert_eq!(vpair.height(), 2);
255 assert_eq!(vpair.get(1, 1), Some(4));
256 unsafe {
257 assert_eq!(vpair.get_unchecked(0, 0), 1);
258 }
259 }
260
261 #[test]
262 fn test_vertical_pair_composition() {
263 let top = RowMajorMatrix::new(vec![1, 2, 3, 4], 2); let bottom = RowMajorMatrix::new(vec![5, 6, 7, 8], 2); let vertical = VerticalPair::new::<i32>(top, bottom);
266
267 assert_eq!(vertical.width(), 2);
269 assert_eq!(vertical.height(), 4);
270
271 assert_eq!(vertical.get(0, 0), Some(1));
273 assert_eq!(vertical.get(1, 1), Some(4));
274
275 unsafe {
277 assert_eq!(vertical.get_unchecked(2, 0), 5);
278 assert_eq!(vertical.get_unchecked(3, 1), 8);
279 }
280
281 let row = vertical.row(3).unwrap().into_iter().collect_vec();
283 assert_eq!(row, vec![7, 8]);
284
285 unsafe {
286 let row = vertical.row_unchecked(1).into_iter().collect_vec();
288 assert_eq!(row, vec![3, 4]);
289
290 let row = vertical
291 .row_subseq_unchecked(0, 0, 1)
292 .into_iter()
293 .collect_vec();
294 assert_eq!(row, vec![1]);
295 }
296
297 assert_eq!(vertical.row_slice(2).unwrap().deref(), &[5, 6]);
299
300 unsafe {
301 assert_eq!(vertical.row_slice_unchecked(3).deref(), &[7, 8]);
303 assert_eq!(vertical.row_subslice_unchecked(1, 1, 2).deref(), &[4]);
304 }
305
306 assert_eq!(vertical.get(0, 2), None); assert_eq!(vertical.get(4, 0), None); assert!(vertical.row(4).is_none()); assert!(vertical.row_slice(4).is_none()); }
311
312 #[test]
313 fn test_horizontal_pair_composition() {
314 let left = RowMajorMatrix::new(vec![1, 2, 3, 4], 2); let right = RowMajorMatrix::new(vec![5, 6, 7, 8], 2); let horizontal = HorizontalPair::new::<i32>(left, right);
317
318 assert_eq!(horizontal.height(), 2);
320 assert_eq!(horizontal.width(), 4);
321
322 assert_eq!(horizontal.get(0, 0), Some(1));
324 assert_eq!(horizontal.get(1, 1), Some(4));
325
326 unsafe {
328 assert_eq!(horizontal.get_unchecked(0, 2), 5);
329 assert_eq!(horizontal.get_unchecked(1, 3), 8);
330 }
331
332 let row = horizontal.row(0).unwrap().into_iter().collect_vec();
334 assert_eq!(row, vec![1, 2, 5, 6]);
335
336 unsafe {
337 let row = horizontal.row_unchecked(1).into_iter().collect_vec();
338 assert_eq!(row, vec![3, 4, 7, 8]);
339 }
340
341 assert_eq!(horizontal.get(0, 4), None); assert_eq!(horizontal.get(2, 0), None); assert!(horizontal.row(2).is_none()); }
345
346 #[test]
347 fn test_either_row_iterator_behavior() {
348 type Iter = alloc::vec::IntoIter<i32>;
349
350 let left: EitherRow<Iter, Iter> = EitherRow::Left(vec![10, 20].into_iter());
352 assert_eq!(left.collect::<Vec<_>>(), vec![10, 20]);
353
354 let right: EitherRow<Iter, Iter> = EitherRow::Right(vec![30, 40].into_iter());
356 assert_eq!(right.collect::<Vec<_>>(), vec![30, 40]);
357 }
358
359 #[test]
360 fn test_either_row_deref_behavior() {
361 let left: EitherRow<&[i32], &[i32]> = EitherRow::Left(&[1, 2, 3]);
362 let right: EitherRow<&[i32], &[i32]> = EitherRow::Right(&[4, 5]);
363
364 assert_eq!(&*left, &[1, 2, 3]);
365 assert_eq!(&*right, &[4, 5]);
366 }
367
368 #[test]
369 #[should_panic]
370 fn test_vertical_pair_width_mismatch_should_panic() {
371 let a = RowMajorMatrix::new(vec![1, 2, 3], 1); let b = RowMajorMatrix::new(vec![4, 5], 2); let _ = VerticalPair::new::<i32>(a, b);
374 }
375
376 #[test]
377 #[should_panic]
378 fn test_horizontal_pair_height_mismatch_should_panic() {
379 let a = RowMajorMatrix::new(vec![1, 2, 3], 3); let b = RowMajorMatrix::new(vec![4, 5], 1); let _ = HorizontalPair::new::<i32>(a, b);
382 }
383}