Skip to main content

p3_matrix/
horizontally_truncated.rs

1use core::marker::PhantomData;
2use core::ops::Range;
3
4use crate::Matrix;
5use crate::bitrev::BitReversibleMatrix;
6
7/// A matrix wrapper that exposes a contiguous range of columns from an inner matrix.
8///
9/// This struct:
10/// - wraps another matrix,
11/// - restricts access to only the columns within the specified `column_range`.
12#[derive(Clone)]
13pub struct HorizontallyTruncated<T, Inner> {
14    /// The underlying full matrix being wrapped.
15    inner: Inner,
16    /// The range of columns to expose from the inner matrix.
17    column_range: Range<usize>,
18    /// Marker for the element type `T`, not used at runtime.
19    _phantom: PhantomData<T>,
20}
21
22impl<T, Inner: Matrix<T>> HorizontallyTruncated<T, Inner>
23where
24    T: Send + Sync + Clone,
25{
26    /// Construct a new horizontally truncated view of a matrix.
27    ///
28    /// # Arguments
29    /// - `inner`: The full inner matrix to be wrapped.
30    /// - `truncated_width`: The number of columns to expose from the start (must be ≤ `inner.width()`).
31    ///
32    /// This is equivalent to `new_with_range(inner, 0..truncated_width)`.
33    ///
34    /// Returns `None` if `truncated_width` is greater than the width of the inner matrix.
35    pub fn new(inner: Inner, truncated_width: usize) -> Option<Self> {
36        Self::new_with_range(inner, 0..truncated_width)
37    }
38
39    /// Construct a new view exposing a specific column range of a matrix.
40    ///
41    /// # Arguments
42    /// - `inner`: The full inner matrix to be wrapped.
43    /// - `column_range`: The columns to expose, satisfying `start <= end <= inner.width()`.
44    ///
45    /// Returns `None` if the range is inverted (`start > end`) or extends beyond the inner width.
46    pub fn new_with_range(inner: Inner, column_range: Range<usize>) -> Option<Self> {
47        // Accept the range only when:
48        // - `start <= end`: the range is not inverted.
49        // - `end <= inner.width()`: it stays within the inner matrix.
50        let valid = column_range.start <= column_range.end && column_range.end <= inner.width();
51        valid.then(|| Self {
52            inner,
53            column_range,
54            _phantom: PhantomData,
55        })
56    }
57}
58
59impl<T, Inner> Matrix<T> for HorizontallyTruncated<T, Inner>
60where
61    T: Send + Sync + Clone,
62    Inner: Matrix<T>,
63{
64    /// Returns the number of columns exposed by the truncated matrix.
65    #[inline(always)]
66    fn width(&self) -> usize {
67        self.column_range.len()
68    }
69
70    /// Returns the number of rows in the matrix (same as the inner matrix).
71    #[inline(always)]
72    fn height(&self) -> usize {
73        self.inner.height()
74    }
75
76    #[inline(always)]
77    unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
78        unsafe {
79            // Safety: The caller must ensure that `c < self.width()` and `r < self.height()`.
80            //
81            // We translate the column index by adding `column_range.start`.
82            self.inner.get_unchecked(r, self.column_range.start + c)
83        }
84    }
85
86    unsafe fn row_unchecked(
87        &self,
88        r: usize,
89    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
90        unsafe {
91            // Safety: The caller must ensure that `r < self.height()`.
92            self.inner
93                .row_subseq_unchecked(r, self.column_range.start, self.column_range.end)
94        }
95    }
96
97    unsafe fn row_subseq_unchecked(
98        &self,
99        r: usize,
100        start: usize,
101        end: usize,
102    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
103        unsafe {
104            // Safety: The caller must ensure that r < self.height() and start <= end <= self.width().
105            //
106            // We translate the column indices by adding `column_range.start`.
107            self.inner.row_subseq_unchecked(
108                r,
109                self.column_range.start + start,
110                self.column_range.start + end,
111            )
112        }
113    }
114
115    unsafe fn row_subslice_unchecked(
116        &self,
117        r: usize,
118        start: usize,
119        end: usize,
120    ) -> impl core::ops::Deref<Target = [T]> {
121        unsafe {
122            // Safety: The caller must ensure that `r < self.height()` and `start <= end <= self.width()`.
123            //
124            // We translate the column indices by adding `column_range.start`.
125            self.inner.row_subslice_unchecked(
126                r,
127                self.column_range.start + start,
128                self.column_range.start + end,
129            )
130        }
131    }
132}
133
134impl<T: Clone + Send + Sync, Inner: BitReversibleMatrix<T>> BitReversibleMatrix<T>
135    for HorizontallyTruncated<T, Inner>
136{
137    type BitRev = HorizontallyTruncated<T, Inner::BitRev>;
138
139    fn bit_reverse_rows(self) -> Self::BitRev {
140        HorizontallyTruncated {
141            inner: self.inner.bit_reverse_rows(),
142            column_range: self.column_range,
143            _phantom: PhantomData,
144        }
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use alloc::vec;
151    use alloc::vec::Vec;
152
153    use super::*;
154    use crate::dense::RowMajorMatrix;
155
156    #[test]
157    fn test_truncate_width_by_one() {
158        // Create a 3x4 matrix:
159        // [ 1  2  3  4]
160        // [ 5  6  7  8]
161        // [ 9 10 11 12]
162        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 4);
163
164        // Truncate to width 3.
165        let truncated = HorizontallyTruncated::new(inner, 3).unwrap();
166
167        // Width should be 3.
168        assert_eq!(truncated.width(), 3);
169
170        // Height remains unchanged.
171        assert_eq!(truncated.height(), 3);
172
173        // Check individual elements.
174        assert_eq!(truncated.get(0, 0), Some(1)); // row 0, col 0
175        assert_eq!(truncated.get(1, 1), Some(6)); // row 1, col 1
176        unsafe {
177            assert_eq!(truncated.get_unchecked(0, 1), 2); // row 0, col 1
178            assert_eq!(truncated.get_unchecked(2, 2), 11); // row 1, col 0
179        }
180
181        // Row 0: should return [1, 2, 3]
182        let row0: Vec<_> = truncated.row(0).unwrap().into_iter().collect();
183        assert_eq!(row0, vec![1, 2, 3]);
184        unsafe {
185            // Row 2: should return [5, 6, 7]
186            let row1: Vec<_> = truncated.row_unchecked(1).into_iter().collect();
187            assert_eq!(row1, vec![5, 6, 7]);
188
189            // Row 3: is equal to return [9, 10, 11]
190            let row3_subset: Vec<_> = truncated
191                .row_subseq_unchecked(2, 1, 2)
192                .into_iter()
193                .collect();
194            assert_eq!(row3_subset, vec![10]);
195        }
196
197        unsafe {
198            let row1 = truncated.row_slice(1).unwrap();
199            assert_eq!(&*row1, &[5, 6, 7]);
200
201            let row2 = truncated.row_slice_unchecked(2);
202            assert_eq!(&*row2, &[9, 10, 11]);
203
204            let row0_subslice = truncated.row_subslice_unchecked(0, 0, 2);
205            assert_eq!(&*row0_subslice, &[1, 2]);
206        }
207
208        assert!(truncated.get(0, 3).is_none()); // Width out of bounds
209        assert!(truncated.get(3, 0).is_none()); // Height out of bounds
210        assert!(truncated.row(3).is_none()); // Height out of bounds
211        assert!(truncated.row_slice(3).is_none()); // Height out of bounds
212
213        // Convert the truncated view to a RowMajorMatrix and check contents.
214        let as_matrix = truncated.to_row_major_matrix();
215
216        // The expected matrix after truncation:
217        // [1  2  3]
218        // [5  6  7]
219        // [9 10 11]
220        let expected = RowMajorMatrix::new(vec![1, 2, 3, 5, 6, 7, 9, 10, 11], 3);
221
222        assert_eq!(as_matrix, expected);
223    }
224
225    #[test]
226    fn test_no_truncation() {
227        // 2x2 matrix:
228        // [ 7  8 ]
229        // [ 9 10 ]
230        let inner = RowMajorMatrix::new(vec![7, 8, 9, 10], 2);
231
232        // Truncate to full width (no change).
233        let truncated = HorizontallyTruncated::new(inner, 2).unwrap();
234
235        assert_eq!(truncated.width(), 2);
236        assert_eq!(truncated.height(), 2);
237        assert_eq!(truncated.get(0, 1).unwrap(), 8);
238        assert_eq!(truncated.get(1, 0).unwrap(), 9);
239
240        unsafe {
241            assert_eq!(truncated.get_unchecked(0, 0), 7);
242            assert_eq!(truncated.get_unchecked(1, 1), 10);
243        }
244
245        let row0: Vec<_> = truncated.row(0).unwrap().into_iter().collect();
246        assert_eq!(row0, vec![7, 8]);
247
248        let row1: Vec<_> = unsafe { truncated.row_unchecked(1).into_iter().collect() };
249        assert_eq!(row1, vec![9, 10]);
250
251        assert!(truncated.get(0, 2).is_none()); // Width out of bounds
252        assert!(truncated.get(2, 0).is_none()); // Height out of bounds
253        assert!(truncated.row(2).is_none()); // Height out of bounds
254        assert!(truncated.row_slice(2).is_none()); // Height out of bounds
255    }
256
257    #[test]
258    fn test_truncate_to_zero_width() {
259        // 1x3 matrix: [11 12 13]
260        let inner = RowMajorMatrix::new(vec![11, 12, 13], 3);
261
262        // Truncate to width 0.
263        let truncated = HorizontallyTruncated::new(inner, 0).unwrap();
264
265        assert_eq!(truncated.width(), 0);
266        assert_eq!(truncated.height(), 1);
267
268        // Row should be empty.
269        assert!(truncated.row(0).unwrap().into_iter().next().is_none());
270
271        assert!(truncated.get(0, 0).is_none()); // Width out of bounds
272        assert!(truncated.get(1, 0).is_none()); // Height out of bounds
273        assert!(truncated.row(1).is_none()); // Height out of bounds
274        assert!(truncated.row_slice(1).is_none()); // Height out of bounds
275    }
276
277    #[test]
278    fn test_invalid_truncation_width() {
279        // 2x2 matrix:
280        // [1 2]
281        // [3 4]
282        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
283
284        // Attempt to truncate beyond inner width (invalid).
285        assert!(HorizontallyTruncated::new(inner, 5).is_none());
286    }
287
288    #[test]
289    fn test_column_range_middle() {
290        // Create a 3x5 matrix:
291        // [ 1  2  3  4  5]
292        // [ 6  7  8  9 10]
293        // [11 12 13 14 15]
294        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], 5);
295
296        // Select columns 1..4 (columns 1, 2, 3).
297        let view = HorizontallyTruncated::new_with_range(inner, 1..4).unwrap();
298
299        // Width should be 3 (columns 1, 2, 3).
300        assert_eq!(view.width(), 3);
301
302        // Height remains unchanged.
303        assert_eq!(view.height(), 3);
304
305        // Check individual elements (column indices are relative to the view).
306        assert_eq!(view.get(0, 0), Some(2)); // row 0, col 0 -> inner col 1
307        assert_eq!(view.get(0, 1), Some(3)); // row 0, col 1 -> inner col 2
308        assert_eq!(view.get(0, 2), Some(4)); // row 0, col 2 -> inner col 3
309        assert_eq!(view.get(1, 0), Some(7)); // row 1, col 0 -> inner col 1
310        assert_eq!(view.get(2, 2), Some(14)); // row 2, col 2 -> inner col 3
311
312        unsafe {
313            assert_eq!(view.get_unchecked(1, 1), 8); // row 1, col 1 -> inner col 2
314            assert_eq!(view.get_unchecked(2, 0), 12); // row 2, col 0 -> inner col 1
315        }
316
317        // Row 0: should return [2, 3, 4]
318        let row0: Vec<_> = view.row(0).unwrap().into_iter().collect();
319        assert_eq!(row0, vec![2, 3, 4]);
320
321        // Row 1: should return [7, 8, 9]
322        let row1: Vec<_> = view.row(1).unwrap().into_iter().collect();
323        assert_eq!(row1, vec![7, 8, 9]);
324
325        unsafe {
326            // Row 2: should return [12, 13, 14]
327            let row2: Vec<_> = view.row_unchecked(2).into_iter().collect();
328            assert_eq!(row2, vec![12, 13, 14]);
329
330            // Subsequence of row 1, cols 1..3 (view indices) -> [8, 9]
331            let row1_subseq: Vec<_> = view.row_subseq_unchecked(1, 1, 3).into_iter().collect();
332            assert_eq!(row1_subseq, vec![8, 9]);
333        }
334
335        // Out of bounds checks.
336        assert!(view.get(0, 3).is_none()); // Width out of bounds
337        assert!(view.get(3, 0).is_none()); // Height out of bounds
338
339        // Convert the view to a RowMajorMatrix and check contents.
340        let as_matrix = view.to_row_major_matrix();
341
342        // The expected matrix after selecting columns 1..4:
343        // [2  3  4]
344        // [7  8  9]
345        // [12 13 14]
346        let expected = RowMajorMatrix::new(vec![2, 3, 4, 7, 8, 9, 12, 13, 14], 3);
347
348        assert_eq!(as_matrix, expected);
349    }
350
351    #[test]
352    fn test_column_range_end() {
353        // Create a 2x4 matrix:
354        // [1 2 3 4]
355        // [5 6 7 8]
356        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 4);
357
358        // Select columns 2..4 (columns 2, 3).
359        let view = HorizontallyTruncated::new_with_range(inner, 2..4).unwrap();
360
361        assert_eq!(view.width(), 2);
362        assert_eq!(view.height(), 2);
363
364        // Row 0: should return [3, 4]
365        let row0: Vec<_> = view.row(0).unwrap().into_iter().collect();
366        assert_eq!(row0, vec![3, 4]);
367
368        // Row 1: should return [7, 8]
369        let row1: Vec<_> = view.row(1).unwrap().into_iter().collect();
370        assert_eq!(row1, vec![7, 8]);
371
372        assert_eq!(view.get(0, 0), Some(3));
373        assert_eq!(view.get(1, 1), Some(8));
374    }
375
376    #[test]
377    fn test_column_range_single_column() {
378        // Create a 3x4 matrix:
379        // [1 2 3 4]
380        // [5 6 7 8]
381        // [9 10 11 12]
382        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 4);
383
384        // Select only column 2.
385        let view = HorizontallyTruncated::new_with_range(inner, 2..3).unwrap();
386
387        assert_eq!(view.width(), 1);
388        assert_eq!(view.height(), 3);
389
390        assert_eq!(view.get(0, 0), Some(3));
391        assert_eq!(view.get(1, 0), Some(7));
392        assert_eq!(view.get(2, 0), Some(11));
393
394        // Row 0: should return [3]
395        let row0: Vec<_> = view.row(0).unwrap().into_iter().collect();
396        assert_eq!(row0, vec![3]);
397    }
398
399    #[test]
400    fn test_column_range_empty() {
401        // Create a 2x3 matrix:
402        // [1 2 3]
403        // [4 5 6]
404        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
405
406        // Select empty range (2..2).
407        let view = HorizontallyTruncated::new_with_range(inner, 2..2).unwrap();
408
409        assert_eq!(view.width(), 0);
410        assert_eq!(view.height(), 2);
411
412        // Row should be empty.
413        assert!(view.row(0).unwrap().into_iter().next().is_none());
414    }
415
416    #[test]
417    fn test_invalid_column_range() {
418        // Create a 2x3 matrix:
419        // [1 2 3]
420        // [4 5 6]
421        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
422
423        // Attempt to select columns 1..5 (extends beyond width).
424        assert!(HorizontallyTruncated::new_with_range(inner, 1..5).is_none());
425    }
426
427    #[test]
428    fn test_inverted_column_range_is_rejected() {
429        // 2x3 matrix:
430        // [1 2 3]
431        // [4 5 6]
432        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
433
434        // An inverted range has `start > end` while `end` is still inside the inner width.
435        // Built explicitly so the literal does not trip the empty-range lint.
436        //
437        //     2..1: start = 2 > end = 1, yet end = 1 <= width = 3
438        //
439        // The constructor must reject it rather than return a view.
440        let inverted = Range { start: 2, end: 1 };
441        assert!(HorizontallyTruncated::new_with_range(inner, inverted).is_none());
442    }
443}