p3_matrix/
horizontally_truncated.rs

1use core::marker::PhantomData;
2use core::ops::Range;
3
4use crate::Matrix;
5
6/// A matrix wrapper that exposes a contiguous range of columns from an inner matrix.
7///
8/// This struct:
9/// - wraps another matrix,
10/// - restricts access to only the columns within the specified `column_range`.
11pub struct HorizontallyTruncated<T, Inner> {
12    /// The underlying full matrix being wrapped.
13    inner: Inner,
14    /// The range of columns to expose from the inner matrix.
15    column_range: Range<usize>,
16    /// Marker for the element type `T`, not used at runtime.
17    _phantom: PhantomData<T>,
18}
19
20impl<T, Inner: Matrix<T>> HorizontallyTruncated<T, Inner>
21where
22    T: Send + Sync + Clone,
23{
24    /// Construct a new horizontally truncated view of a matrix.
25    ///
26    /// # Arguments
27    /// - `inner`: The full inner matrix to be wrapped.
28    /// - `truncated_width`: The number of columns to expose from the start (must be ≤ `inner.width()`).
29    ///
30    /// This is equivalent to `new_with_range(inner, 0..truncated_width)`.
31    ///
32    /// Returns `None` if `truncated_width` is greater than the width of the inner matrix.
33    pub fn new(inner: Inner, truncated_width: usize) -> Option<Self> {
34        Self::new_with_range(inner, 0..truncated_width)
35    }
36
37    /// Construct a new view exposing a specific column range of a matrix.
38    ///
39    /// # Arguments
40    /// - `inner`: The full inner matrix to be wrapped.
41    /// - `column_range`: The range of columns to expose (must satisfy `column_range.end <= inner.width()`).
42    ///
43    /// Returns `None` if the column range extends beyond the width of the inner matrix.
44    pub fn new_with_range(inner: Inner, column_range: Range<usize>) -> Option<Self> {
45        (column_range.end <= inner.width()).then(|| Self {
46            inner,
47            column_range,
48            _phantom: PhantomData,
49        })
50    }
51}
52
53impl<T, Inner> Matrix<T> for HorizontallyTruncated<T, Inner>
54where
55    T: Send + Sync + Clone,
56    Inner: Matrix<T>,
57{
58    /// Returns the number of columns exposed by the truncated matrix.
59    #[inline(always)]
60    fn width(&self) -> usize {
61        self.column_range.len()
62    }
63
64    /// Returns the number of rows in the matrix (same as the inner matrix).
65    #[inline(always)]
66    fn height(&self) -> usize {
67        self.inner.height()
68    }
69
70    #[inline(always)]
71    unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
72        unsafe {
73            // Safety: The caller must ensure that `c < self.width()` and `r < self.height()`.
74            //
75            // We translate the column index by adding `column_range.start`.
76            self.inner.get_unchecked(r, self.column_range.start + c)
77        }
78    }
79
80    unsafe fn row_unchecked(
81        &self,
82        r: usize,
83    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
84        unsafe {
85            // Safety: The caller must ensure that `r < self.height()`.
86            self.inner
87                .row_subseq_unchecked(r, self.column_range.start, self.column_range.end)
88        }
89    }
90
91    unsafe fn row_subseq_unchecked(
92        &self,
93        r: usize,
94        start: usize,
95        end: usize,
96    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
97        unsafe {
98            // Safety: The caller must ensure that r < self.height() and start <= end <= self.width().
99            //
100            // We translate the column indices by adding `column_range.start`.
101            self.inner.row_subseq_unchecked(
102                r,
103                self.column_range.start + start,
104                self.column_range.start + end,
105            )
106        }
107    }
108
109    unsafe fn row_subslice_unchecked(
110        &self,
111        r: usize,
112        start: usize,
113        end: usize,
114    ) -> impl core::ops::Deref<Target = [T]> {
115        unsafe {
116            // Safety: The caller must ensure that `r < self.height()` and `start <= end <= self.width()`.
117            //
118            // We translate the column indices by adding `column_range.start`.
119            self.inner.row_subslice_unchecked(
120                r,
121                self.column_range.start + start,
122                self.column_range.start + end,
123            )
124        }
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use alloc::vec;
131    use alloc::vec::Vec;
132
133    use super::*;
134    use crate::dense::RowMajorMatrix;
135
136    #[test]
137    fn test_truncate_width_by_one() {
138        // Create a 3x4 matrix:
139        // [ 1  2  3  4]
140        // [ 5  6  7  8]
141        // [ 9 10 11 12]
142        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 4);
143
144        // Truncate to width 3.
145        let truncated = HorizontallyTruncated::new(inner, 3).unwrap();
146
147        // Width should be 3.
148        assert_eq!(truncated.width(), 3);
149
150        // Height remains unchanged.
151        assert_eq!(truncated.height(), 3);
152
153        // Check individual elements.
154        assert_eq!(truncated.get(0, 0), Some(1)); // row 0, col 0
155        assert_eq!(truncated.get(1, 1), Some(6)); // row 1, col 1
156        unsafe {
157            assert_eq!(truncated.get_unchecked(0, 1), 2); // row 0, col 1
158            assert_eq!(truncated.get_unchecked(2, 2), 11); // row 1, col 0
159        }
160
161        // Row 0: should return [1, 2, 3]
162        let row0: Vec<_> = truncated.row(0).unwrap().into_iter().collect();
163        assert_eq!(row0, vec![1, 2, 3]);
164        unsafe {
165            // Row 2: should return [5, 6, 7]
166            let row1: Vec<_> = truncated.row_unchecked(1).into_iter().collect();
167            assert_eq!(row1, vec![5, 6, 7]);
168
169            // Row 3: is equal to return [9, 10, 11]
170            let row3_subset: Vec<_> = truncated
171                .row_subseq_unchecked(2, 1, 2)
172                .into_iter()
173                .collect();
174            assert_eq!(row3_subset, vec![10]);
175        }
176
177        unsafe {
178            let row1 = truncated.row_slice(1).unwrap();
179            assert_eq!(&*row1, &[5, 6, 7]);
180
181            let row2 = truncated.row_slice_unchecked(2);
182            assert_eq!(&*row2, &[9, 10, 11]);
183
184            let row0_subslice = truncated.row_subslice_unchecked(0, 0, 2);
185            assert_eq!(&*row0_subslice, &[1, 2]);
186        }
187
188        assert!(truncated.get(0, 3).is_none()); // Width out of bounds
189        assert!(truncated.get(3, 0).is_none()); // Height out of bounds
190        assert!(truncated.row(3).is_none()); // Height out of bounds
191        assert!(truncated.row_slice(3).is_none()); // Height out of bounds
192
193        // Convert the truncated view to a RowMajorMatrix and check contents.
194        let as_matrix = truncated.to_row_major_matrix();
195
196        // The expected matrix after truncation:
197        // [1  2  3]
198        // [5  6  7]
199        // [9 10 11]
200        let expected = RowMajorMatrix::new(vec![1, 2, 3, 5, 6, 7, 9, 10, 11], 3);
201
202        assert_eq!(as_matrix, expected);
203    }
204
205    #[test]
206    fn test_no_truncation() {
207        // 2x2 matrix:
208        // [ 7  8 ]
209        // [ 9 10 ]
210        let inner = RowMajorMatrix::new(vec![7, 8, 9, 10], 2);
211
212        // Truncate to full width (no change).
213        let truncated = HorizontallyTruncated::new(inner, 2).unwrap();
214
215        assert_eq!(truncated.width(), 2);
216        assert_eq!(truncated.height(), 2);
217        assert_eq!(truncated.get(0, 1).unwrap(), 8);
218        assert_eq!(truncated.get(1, 0).unwrap(), 9);
219
220        unsafe {
221            assert_eq!(truncated.get_unchecked(0, 0), 7);
222            assert_eq!(truncated.get_unchecked(1, 1), 10);
223        }
224
225        let row0: Vec<_> = truncated.row(0).unwrap().into_iter().collect();
226        assert_eq!(row0, vec![7, 8]);
227
228        let row1: Vec<_> = unsafe { truncated.row_unchecked(1).into_iter().collect() };
229        assert_eq!(row1, vec![9, 10]);
230
231        assert!(truncated.get(0, 2).is_none()); // Width out of bounds
232        assert!(truncated.get(2, 0).is_none()); // Height out of bounds
233        assert!(truncated.row(2).is_none()); // Height out of bounds
234        assert!(truncated.row_slice(2).is_none()); // Height out of bounds
235    }
236
237    #[test]
238    fn test_truncate_to_zero_width() {
239        // 1x3 matrix: [11 12 13]
240        let inner = RowMajorMatrix::new(vec![11, 12, 13], 3);
241
242        // Truncate to width 0.
243        let truncated = HorizontallyTruncated::new(inner, 0).unwrap();
244
245        assert_eq!(truncated.width(), 0);
246        assert_eq!(truncated.height(), 1);
247
248        // Row should be empty.
249        assert!(truncated.row(0).unwrap().into_iter().next().is_none());
250
251        assert!(truncated.get(0, 0).is_none()); // Width out of bounds
252        assert!(truncated.get(1, 0).is_none()); // Height out of bounds
253        assert!(truncated.row(1).is_none()); // Height out of bounds
254        assert!(truncated.row_slice(1).is_none()); // Height out of bounds
255    }
256
257    #[test]
258    fn test_invalid_truncation_width() {
259        // 2x2 matrix:
260        // [1 2]
261        // [3 4]
262        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
263
264        // Attempt to truncate beyond inner width (invalid).
265        assert!(HorizontallyTruncated::new(inner, 5).is_none());
266    }
267
268    #[test]
269    fn test_column_range_middle() {
270        // Create a 3x5 matrix:
271        // [ 1  2  3  4  5]
272        // [ 6  7  8  9 10]
273        // [11 12 13 14 15]
274        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], 5);
275
276        // Select columns 1..4 (columns 1, 2, 3).
277        let view = HorizontallyTruncated::new_with_range(inner, 1..4).unwrap();
278
279        // Width should be 3 (columns 1, 2, 3).
280        assert_eq!(view.width(), 3);
281
282        // Height remains unchanged.
283        assert_eq!(view.height(), 3);
284
285        // Check individual elements (column indices are relative to the view).
286        assert_eq!(view.get(0, 0), Some(2)); // row 0, col 0 -> inner col 1
287        assert_eq!(view.get(0, 1), Some(3)); // row 0, col 1 -> inner col 2
288        assert_eq!(view.get(0, 2), Some(4)); // row 0, col 2 -> inner col 3
289        assert_eq!(view.get(1, 0), Some(7)); // row 1, col 0 -> inner col 1
290        assert_eq!(view.get(2, 2), Some(14)); // row 2, col 2 -> inner col 3
291
292        unsafe {
293            assert_eq!(view.get_unchecked(1, 1), 8); // row 1, col 1 -> inner col 2
294            assert_eq!(view.get_unchecked(2, 0), 12); // row 2, col 0 -> inner col 1
295        }
296
297        // Row 0: should return [2, 3, 4]
298        let row0: Vec<_> = view.row(0).unwrap().into_iter().collect();
299        assert_eq!(row0, vec![2, 3, 4]);
300
301        // Row 1: should return [7, 8, 9]
302        let row1: Vec<_> = view.row(1).unwrap().into_iter().collect();
303        assert_eq!(row1, vec![7, 8, 9]);
304
305        unsafe {
306            // Row 2: should return [12, 13, 14]
307            let row2: Vec<_> = view.row_unchecked(2).into_iter().collect();
308            assert_eq!(row2, vec![12, 13, 14]);
309
310            // Subsequence of row 1, cols 1..3 (view indices) -> [8, 9]
311            let row1_subseq: Vec<_> = view.row_subseq_unchecked(1, 1, 3).into_iter().collect();
312            assert_eq!(row1_subseq, vec![8, 9]);
313        }
314
315        // Out of bounds checks.
316        assert!(view.get(0, 3).is_none()); // Width out of bounds
317        assert!(view.get(3, 0).is_none()); // Height out of bounds
318
319        // Convert the view to a RowMajorMatrix and check contents.
320        let as_matrix = view.to_row_major_matrix();
321
322        // The expected matrix after selecting columns 1..4:
323        // [2  3  4]
324        // [7  8  9]
325        // [12 13 14]
326        let expected = RowMajorMatrix::new(vec![2, 3, 4, 7, 8, 9, 12, 13, 14], 3);
327
328        assert_eq!(as_matrix, expected);
329    }
330
331    #[test]
332    fn test_column_range_end() {
333        // Create a 2x4 matrix:
334        // [1 2 3 4]
335        // [5 6 7 8]
336        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 4);
337
338        // Select columns 2..4 (columns 2, 3).
339        let view = HorizontallyTruncated::new_with_range(inner, 2..4).unwrap();
340
341        assert_eq!(view.width(), 2);
342        assert_eq!(view.height(), 2);
343
344        // Row 0: should return [3, 4]
345        let row0: Vec<_> = view.row(0).unwrap().into_iter().collect();
346        assert_eq!(row0, vec![3, 4]);
347
348        // Row 1: should return [7, 8]
349        let row1: Vec<_> = view.row(1).unwrap().into_iter().collect();
350        assert_eq!(row1, vec![7, 8]);
351
352        assert_eq!(view.get(0, 0), Some(3));
353        assert_eq!(view.get(1, 1), Some(8));
354    }
355
356    #[test]
357    fn test_column_range_single_column() {
358        // Create a 3x4 matrix:
359        // [1 2 3 4]
360        // [5 6 7 8]
361        // [9 10 11 12]
362        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 4);
363
364        // Select only column 2.
365        let view = HorizontallyTruncated::new_with_range(inner, 2..3).unwrap();
366
367        assert_eq!(view.width(), 1);
368        assert_eq!(view.height(), 3);
369
370        assert_eq!(view.get(0, 0), Some(3));
371        assert_eq!(view.get(1, 0), Some(7));
372        assert_eq!(view.get(2, 0), Some(11));
373
374        // Row 0: should return [3]
375        let row0: Vec<_> = view.row(0).unwrap().into_iter().collect();
376        assert_eq!(row0, vec![3]);
377    }
378
379    #[test]
380    fn test_column_range_empty() {
381        // Create a 2x3 matrix:
382        // [1 2 3]
383        // [4 5 6]
384        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
385
386        // Select empty range (2..2).
387        let view = HorizontallyTruncated::new_with_range(inner, 2..2).unwrap();
388
389        assert_eq!(view.width(), 0);
390        assert_eq!(view.height(), 2);
391
392        // Row should be empty.
393        assert!(view.row(0).unwrap().into_iter().next().is_none());
394    }
395
396    #[test]
397    fn test_invalid_column_range() {
398        // Create a 2x3 matrix:
399        // [1 2 3]
400        // [4 5 6]
401        let inner = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
402
403        // Attempt to select columns 1..5 (extends beyond width).
404        assert!(HorizontallyTruncated::new_with_range(inner, 1..5).is_none());
405    }
406}