1use core::marker::PhantomData;
2use core::ops::Range;
3
4use crate::Matrix;
5
6pub struct HorizontallyTruncated<T, Inner> {
12 inner: Inner,
14 column_range: Range<usize>,
16 _phantom: PhantomData<T>,
18}
19
20impl<T, Inner: Matrix<T>> HorizontallyTruncated<T, Inner>
21where
22 T: Send + Sync + Clone,
23{
24 pub fn new(inner: Inner, truncated_width: usize) -> Option<Self> {
34 Self::new_with_range(inner, 0..truncated_width)
35 }
36
37 pub fn new_with_range(inner: Inner, column_range: Range<usize>) -> Option<Self> {
45 (column_range.end <= inner.width()).then(|| Self {
46 inner,
47 column_range,
48 _phantom: PhantomData,
49 })
50 }
51}
52
53impl<T, Inner> Matrix<T> for HorizontallyTruncated<T, Inner>
54where
55 T: Send + Sync + Clone,
56 Inner: Matrix<T>,
57{
58 #[inline(always)]
60 fn width(&self) -> usize {
61 self.column_range.len()
62 }
63
64 #[inline(always)]
66 fn height(&self) -> usize {
67 self.inner.height()
68 }
69
70 #[inline(always)]
71 unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
72 unsafe {
73 self.inner.get_unchecked(r, self.column_range.start + c)
77 }
78 }
79
80 unsafe fn row_unchecked(
81 &self,
82 r: usize,
83 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
84 unsafe {
85 self.inner
87 .row_subseq_unchecked(r, self.column_range.start, self.column_range.end)
88 }
89 }
90
91 unsafe fn row_subseq_unchecked(
92 &self,
93 r: usize,
94 start: usize,
95 end: usize,
96 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
97 unsafe {
98 self.inner.row_subseq_unchecked(
102 r,
103 self.column_range.start + start,
104 self.column_range.start + end,
105 )
106 }
107 }
108
109 unsafe fn row_subslice_unchecked(
110 &self,
111 r: usize,
112 start: usize,
113 end: usize,
114 ) -> impl core::ops::Deref<Target = [T]> {
115 unsafe {
116 self.inner.row_subslice_unchecked(
120 r,
121 self.column_range.start + start,
122 self.column_range.start + end,
123 )
124 }
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use alloc::vec;
131 use alloc::vec::Vec;
132
133 use super::*;
134 use crate::dense::RowMajorMatrix;
135
136 #[test]
137 fn test_truncate_width_by_one() {
138 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 4);
143
144 let truncated = HorizontallyTruncated::new(inner, 3).unwrap();
146
147 assert_eq!(truncated.width(), 3);
149
150 assert_eq!(truncated.height(), 3);
152
153 assert_eq!(truncated.get(0, 0), Some(1)); assert_eq!(truncated.get(1, 1), Some(6)); unsafe {
157 assert_eq!(truncated.get_unchecked(0, 1), 2); assert_eq!(truncated.get_unchecked(2, 2), 11); }
160
161 let row0: Vec<_> = truncated.row(0).unwrap().into_iter().collect();
163 assert_eq!(row0, vec![1, 2, 3]);
164 unsafe {
165 let row1: Vec<_> = truncated.row_unchecked(1).into_iter().collect();
167 assert_eq!(row1, vec![5, 6, 7]);
168
169 let row3_subset: Vec<_> = truncated
171 .row_subseq_unchecked(2, 1, 2)
172 .into_iter()
173 .collect();
174 assert_eq!(row3_subset, vec![10]);
175 }
176
177 unsafe {
178 let row1 = truncated.row_slice(1).unwrap();
179 assert_eq!(&*row1, &[5, 6, 7]);
180
181 let row2 = truncated.row_slice_unchecked(2);
182 assert_eq!(&*row2, &[9, 10, 11]);
183
184 let row0_subslice = truncated.row_subslice_unchecked(0, 0, 2);
185 assert_eq!(&*row0_subslice, &[1, 2]);
186 }
187
188 assert!(truncated.get(0, 3).is_none()); assert!(truncated.get(3, 0).is_none()); assert!(truncated.row(3).is_none()); assert!(truncated.row_slice(3).is_none()); let as_matrix = truncated.to_row_major_matrix();
195
196 let expected = RowMajorMatrix::new(vec![1, 2, 3, 5, 6, 7, 9, 10, 11], 3);
201
202 assert_eq!(as_matrix, expected);
203 }
204
205 #[test]
206 fn test_no_truncation() {
207 let inner = RowMajorMatrix::new(vec![7, 8, 9, 10], 2);
211
212 let truncated = HorizontallyTruncated::new(inner, 2).unwrap();
214
215 assert_eq!(truncated.width(), 2);
216 assert_eq!(truncated.height(), 2);
217 assert_eq!(truncated.get(0, 1).unwrap(), 8);
218 assert_eq!(truncated.get(1, 0).unwrap(), 9);
219
220 unsafe {
221 assert_eq!(truncated.get_unchecked(0, 0), 7);
222 assert_eq!(truncated.get_unchecked(1, 1), 10);
223 }
224
225 let row0: Vec<_> = truncated.row(0).unwrap().into_iter().collect();
226 assert_eq!(row0, vec![7, 8]);
227
228 let row1: Vec<_> = unsafe { truncated.row_unchecked(1).into_iter().collect() };
229 assert_eq!(row1, vec![9, 10]);
230
231 assert!(truncated.get(0, 2).is_none()); assert!(truncated.get(2, 0).is_none()); assert!(truncated.row(2).is_none()); assert!(truncated.row_slice(2).is_none()); }
236
237 #[test]
238 fn test_truncate_to_zero_width() {
239 let inner = RowMajorMatrix::new(vec![11, 12, 13], 3);
241
242 let truncated = HorizontallyTruncated::new(inner, 0).unwrap();
244
245 assert_eq!(truncated.width(), 0);
246 assert_eq!(truncated.height(), 1);
247
248 assert!(truncated.row(0).unwrap().into_iter().next().is_none());
250
251 assert!(truncated.get(0, 0).is_none()); assert!(truncated.get(1, 0).is_none()); assert!(truncated.row(1).is_none()); assert!(truncated.row_slice(1).is_none()); }
256
257 #[test]
258 fn test_invalid_truncation_width() {
259 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
263
264 assert!(HorizontallyTruncated::new(inner, 5).is_none());
266 }
267
268 #[test]
269 fn test_column_range_middle() {
270 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], 5);
275
276 let view = HorizontallyTruncated::new_with_range(inner, 1..4).unwrap();
278
279 assert_eq!(view.width(), 3);
281
282 assert_eq!(view.height(), 3);
284
285 assert_eq!(view.get(0, 0), Some(2)); assert_eq!(view.get(0, 1), Some(3)); assert_eq!(view.get(0, 2), Some(4)); assert_eq!(view.get(1, 0), Some(7)); assert_eq!(view.get(2, 2), Some(14)); unsafe {
293 assert_eq!(view.get_unchecked(1, 1), 8); assert_eq!(view.get_unchecked(2, 0), 12); }
296
297 let row0: Vec<_> = view.row(0).unwrap().into_iter().collect();
299 assert_eq!(row0, vec![2, 3, 4]);
300
301 let row1: Vec<_> = view.row(1).unwrap().into_iter().collect();
303 assert_eq!(row1, vec![7, 8, 9]);
304
305 unsafe {
306 let row2: Vec<_> = view.row_unchecked(2).into_iter().collect();
308 assert_eq!(row2, vec![12, 13, 14]);
309
310 let row1_subseq: Vec<_> = view.row_subseq_unchecked(1, 1, 3).into_iter().collect();
312 assert_eq!(row1_subseq, vec![8, 9]);
313 }
314
315 assert!(view.get(0, 3).is_none()); assert!(view.get(3, 0).is_none()); let as_matrix = view.to_row_major_matrix();
321
322 let expected = RowMajorMatrix::new(vec![2, 3, 4, 7, 8, 9, 12, 13, 14], 3);
327
328 assert_eq!(as_matrix, expected);
329 }
330
331 #[test]
332 fn test_column_range_end() {
333 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 4);
337
338 let view = HorizontallyTruncated::new_with_range(inner, 2..4).unwrap();
340
341 assert_eq!(view.width(), 2);
342 assert_eq!(view.height(), 2);
343
344 let row0: Vec<_> = view.row(0).unwrap().into_iter().collect();
346 assert_eq!(row0, vec![3, 4]);
347
348 let row1: Vec<_> = view.row(1).unwrap().into_iter().collect();
350 assert_eq!(row1, vec![7, 8]);
351
352 assert_eq!(view.get(0, 0), Some(3));
353 assert_eq!(view.get(1, 1), Some(8));
354 }
355
356 #[test]
357 fn test_column_range_single_column() {
358 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 4);
363
364 let view = HorizontallyTruncated::new_with_range(inner, 2..3).unwrap();
366
367 assert_eq!(view.width(), 1);
368 assert_eq!(view.height(), 3);
369
370 assert_eq!(view.get(0, 0), Some(3));
371 assert_eq!(view.get(1, 0), Some(7));
372 assert_eq!(view.get(2, 0), Some(11));
373
374 let row0: Vec<_> = view.row(0).unwrap().into_iter().collect();
376 assert_eq!(row0, vec![3]);
377 }
378
379 #[test]
380 fn test_column_range_empty() {
381 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
385
386 let view = HorizontallyTruncated::new_with_range(inner, 2..2).unwrap();
388
389 assert_eq!(view.width(), 0);
390 assert_eq!(view.height(), 2);
391
392 assert!(view.row(0).unwrap().into_iter().next().is_none());
394 }
395
396 #[test]
397 fn test_invalid_column_range() {
398 let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
402
403 assert!(HorizontallyTruncated::new_with_range(inner, 1..5).is_none());
405 }
406}