Skip to main content

p3_matrix/
horizontally_truncated.rs

1use core::marker::PhantomData;
2use core::ops::Range;
3
4use crate::Matrix;
5
6/// A matrix wrapper that exposes a contiguous range of columns from an inner matrix.
7///
8/// This struct:
9/// - wraps another matrix,
10/// - restricts access to only the columns within the specified `column_range`.
11#[derive(Clone)]
12pub struct HorizontallyTruncated<T, Inner> {
13    /// The underlying full matrix being wrapped.
14    inner: Inner,
15    /// The range of columns to expose from the inner matrix.
16    column_range: Range<usize>,
17    /// Marker for the element type `T`, not used at runtime.
18    _phantom: PhantomData<T>,
19}
20
21impl<T, Inner: Matrix<T>> HorizontallyTruncated<T, Inner>
22where
23    T: Send + Sync + Clone,
24{
25    /// Construct a new horizontally truncated view of a matrix.
26    ///
27    /// # Arguments
28    /// - `inner`: The full inner matrix to be wrapped.
29    /// - `truncated_width`: The number of columns to expose from the start (must be ≤ `inner.width()`).
30    ///
31    /// This is equivalent to `new_with_range(inner, 0..truncated_width)`.
32    ///
33    /// Returns `None` if `truncated_width` is greater than the width of the inner matrix.
34    pub fn new(inner: Inner, truncated_width: usize) -> Option<Self> {
35        Self::new_with_range(inner, 0..truncated_width)
36    }
37
38    /// Construct a new view exposing a specific column range of a matrix.
39    ///
40    /// # Arguments
41    /// - `inner`: The full inner matrix to be wrapped.
42    /// - `column_range`: The range of columns to expose (must satisfy `column_range.end <= inner.width()`).
43    ///
44    /// Returns `None` if the column range extends beyond the width of the inner matrix.
45    pub fn new_with_range(inner: Inner, column_range: Range<usize>) -> Option<Self> {
46        (column_range.end <= inner.width()).then(|| Self {
47            inner,
48            column_range,
49            _phantom: PhantomData,
50        })
51    }
52}
53
54impl<T, Inner> Matrix<T> for HorizontallyTruncated<T, Inner>
55where
56    T: Send + Sync + Clone,
57    Inner: Matrix<T>,
58{
59    /// Returns the number of columns exposed by the truncated matrix.
60    #[inline(always)]
61    fn width(&self) -> usize {
62        self.column_range.len()
63    }
64
65    /// Returns the number of rows in the matrix (same as the inner matrix).
66    #[inline(always)]
67    fn height(&self) -> usize {
68        self.inner.height()
69    }
70
71    #[inline(always)]
72    unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
73        unsafe {
74            // Safety: The caller must ensure that `c < self.width()` and `r < self.height()`.
75            //
76            // We translate the column index by adding `column_range.start`.
77            self.inner.get_unchecked(r, self.column_range.start + c)
78        }
79    }
80
81    unsafe fn row_unchecked(
82        &self,
83        r: usize,
84    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
85        unsafe {
86            // Safety: The caller must ensure that `r < self.height()`.
87            self.inner
88                .row_subseq_unchecked(r, self.column_range.start, self.column_range.end)
89        }
90    }
91
92    unsafe fn row_subseq_unchecked(
93        &self,
94        r: usize,
95        start: usize,
96        end: usize,
97    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
98        unsafe {
99            // Safety: The caller must ensure that r < self.height() and start <= end <= self.width().
100            //
101            // We translate the column indices by adding `column_range.start`.
102            self.inner.row_subseq_unchecked(
103                r,
104                self.column_range.start + start,
105                self.column_range.start + end,
106            )
107        }
108    }
109
110    unsafe fn row_subslice_unchecked(
111        &self,
112        r: usize,
113        start: usize,
114        end: usize,
115    ) -> impl core::ops::Deref<Target = [T]> {
116        unsafe {
117            // Safety: The caller must ensure that `r < self.height()` and `start <= end <= self.width()`.
118            //
119            // We translate the column indices by adding `column_range.start`.
120            self.inner.row_subslice_unchecked(
121                r,
122                self.column_range.start + start,
123                self.column_range.start + end,
124            )
125        }
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use alloc::vec;
132    use alloc::vec::Vec;
133
134    use super::*;
135    use crate::dense::RowMajorMatrix;
136
137    #[test]
138    fn test_truncate_width_by_one() {
139        // Create a 3x4 matrix:
140        // [ 1  2  3  4]
141        // [ 5  6  7  8]
142        // [ 9 10 11 12]
143        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 4);
144
145        // Truncate to width 3.
146        let truncated = HorizontallyTruncated::new(inner, 3).unwrap();
147
148        // Width should be 3.
149        assert_eq!(truncated.width(), 3);
150
151        // Height remains unchanged.
152        assert_eq!(truncated.height(), 3);
153
154        // Check individual elements.
155        assert_eq!(truncated.get(0, 0), Some(1)); // row 0, col 0
156        assert_eq!(truncated.get(1, 1), Some(6)); // row 1, col 1
157        unsafe {
158            assert_eq!(truncated.get_unchecked(0, 1), 2); // row 0, col 1
159            assert_eq!(truncated.get_unchecked(2, 2), 11); // row 1, col 0
160        }
161
162        // Row 0: should return [1, 2, 3]
163        let row0: Vec<_> = truncated.row(0).unwrap().into_iter().collect();
164        assert_eq!(row0, vec![1, 2, 3]);
165        unsafe {
166            // Row 2: should return [5, 6, 7]
167            let row1: Vec<_> = truncated.row_unchecked(1).into_iter().collect();
168            assert_eq!(row1, vec![5, 6, 7]);
169
170            // Row 3: is equal to return [9, 10, 11]
171            let row3_subset: Vec<_> = truncated
172                .row_subseq_unchecked(2, 1, 2)
173                .into_iter()
174                .collect();
175            assert_eq!(row3_subset, vec![10]);
176        }
177
178        unsafe {
179            let row1 = truncated.row_slice(1).unwrap();
180            assert_eq!(&*row1, &[5, 6, 7]);
181
182            let row2 = truncated.row_slice_unchecked(2);
183            assert_eq!(&*row2, &[9, 10, 11]);
184
185            let row0_subslice = truncated.row_subslice_unchecked(0, 0, 2);
186            assert_eq!(&*row0_subslice, &[1, 2]);
187        }
188
189        assert!(truncated.get(0, 3).is_none()); // Width out of bounds
190        assert!(truncated.get(3, 0).is_none()); // Height out of bounds
191        assert!(truncated.row(3).is_none()); // Height out of bounds
192        assert!(truncated.row_slice(3).is_none()); // Height out of bounds
193
194        // Convert the truncated view to a RowMajorMatrix and check contents.
195        let as_matrix = truncated.to_row_major_matrix();
196
197        // The expected matrix after truncation:
198        // [1  2  3]
199        // [5  6  7]
200        // [9 10 11]
201        let expected = RowMajorMatrix::new(vec![1, 2, 3, 5, 6, 7, 9, 10, 11], 3);
202
203        assert_eq!(as_matrix, expected);
204    }
205
206    #[test]
207    fn test_no_truncation() {
208        // 2x2 matrix:
209        // [ 7  8 ]
210        // [ 9 10 ]
211        let inner = RowMajorMatrix::new(vec![7, 8, 9, 10], 2);
212
213        // Truncate to full width (no change).
214        let truncated = HorizontallyTruncated::new(inner, 2).unwrap();
215
216        assert_eq!(truncated.width(), 2);
217        assert_eq!(truncated.height(), 2);
218        assert_eq!(truncated.get(0, 1).unwrap(), 8);
219        assert_eq!(truncated.get(1, 0).unwrap(), 9);
220
221        unsafe {
222            assert_eq!(truncated.get_unchecked(0, 0), 7);
223            assert_eq!(truncated.get_unchecked(1, 1), 10);
224        }
225
226        let row0: Vec<_> = truncated.row(0).unwrap().into_iter().collect();
227        assert_eq!(row0, vec![7, 8]);
228
229        let row1: Vec<_> = unsafe { truncated.row_unchecked(1).into_iter().collect() };
230        assert_eq!(row1, vec![9, 10]);
231
232        assert!(truncated.get(0, 2).is_none()); // Width out of bounds
233        assert!(truncated.get(2, 0).is_none()); // Height out of bounds
234        assert!(truncated.row(2).is_none()); // Height out of bounds
235        assert!(truncated.row_slice(2).is_none()); // Height out of bounds
236    }
237
238    #[test]
239    fn test_truncate_to_zero_width() {
240        // 1x3 matrix: [11 12 13]
241        let inner = RowMajorMatrix::new(vec![11, 12, 13], 3);
242
243        // Truncate to width 0.
244        let truncated = HorizontallyTruncated::new(inner, 0).unwrap();
245
246        assert_eq!(truncated.width(), 0);
247        assert_eq!(truncated.height(), 1);
248
249        // Row should be empty.
250        assert!(truncated.row(0).unwrap().into_iter().next().is_none());
251
252        assert!(truncated.get(0, 0).is_none()); // Width out of bounds
253        assert!(truncated.get(1, 0).is_none()); // Height out of bounds
254        assert!(truncated.row(1).is_none()); // Height out of bounds
255        assert!(truncated.row_slice(1).is_none()); // Height out of bounds
256    }
257
258    #[test]
259    fn test_invalid_truncation_width() {
260        // 2x2 matrix:
261        // [1 2]
262        // [3 4]
263        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
264
265        // Attempt to truncate beyond inner width (invalid).
266        assert!(HorizontallyTruncated::new(inner, 5).is_none());
267    }
268
269    #[test]
270    fn test_column_range_middle() {
271        // Create a 3x5 matrix:
272        // [ 1  2  3  4  5]
273        // [ 6  7  8  9 10]
274        // [11 12 13 14 15]
275        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], 5);
276
277        // Select columns 1..4 (columns 1, 2, 3).
278        let view = HorizontallyTruncated::new_with_range(inner, 1..4).unwrap();
279
280        // Width should be 3 (columns 1, 2, 3).
281        assert_eq!(view.width(), 3);
282
283        // Height remains unchanged.
284        assert_eq!(view.height(), 3);
285
286        // Check individual elements (column indices are relative to the view).
287        assert_eq!(view.get(0, 0), Some(2)); // row 0, col 0 -> inner col 1
288        assert_eq!(view.get(0, 1), Some(3)); // row 0, col 1 -> inner col 2
289        assert_eq!(view.get(0, 2), Some(4)); // row 0, col 2 -> inner col 3
290        assert_eq!(view.get(1, 0), Some(7)); // row 1, col 0 -> inner col 1
291        assert_eq!(view.get(2, 2), Some(14)); // row 2, col 2 -> inner col 3
292
293        unsafe {
294            assert_eq!(view.get_unchecked(1, 1), 8); // row 1, col 1 -> inner col 2
295            assert_eq!(view.get_unchecked(2, 0), 12); // row 2, col 0 -> inner col 1
296        }
297
298        // Row 0: should return [2, 3, 4]
299        let row0: Vec<_> = view.row(0).unwrap().into_iter().collect();
300        assert_eq!(row0, vec![2, 3, 4]);
301
302        // Row 1: should return [7, 8, 9]
303        let row1: Vec<_> = view.row(1).unwrap().into_iter().collect();
304        assert_eq!(row1, vec![7, 8, 9]);
305
306        unsafe {
307            // Row 2: should return [12, 13, 14]
308            let row2: Vec<_> = view.row_unchecked(2).into_iter().collect();
309            assert_eq!(row2, vec![12, 13, 14]);
310
311            // Subsequence of row 1, cols 1..3 (view indices) -> [8, 9]
312            let row1_subseq: Vec<_> = view.row_subseq_unchecked(1, 1, 3).into_iter().collect();
313            assert_eq!(row1_subseq, vec![8, 9]);
314        }
315
316        // Out of bounds checks.
317        assert!(view.get(0, 3).is_none()); // Width out of bounds
318        assert!(view.get(3, 0).is_none()); // Height out of bounds
319
320        // Convert the view to a RowMajorMatrix and check contents.
321        let as_matrix = view.to_row_major_matrix();
322
323        // The expected matrix after selecting columns 1..4:
324        // [2  3  4]
325        // [7  8  9]
326        // [12 13 14]
327        let expected = RowMajorMatrix::new(vec![2, 3, 4, 7, 8, 9, 12, 13, 14], 3);
328
329        assert_eq!(as_matrix, expected);
330    }
331
332    #[test]
333    fn test_column_range_end() {
334        // Create a 2x4 matrix:
335        // [1 2 3 4]
336        // [5 6 7 8]
337        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 4);
338
339        // Select columns 2..4 (columns 2, 3).
340        let view = HorizontallyTruncated::new_with_range(inner, 2..4).unwrap();
341
342        assert_eq!(view.width(), 2);
343        assert_eq!(view.height(), 2);
344
345        // Row 0: should return [3, 4]
346        let row0: Vec<_> = view.row(0).unwrap().into_iter().collect();
347        assert_eq!(row0, vec![3, 4]);
348
349        // Row 1: should return [7, 8]
350        let row1: Vec<_> = view.row(1).unwrap().into_iter().collect();
351        assert_eq!(row1, vec![7, 8]);
352
353        assert_eq!(view.get(0, 0), Some(3));
354        assert_eq!(view.get(1, 1), Some(8));
355    }
356
357    #[test]
358    fn test_column_range_single_column() {
359        // Create a 3x4 matrix:
360        // [1 2 3 4]
361        // [5 6 7 8]
362        // [9 10 11 12]
363        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 4);
364
365        // Select only column 2.
366        let view = HorizontallyTruncated::new_with_range(inner, 2..3).unwrap();
367
368        assert_eq!(view.width(), 1);
369        assert_eq!(view.height(), 3);
370
371        assert_eq!(view.get(0, 0), Some(3));
372        assert_eq!(view.get(1, 0), Some(7));
373        assert_eq!(view.get(2, 0), Some(11));
374
375        // Row 0: should return [3]
376        let row0: Vec<_> = view.row(0).unwrap().into_iter().collect();
377        assert_eq!(row0, vec![3]);
378    }
379
380    #[test]
381    fn test_column_range_empty() {
382        // Create a 2x3 matrix:
383        // [1 2 3]
384        // [4 5 6]
385        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
386
387        // Select empty range (2..2).
388        let view = HorizontallyTruncated::new_with_range(inner, 2..2).unwrap();
389
390        assert_eq!(view.width(), 0);
391        assert_eq!(view.height(), 2);
392
393        // Row should be empty.
394        assert!(view.row(0).unwrap().into_iter().next().is_none());
395    }
396
397    #[test]
398    fn test_invalid_column_range() {
399        // Create a 2x3 matrix:
400        // [1 2 3]
401        // [4 5 6]
402        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
403
404        // Attempt to select columns 1..5 (extends beyond width).
405        assert!(HorizontallyTruncated::new_with_range(inner, 1..5).is_none());
406    }
407}