easy_ml/matrices/views/
ranges.rs

1use crate::matrices::views::{DataLayout, MatrixMut, MatrixRef, NoInteriorMutability};
2use crate::matrices::{Column, Row};
3
4use std::marker::PhantomData;
5use std::num::NonZeroUsize;
6use std::ops::Range;
7
8/**
9 * A 2 dimensional range over a matrix, hiding the values **outside** the range from view.
10 *
11 * The entire source is still owned by the MatrixRange however, so this does not permit
12 * creating multiple mutable ranges into a single matrix even if they wouldn't overlap.
13 *
14 * For non overlapping mutable ranges into a single matrix see
15 * [`partition`](crate::matrices::Matrix::partition).
16 *
17 * See also: [MatrixMask](MatrixMask)
18 */
19#[derive(Clone, Debug)]
20pub struct MatrixRange<T, S> {
21    source: S,
22    rows: IndexRange,
23    columns: IndexRange,
24    _type: PhantomData<T>,
25}
26
27/**
28 * A 2 dimensional mask over a matrix, hiding the values **inside** the range from view.
29 *
30 * The entire source is still owned by the MatrixMask however, so this does not permit
31 * creating multiple mutable masks into a single matrix even if they wouldn't overlap.
32 *
33 * See also: [MatrixRange](MatrixRange)
34 */
35#[derive(Clone, Debug)]
36pub struct MatrixMask<T, S> {
37    source: S,
38    rows: IndexRange,
39    columns: IndexRange,
40    _type: PhantomData<T>,
41}
42
43impl<T, S> MatrixRange<T, S>
44where
45    S: MatrixRef<T>,
46{
47    /**
48     * Creates a new MatrixRange giving a view of only the data within the row and column
49     * [IndexRange](IndexRange)s.
50     *
51     * # Examples
52     *
53     * Creating a view and manipulating a matrix from it.
54     * ```
55     * use easy_ml::matrices::Matrix;
56     * use easy_ml::matrices::views::{MatrixView, MatrixRange};
57     * let mut matrix = Matrix::from(vec![
58     *     vec![ 2, 3, 4 ],
59     *     vec![ 5, 1, 8 ]]);
60     * {
61     *     let mut view = MatrixView::from(MatrixRange::from(&mut matrix, 0..1, 1..3));
62     *     assert_eq!(vec![3, 4], view.row_major_iter().collect::<Vec<_>>());
63     *     view.map_mut(|x| x + 10);
64     * }
65     * assert_eq!(matrix, Matrix::from(vec![
66     *     vec![ 2, 13, 14 ],
67     *     vec![ 5,  1,  8 ]]));
68     * ```
69     *
70     * Various ways to construct a MatrixRange
71     * ```
72     * use easy_ml::matrices::Matrix;
73     * use easy_ml::matrices::views::{IndexRange, MatrixRange};
74     * let matrix = Matrix::from(vec![vec![1]]);
75     * let index_range = MatrixRange::from(&matrix, IndexRange::new(0, 4), IndexRange::new(1, 3));
76     * let tuple = MatrixRange::from(&matrix, (0, 4), (1, 3));
77     * let array = MatrixRange::from(&matrix, [0, 4], [1, 3]);
78     * // Note std::ops::Range is start..end not start and length!
79     * let range = MatrixRange::from(&matrix, 0..4, 1..4);
80     * ```
81     *
82     * NOTE: In previous versions (<=1.8.1), this erroneously did not clip the IndexRange input to
83     * not exceed the rows and columns of the source, which led to the possibility to create
84     * MatrixRanges that reported a greater number of rows and columns in their shape than their
85     * actual data. This function will now correctly clip any ranges that exceed their sources.
86     */
87    pub fn from<R>(source: S, rows: R, columns: R) -> MatrixRange<T, S>
88    where
89        R: Into<IndexRange>,
90    {
91        let max_rows = source.view_rows();
92        let max_columns = source.view_columns();
93        MatrixRange {
94            source,
95            rows: {
96                let mut rows = rows.into();
97                rows.clip(max_rows);
98                rows
99            },
100            columns: {
101                let mut columns = columns.into();
102                columns.clip(max_columns);
103                columns
104            },
105            _type: PhantomData,
106        }
107    }
108
109    /**
110     * Consumes the MatrixRange, yielding the source it was created from.
111     */
112    #[allow(dead_code)]
113    pub fn source(self) -> S {
114        self.source
115    }
116
117    /**
118     * Gives a reference to the MatrixRange's source (in which the data is not clipped).
119     */
120    // # Safety
121    //
122    // Giving out a mutable reference to our source could allow it to be changed out from under us
123    // and make our range checks invalid. However, since the source implements MatrixRef
124    // interior mutability is not allowed, so we can give out shared references without breaking
125    // our own integrity.
126    #[allow(dead_code)]
127    pub fn source_ref(&self) -> &S {
128        &self.source
129    }
130}
131
132impl<T, S> MatrixMask<T, S>
133where
134    S: MatrixRef<T>,
135{
136    /**
137     * Creates a new MatrixMask giving a view of only the data outside the row and column
138     * [IndexRange](IndexRange)s. If the index range given for rows or columns exceeds the
139     * size of the matrix, they will be clipped to fit the actual size without an error.
140     *
141     * # Examples
142     *
143     * Creating a view and manipulating a matrix from it.
144     * ```
145     * use easy_ml::matrices::Matrix;
146     * use easy_ml::matrices::views::{MatrixView, MatrixMask};
147     * let mut matrix = Matrix::from(vec![
148     *     vec![ 2, 3, 4 ],
149     *     vec![ 5, 1, 8 ]]);
150     * {
151     *     let mut view = MatrixView::from(MatrixMask::from(&mut matrix, 0..1, 2..3));
152     *     assert_eq!(vec![5, 1], view.row_major_iter().collect::<Vec<_>>());
153     *     view.map_mut(|x| x + 10);
154     * }
155     * assert_eq!(matrix, Matrix::from(vec![
156     *     vec![ 2,   3,  4 ],
157     *     vec![ 15, 11,  8 ]]));
158     * ```
159     *
160     * Various ways to construct a MatrixMask
161     * ```
162     * use easy_ml::matrices::Matrix;
163     * use easy_ml::matrices::views::{IndexRange, MatrixMask};
164     * let matrix = Matrix::from(vec![vec![1]]);
165     * let index_range = MatrixMask::from(&matrix, IndexRange::new(0, 4), IndexRange::new(1, 3));
166     * let tuple = MatrixMask::from(&matrix, (0, 4), (1, 3));
167     * let array = MatrixMask::from(&matrix, [0, 4], [1, 3]);
168     * // Note std::ops::Range is start..end not start and length!
169     * let range = MatrixMask::from(&matrix, 0..4, 1..4);
170     * ```
171     */
172    pub fn from<R>(source: S, rows: R, columns: R) -> MatrixMask<T, S>
173    where
174        R: Into<IndexRange>,
175    {
176        let max_rows = source.view_rows();
177        let max_columns = source.view_columns();
178        MatrixMask {
179            source,
180            rows: {
181                let mut rows = rows.into();
182                rows.clip(max_rows);
183                rows
184            },
185            columns: {
186                let mut columns = columns.into();
187                columns.clip(max_columns);
188                columns
189            },
190            _type: PhantomData,
191        }
192    }
193
194    /**
195     * Creates a MatrixMask of this source that retains only the specified
196     * number of elements at both the start and end of the rows.
197     * If twice the provided number of elements for the rows exceeds the
198     * number of rows in the matrix, then all elements are retained. Similarly,
199     * passing None retains all elements.
200     *
201     * ```
202     * use std::num::NonZeroUsize;
203     * use easy_ml::matrices::Matrix;
204     * use easy_ml::matrices::views::{MatrixView, MatrixMask};
205     * let matrix = Matrix::from_flat_row_major((5, 5), (0..25).collect());
206     * let start_and_end = MatrixView::from(
207     *     MatrixMask::start_and_end_of_rows(
208     *         matrix, NonZeroUsize::new(1)
209     *     )
210     * );
211     * assert_eq!(
212     *     start_and_end,
213     *     Matrix::from_flat_row_major((2, 5), vec![
214     *          0,  1,  2,  3,  4,
215     *         20, 21, 22, 23, 24,
216     *     ])
217     * );
218     * ```
219     */
220    pub fn start_and_end_of_rows(source: S, retain: Option<NonZeroUsize>) -> MatrixMask<T, S> {
221        let rows = match retain {
222            None => IndexRange::new(0, 0),
223            Some(x) => {
224                let x = x.get();
225                let length = source.view_rows();
226                let retain_start = std::cmp::min(x, length - 1);
227                let retain_end = length.saturating_sub(x);
228                let mut range: IndexRange = (retain_start..retain_end).into();
229                range.clip(length - 1);
230                range
231            }
232        };
233        let columns = IndexRange::new(0, 0);
234        MatrixMask::from(source, rows, columns)
235    }
236
237    /**
238     * Creates a MatrixMask of this source that retains only the specified
239     * number of elements at both the start and end of the columns.
240     * If twice the provided number of elements for the columns exceeds the
241     * number of columns in the matrix, then all elements are retained. Similarly,
242     * passing None retains all elements.
243     *
244     * ```
245     * use std::num::NonZeroUsize;
246     * use easy_ml::matrices::Matrix;
247     * use easy_ml::matrices::views::{MatrixView, MatrixMask};
248     * let matrix = Matrix::from_flat_row_major((5, 5), (0..25).collect());
249     * let start_and_end = MatrixView::from(
250     *     MatrixMask::start_and_end_of_columns(
251     *         matrix, NonZeroUsize::new(1)
252     *     )
253     * );
254     * assert_eq!(
255     *     start_and_end,
256     *     Matrix::from_flat_row_major((5, 2), vec![
257     *          0,  4,
258     *          5,  9,
259     *         10, 14,
260     *         15, 19,
261     *         20, 24,
262     *     ])
263     * );
264     * ```
265     */
266    pub fn start_and_end_of_columns(source: S, retain: Option<NonZeroUsize>) -> MatrixMask<T, S> {
267        let rows = IndexRange::new(0, 0);
268        let columns = match retain {
269            None => IndexRange::new(0, 0),
270            Some(x) => {
271                let x = x.get();
272                let length = source.view_columns();
273                let retain_start = std::cmp::min(x, length - 1);
274                let retain_end = length.saturating_sub(x);
275                let mut range: IndexRange = (retain_start..retain_end).into();
276                range.clip(length - 1);
277                range
278            }
279        };
280        MatrixMask::from(source, rows, columns)
281    }
282
283    /**
284     * Consumes the MatrixMask, yielding the source it was created from.
285     */
286    #[allow(dead_code)]
287    pub fn source(self) -> S {
288        self.source
289    }
290
291    /**
292     * Gives a reference to the MatrixMask's source (in which the data is not masked).
293     */
294    // # Safety
295    //
296    // Giving out a mutable reference to our source could allow it to be changed out from under us
297    // and make our mask checks invalid. However, since the source implements MatrixRef
298    // interior mutability is not allowed, so we can give out shared references without breaking
299    // our own integrity.
300    #[allow(dead_code)]
301    pub fn source_ref(&self) -> &S {
302        &self.source
303    }
304}
305
306/**
307 * A range bounded between `start` inclusive and `start + length` exclusive.
308 *
309 * # Examples
310 *
311 * Converting between [Range](std::ops::Range) and IndexRange.
312 * ```
313 * use std::ops::Range;
314 * use easy_ml::matrices::views::IndexRange;
315 * assert_eq!(IndexRange::new(3, 2), (3..5).into());
316 * assert_eq!(IndexRange::new(1, 5), (1..6).into());
317 * assert_eq!(IndexRange::new(0, 4), (0..4).into());
318 * ```
319 *
320 * Creating a Range
321 *
322 * ```
323 * use easy_ml::matrices::views::IndexRange;
324 * let range = IndexRange::new(3, 2);
325 * let also_range: IndexRange = (3, 2).into();
326 * let also_also_range: IndexRange = [3, 2].into();
327 * ```
328 *
329 * NB: You can construct an IndexRange where start+length exceeds isize::MAX or even
330 * usize::MAX, however matrices and tensors themselves cannot contain more than isize::MAX
331 * elements. Concerned readers should note that on a 64 bit computer this maximum
332 * value is 9,223,372,036,854,775,807 so running out of memory is likely to occur first.
333 */
334#[derive(Clone, Debug, Eq, PartialEq)]
335pub struct IndexRange {
336    pub(crate) start: usize,
337    pub(crate) length: usize,
338}
339
340impl IndexRange {
341    pub fn new(start: usize, length: usize) -> IndexRange {
342        IndexRange { start, length }
343    }
344
345    // TODO: If we make these public we need to disambiguate Range from Mask behaviour better
346    /**
347     * Maps from a coordinate space of the ith index accessible by this range to the actual index
348     * into the entire dimension's data.
349     */
350    #[inline]
351    pub(crate) fn map(&self, index: usize) -> Option<usize> {
352        if index < self.length {
353            Some(index + self.start)
354        } else {
355            None
356        }
357    }
358
359    // NOTE: This doesn't perform bounds checks, adding the length of the mask could push
360    // the index out of the valid bounds of the dimension it is for, but if we performed
361    // bounds checks here they would be redundant since performing the get with the masked index
362    // will bounds check if required
363    #[inline]
364    pub(crate) fn mask(&self, index: usize) -> usize {
365        if index < self.start {
366            index
367        } else {
368            index + self.length
369        }
370    }
371
372    // Clips the range or mask to not exceed an index. Note, this may yield 0 length ranges
373    // that have non zero starting positions, however map and mask will still calculate correctly.
374    pub(crate) fn clip(&mut self, max_index: usize) {
375        let end = self.start + self.length;
376        let end = std::cmp::min(end, max_index);
377        let length = end.saturating_sub(self.start);
378        self.length = length;
379    }
380}
381
382/**
383 * Converts from a range of start..end to an IndexRange of start and length
384 *
385 * NOTE: In previous versions (<=1.8.1) this did not saturate when attempting to subtract the
386 * start of the range from the end to calculate the length. It will now correctly produce an
387 * IndexRange with a length of 0 if the end is before or equal to the start.
388 */
389impl From<Range<usize>> for IndexRange {
390    fn from(range: Range<usize>) -> IndexRange {
391        IndexRange::new(range.start, range.end.saturating_sub(range.start))
392    }
393}
394
395/** Converts from an IndexRange of start and length to a range of start..end */
396impl From<IndexRange> for Range<usize> {
397    fn from(range: IndexRange) -> Range<usize> {
398        Range {
399            start: range.start,
400            end: range.start + range.length,
401        }
402    }
403}
404
405/**
406 * Converts from a tuple of start and length to an IndexRange
407 *
408 * NOTE: In previous versions (<=1.8.1), this was erroneously implemented as conversion from a
409 * tuple of start and end, not start and length as documented.
410 */
411impl From<(usize, usize)> for IndexRange {
412    fn from(range: (usize, usize)) -> IndexRange {
413        let (start, length) = range;
414        IndexRange::new(start, length)
415    }
416}
417
418/**
419 * Converts from an array of start and length to an IndexRange
420 *
421 * NOTE: In previous versions (<=1.8.1), this was erroneously implemented as conversion from an
422 * array of start and end, not start and length as documented.
423 */
424impl From<[usize; 2]> for IndexRange {
425    fn from(range: [usize; 2]) -> IndexRange {
426        let [start, length] = range;
427        IndexRange::new(start, length)
428    }
429}
430
431#[test]
432fn test_index_range_clipping() {
433    let mut range: IndexRange = (0..6).into();
434    range.clip(4);
435    assert_eq!(range, (0..4).into());
436    let mut range: IndexRange = (1..4).into();
437    range.clip(5);
438    assert_eq!(range, (1..4).into());
439    range.clip(2);
440    assert_eq!(range, (1..2).into());
441    let mut range: IndexRange = (3..5).into();
442    range.clip(2);
443    assert_eq!(range, (3..2).into());
444    assert_eq!(range.map(0), None);
445    assert_eq!(range.map(1), None);
446    assert_eq!(range.mask(0), 0);
447    assert_eq!(range.mask(1), 1);
448}
449
450// # Safety
451//
452// Since the MatrixRef we own must implement MatrixRef correctly, so do we by delegating to it,
453// as we don't introduce any interior mutability.
454/**
455 * A MatrixRange of a MatrixRef type implements MatrixRef.
456 */
457unsafe impl<T, S> MatrixRef<T> for MatrixRange<T, S>
458where
459    S: MatrixRef<T>,
460{
461    fn try_get_reference(&self, row: Row, column: Column) -> Option<&T> {
462        let row = self.rows.map(row)?;
463        let column = self.columns.map(column)?;
464        self.source.try_get_reference(row, column)
465    }
466
467    fn view_rows(&self) -> Row {
468        self.rows.length
469    }
470
471    fn view_columns(&self) -> Column {
472        self.columns.length
473    }
474
475    unsafe fn get_reference_unchecked(&self, row: Row, column: Column) -> &T {
476        unsafe {
477            // It is the caller's responsibiltiy to always call with row/column indexes in range,
478            // therefore the unwrap() case should never happen because on an arbitary MatrixRef
479            // it would be undefined behavior.
480            let row = self.rows.map(row).unwrap();
481            let column = self.columns.map(column).unwrap();
482            self.source.get_reference_unchecked(row, column)
483        }
484    }
485
486    fn data_layout(&self) -> DataLayout {
487        self.source.data_layout()
488    }
489}
490
491// # Safety
492//
493// Since the MatrixMut we own must implement MatrixMut correctly, so do we by delegating to it,
494// as we don't introduce any interior mutability.
495/**
496 * A MatrixRange of a MatrixMut type implements MatrixMut.
497 */
498unsafe impl<T, S> MatrixMut<T> for MatrixRange<T, S>
499where
500    S: MatrixMut<T>,
501{
502    fn try_get_reference_mut(&mut self, row: Row, column: Column) -> Option<&mut T> {
503        let row = self.rows.map(row)?;
504        let column = self.columns.map(column)?;
505        self.source.try_get_reference_mut(row, column)
506    }
507
508    unsafe fn get_reference_unchecked_mut(&mut self, row: Row, column: Column) -> &mut T {
509        unsafe {
510            // It is the caller's responsibility to always call with row/column indexes in range,
511            // therefore the unwrap() case should never happen because on an arbitary MatrixRef
512            // it would be undefined behavior.
513            let row = self.rows.map(row).unwrap();
514            let column = self.columns.map(column).unwrap();
515            self.source.get_reference_unchecked_mut(row, column)
516        }
517    }
518}
519
520// # Safety
521//
522// Since the NoInteriorMutability we own must implement NoInteriorMutability correctly, so
523// do we by delegating to it, as we don't introduce any interior mutability.
524/**
525 * A MatrixRange of a NoInteriorMutability type implements NoInteriorMutability.
526 */
527unsafe impl<T, S> NoInteriorMutability for MatrixRange<T, S> where S: NoInteriorMutability {}
528
529#[test]
530fn test_matrix_range_shape_clips() {
531    use crate::matrices::Matrix;
532    let matrix = Matrix::from(vec![vec![1, 2, 3], vec![4, 5, 6]]);
533    let range = MatrixRange::from(&matrix, 0..7, 1..4);
534    assert_eq!(2, range.view_rows());
535    assert_eq!(2, range.view_columns());
536    assert_eq!(2, range.rows.length);
537    assert_eq!(2, range.columns.length);
538}
539
540// # Safety
541//
542// Since the MatrixRef we own must implement MatrixRef correctly, so do we by delegating to it,
543// as we don't introduce any interior mutability.
544/**
545 * A MatrixMask of a MatrixRef type implements MatrixRef.
546 */
547unsafe impl<T, S> MatrixRef<T> for MatrixMask<T, S>
548where
549    S: MatrixRef<T>,
550{
551    fn try_get_reference(&self, row: Row, column: Column) -> Option<&T> {
552        let row = self.rows.mask(row);
553        let column = self.columns.mask(column);
554        self.source.try_get_reference(row, column)
555    }
556
557    fn view_rows(&self) -> Row {
558        // We enforce in the constructor that the mask is clipped to the size of our actual
559        // matrix, hence the mask cannot be longer than our data in either dimension. If the
560        // mask is the same length as our data, we'd return 0 which for MatrixRef is allowed.
561        self.source.view_rows() - self.rows.length
562    }
563
564    fn view_columns(&self) -> Column {
565        // We enforce in the constructor that the mask is clipped to the size of our actual
566        // matrix, hence the mask cannot be longer than our data in either dimension. If the
567        // mask is the same length as our data, we'd return 0 which for MatrixRef is allowed.
568        self.source.view_columns() - self.columns.length
569    }
570
571    unsafe fn get_reference_unchecked(&self, row: Row, column: Column) -> &T {
572        unsafe {
573            // It is the caller's responsibility to always call with row/column indexes in range,
574            // therefore calling get_reference_unchecked with indexes beyond the size of the matrix
575            // should never happen because on an arbitary MatrixRef it would be undefined behavior.
576            let row = self.rows.mask(row);
577            let column = self.columns.mask(column);
578            self.source.get_reference_unchecked(row, column)
579        }
580    }
581
582    fn data_layout(&self) -> DataLayout {
583        self.source.data_layout()
584    }
585}
586
587// # Safety
588//
589// Since the MatrixMut we own must implement MatrixMut correctly, so do we by delegating to it,
590// as we don't introduce any interior mutability.
591/**
592 * A MatrixMask of a MatrixMut type implements MatrixMut.
593 */
594unsafe impl<T, S> MatrixMut<T> for MatrixMask<T, S>
595where
596    S: MatrixMut<T>,
597{
598    fn try_get_reference_mut(&mut self, row: Row, column: Column) -> Option<&mut T> {
599        let row = self.rows.mask(row);
600        let column = self.columns.mask(column);
601        self.source.try_get_reference_mut(row, column)
602    }
603
604    unsafe fn get_reference_unchecked_mut(&mut self, row: Row, column: Column) -> &mut T {
605        unsafe {
606            // It is the caller's responsibility to always call with row/column indexes in range,
607            // therefore calling get_reference_unchecked with indexes beyond the size of the matrix
608            // should never happen because on an arbitary MatrixRef it would be undefined behavior.
609            let row = self.rows.mask(row);
610            let column = self.columns.mask(column);
611            self.source.get_reference_unchecked_mut(row, column)
612        }
613    }
614}
615
616// # Safety
617//
618// Since the NoInteriorMutability we own must implement NoInteriorMutability correctly, so
619// do we by delegating to it, as we don't introduce any interior mutability.
620/**
621 * A MatrixMask of a NoInteriorMutability type implements NoInteriorMutability.
622 */
623unsafe impl<T, S> NoInteriorMutability for MatrixMask<T, S> where S: NoInteriorMutability {}