1use core::marker::PhantomData;
2use core::ops::Range;
3
4use crate::Matrix;
5
6#[derive(Clone)]
12pub struct HorizontallyTruncated<T, Inner> {
13 inner: Inner,
15 column_range: Range<usize>,
17 _phantom: PhantomData<T>,
19}
20
21impl<T, Inner: Matrix<T>> HorizontallyTruncated<T, Inner>
22where
23 T: Send + Sync + Clone,
24{
25 pub fn new(inner: Inner, truncated_width: usize) -> Option<Self> {
35 Self::new_with_range(inner, 0..truncated_width)
36 }
37
38 pub fn new_with_range(inner: Inner, column_range: Range<usize>) -> Option<Self> {
46 (column_range.end <= inner.width()).then(|| Self {
47 inner,
48 column_range,
49 _phantom: PhantomData,
50 })
51 }
52}
53
54impl<T, Inner> Matrix<T> for HorizontallyTruncated<T, Inner>
55where
56 T: Send + Sync + Clone,
57 Inner: Matrix<T>,
58{
59 #[inline(always)]
61 fn width(&self) -> usize {
62 self.column_range.len()
63 }
64
65 #[inline(always)]
67 fn height(&self) -> usize {
68 self.inner.height()
69 }
70
71 #[inline(always)]
72 unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
73 unsafe {
74 self.inner.get_unchecked(r, self.column_range.start + c)
78 }
79 }
80
81 unsafe fn row_unchecked(
82 &self,
83 r: usize,
84 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
85 unsafe {
86 self.inner
88 .row_subseq_unchecked(r, self.column_range.start, self.column_range.end)
89 }
90 }
91
92 unsafe fn row_subseq_unchecked(
93 &self,
94 r: usize,
95 start: usize,
96 end: usize,
97 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
98 unsafe {
99 self.inner.row_subseq_unchecked(
103 r,
104 self.column_range.start + start,
105 self.column_range.start + end,
106 )
107 }
108 }
109
110 unsafe fn row_subslice_unchecked(
111 &self,
112 r: usize,
113 start: usize,
114 end: usize,
115 ) -> impl core::ops::Deref<Target = [T]> {
116 unsafe {
117 self.inner.row_subslice_unchecked(
121 r,
122 self.column_range.start + start,
123 self.column_range.start + end,
124 )
125 }
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use alloc::vec;
132 use alloc::vec::Vec;
133
134 use super::*;
135 use crate::dense::RowMajorMatrix;
136
137 #[test]
138 fn test_truncate_width_by_one() {
139 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 4);
144
145 let truncated = HorizontallyTruncated::new(inner, 3).unwrap();
147
148 assert_eq!(truncated.width(), 3);
150
151 assert_eq!(truncated.height(), 3);
153
154 assert_eq!(truncated.get(0, 0), Some(1)); assert_eq!(truncated.get(1, 1), Some(6)); unsafe {
158 assert_eq!(truncated.get_unchecked(0, 1), 2); assert_eq!(truncated.get_unchecked(2, 2), 11); }
161
162 let row0: Vec<_> = truncated.row(0).unwrap().into_iter().collect();
164 assert_eq!(row0, vec![1, 2, 3]);
165 unsafe {
166 let row1: Vec<_> = truncated.row_unchecked(1).into_iter().collect();
168 assert_eq!(row1, vec![5, 6, 7]);
169
170 let row3_subset: Vec<_> = truncated
172 .row_subseq_unchecked(2, 1, 2)
173 .into_iter()
174 .collect();
175 assert_eq!(row3_subset, vec![10]);
176 }
177
178 unsafe {
179 let row1 = truncated.row_slice(1).unwrap();
180 assert_eq!(&*row1, &[5, 6, 7]);
181
182 let row2 = truncated.row_slice_unchecked(2);
183 assert_eq!(&*row2, &[9, 10, 11]);
184
185 let row0_subslice = truncated.row_subslice_unchecked(0, 0, 2);
186 assert_eq!(&*row0_subslice, &[1, 2]);
187 }
188
189 assert!(truncated.get(0, 3).is_none()); assert!(truncated.get(3, 0).is_none()); assert!(truncated.row(3).is_none()); assert!(truncated.row_slice(3).is_none()); let as_matrix = truncated.to_row_major_matrix();
196
197 let expected = RowMajorMatrix::new(vec![1, 2, 3, 5, 6, 7, 9, 10, 11], 3);
202
203 assert_eq!(as_matrix, expected);
204 }
205
206 #[test]
207 fn test_no_truncation() {
208 let inner = RowMajorMatrix::new(vec![7, 8, 9, 10], 2);
212
213 let truncated = HorizontallyTruncated::new(inner, 2).unwrap();
215
216 assert_eq!(truncated.width(), 2);
217 assert_eq!(truncated.height(), 2);
218 assert_eq!(truncated.get(0, 1).unwrap(), 8);
219 assert_eq!(truncated.get(1, 0).unwrap(), 9);
220
221 unsafe {
222 assert_eq!(truncated.get_unchecked(0, 0), 7);
223 assert_eq!(truncated.get_unchecked(1, 1), 10);
224 }
225
226 let row0: Vec<_> = truncated.row(0).unwrap().into_iter().collect();
227 assert_eq!(row0, vec![7, 8]);
228
229 let row1: Vec<_> = unsafe { truncated.row_unchecked(1).into_iter().collect() };
230 assert_eq!(row1, vec![9, 10]);
231
232 assert!(truncated.get(0, 2).is_none()); assert!(truncated.get(2, 0).is_none()); assert!(truncated.row(2).is_none()); assert!(truncated.row_slice(2).is_none()); }
237
238 #[test]
239 fn test_truncate_to_zero_width() {
240 let inner = RowMajorMatrix::new(vec![11, 12, 13], 3);
242
243 let truncated = HorizontallyTruncated::new(inner, 0).unwrap();
245
246 assert_eq!(truncated.width(), 0);
247 assert_eq!(truncated.height(), 1);
248
249 assert!(truncated.row(0).unwrap().into_iter().next().is_none());
251
252 assert!(truncated.get(0, 0).is_none()); assert!(truncated.get(1, 0).is_none()); assert!(truncated.row(1).is_none()); assert!(truncated.row_slice(1).is_none()); }
257
258 #[test]
259 fn test_invalid_truncation_width() {
260 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
264
265 assert!(HorizontallyTruncated::new(inner, 5).is_none());
267 }
268
269 #[test]
270 fn test_column_range_middle() {
271 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], 5);
276
277 let view = HorizontallyTruncated::new_with_range(inner, 1..4).unwrap();
279
280 assert_eq!(view.width(), 3);
282
283 assert_eq!(view.height(), 3);
285
286 assert_eq!(view.get(0, 0), Some(2)); assert_eq!(view.get(0, 1), Some(3)); assert_eq!(view.get(0, 2), Some(4)); assert_eq!(view.get(1, 0), Some(7)); assert_eq!(view.get(2, 2), Some(14)); unsafe {
294 assert_eq!(view.get_unchecked(1, 1), 8); assert_eq!(view.get_unchecked(2, 0), 12); }
297
298 let row0: Vec<_> = view.row(0).unwrap().into_iter().collect();
300 assert_eq!(row0, vec![2, 3, 4]);
301
302 let row1: Vec<_> = view.row(1).unwrap().into_iter().collect();
304 assert_eq!(row1, vec![7, 8, 9]);
305
306 unsafe {
307 let row2: Vec<_> = view.row_unchecked(2).into_iter().collect();
309 assert_eq!(row2, vec![12, 13, 14]);
310
311 let row1_subseq: Vec<_> = view.row_subseq_unchecked(1, 1, 3).into_iter().collect();
313 assert_eq!(row1_subseq, vec![8, 9]);
314 }
315
316 assert!(view.get(0, 3).is_none()); assert!(view.get(3, 0).is_none()); let as_matrix = view.to_row_major_matrix();
322
323 let expected = RowMajorMatrix::new(vec![2, 3, 4, 7, 8, 9, 12, 13, 14], 3);
328
329 assert_eq!(as_matrix, expected);
330 }
331
332 #[test]
333 fn test_column_range_end() {
334 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 4);
338
339 let view = HorizontallyTruncated::new_with_range(inner, 2..4).unwrap();
341
342 assert_eq!(view.width(), 2);
343 assert_eq!(view.height(), 2);
344
345 let row0: Vec<_> = view.row(0).unwrap().into_iter().collect();
347 assert_eq!(row0, vec![3, 4]);
348
349 let row1: Vec<_> = view.row(1).unwrap().into_iter().collect();
351 assert_eq!(row1, vec![7, 8]);
352
353 assert_eq!(view.get(0, 0), Some(3));
354 assert_eq!(view.get(1, 1), Some(8));
355 }
356
357 #[test]
358 fn test_column_range_single_column() {
359 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 4);
364
365 let view = HorizontallyTruncated::new_with_range(inner, 2..3).unwrap();
367
368 assert_eq!(view.width(), 1);
369 assert_eq!(view.height(), 3);
370
371 assert_eq!(view.get(0, 0), Some(3));
372 assert_eq!(view.get(1, 0), Some(7));
373 assert_eq!(view.get(2, 0), Some(11));
374
375 let row0: Vec<_> = view.row(0).unwrap().into_iter().collect();
377 assert_eq!(row0, vec![3]);
378 }
379
380 #[test]
381 fn test_column_range_empty() {
382 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
386
387 let view = HorizontallyTruncated::new_with_range(inner, 2..2).unwrap();
389
390 assert_eq!(view.width(), 0);
391 assert_eq!(view.height(), 2);
392
393 assert!(view.row(0).unwrap().into_iter().next().is_none());
395 }
396
397 #[test]
398 fn test_invalid_column_range() {
399 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
403
404 assert!(HorizontallyTruncated::new_with_range(inner, 1..5).is_none());
406 }
407}