easy_ml/matrices/views/ranges.rs
1use crate::matrices::views::{DataLayout, MatrixMut, MatrixRef, NoInteriorMutability};
2use crate::matrices::{Column, Row};
3
4use std::marker::PhantomData;
5use std::num::NonZeroUsize;
6use std::ops::Range;
7
8/**
9 * A 2 dimensional range over a matrix, hiding the values **outside** the range from view.
10 *
11 * The entire source is still owned by the MatrixRange however, so this does not permit
12 * creating multiple mutable ranges into a single matrix even if they wouldn't overlap.
13 *
14 * For non overlapping mutable ranges into a single matrix see
15 * [`partition`](crate::matrices::Matrix::partition).
16 *
17 * See also: [MatrixMask](MatrixMask)
18 */
19#[derive(Clone, Debug)]
20pub struct MatrixRange<T, S> {
21 source: S,
22 rows: IndexRange,
23 columns: IndexRange,
24 _type: PhantomData<T>,
25}
26
27/**
28 * A 2 dimensional mask over a matrix, hiding the values **inside** the range from view.
29 *
30 * The entire source is still owned by the MatrixMask however, so this does not permit
31 * creating multiple mutable masks into a single matrix even if they wouldn't overlap.
32 *
33 * See also: [MatrixRange](MatrixRange)
34 */
35#[derive(Clone, Debug)]
36pub struct MatrixMask<T, S> {
37 source: S,
38 rows: IndexRange,
39 columns: IndexRange,
40 _type: PhantomData<T>,
41}
42
43impl<T, S> MatrixRange<T, S>
44where
45 S: MatrixRef<T>,
46{
47 /**
48 * Creates a new MatrixRange giving a view of only the data within the row and column
49 * [IndexRange](IndexRange)s.
50 *
51 * # Examples
52 *
53 * Creating a view and manipulating a matrix from it.
54 * ```
55 * use easy_ml::matrices::Matrix;
56 * use easy_ml::matrices::views::{MatrixView, MatrixRange};
57 * let mut matrix = Matrix::from(vec![
58 * vec![ 2, 3, 4 ],
59 * vec![ 5, 1, 8 ]]);
60 * {
61 * let mut view = MatrixView::from(MatrixRange::from(&mut matrix, 0..1, 1..3));
62 * assert_eq!(vec![3, 4], view.row_major_iter().collect::<Vec<_>>());
63 * view.map_mut(|x| x + 10);
64 * }
65 * assert_eq!(matrix, Matrix::from(vec![
66 * vec![ 2, 13, 14 ],
67 * vec![ 5, 1, 8 ]]));
68 * ```
69 *
70 * Various ways to construct a MatrixRange
71 * ```
72 * use easy_ml::matrices::Matrix;
73 * use easy_ml::matrices::views::{IndexRange, MatrixRange};
74 * let matrix = Matrix::from(vec![vec![1]]);
75 * let index_range = MatrixRange::from(&matrix, IndexRange::new(0, 4), IndexRange::new(1, 3));
76 * let tuple = MatrixRange::from(&matrix, (0, 4), (1, 3));
77 * let array = MatrixRange::from(&matrix, [0, 4], [1, 3]);
78 * // Note std::ops::Range is start..end not start and length!
79 * let range = MatrixRange::from(&matrix, 0..4, 1..4);
80 * ```
81 *
82 * NOTE: In previous versions (<=1.8.1), this erroneously did not clip the IndexRange input to
83 * not exceed the rows and columns of the source, which led to the possibility to create
84 * MatrixRanges that reported a greater number of rows and columns in their shape than their
85 * actual data. This function will now correctly clip any ranges that exceed their sources.
86 */
87 pub fn from<R>(source: S, rows: R, columns: R) -> MatrixRange<T, S>
88 where
89 R: Into<IndexRange>,
90 {
91 let max_rows = source.view_rows();
92 let max_columns = source.view_columns();
93 MatrixRange {
94 source,
95 rows: {
96 let mut rows = rows.into();
97 rows.clip(max_rows);
98 rows
99 },
100 columns: {
101 let mut columns = columns.into();
102 columns.clip(max_columns);
103 columns
104 },
105 _type: PhantomData,
106 }
107 }
108
109 /**
110 * Consumes the MatrixRange, yielding the source it was created from.
111 */
112 #[allow(dead_code)]
113 pub fn source(self) -> S {
114 self.source
115 }
116
117 /**
118 * Gives a reference to the MatrixRange's source (in which the data is not clipped).
119 */
120 // # Safety
121 //
122 // Giving out a mutable reference to our source could allow it to be changed out from under us
123 // and make our range checks invalid. However, since the source implements MatrixRef
124 // interior mutability is not allowed, so we can give out shared references without breaking
125 // our own integrity.
126 #[allow(dead_code)]
127 pub fn source_ref(&self) -> &S {
128 &self.source
129 }
130}
131
132impl<T, S> MatrixMask<T, S>
133where
134 S: MatrixRef<T>,
135{
136 /**
137 * Creates a new MatrixMask giving a view of only the data outside the row and column
138 * [IndexRange](IndexRange)s. If the index range given for rows or columns exceeds the
139 * size of the matrix, they will be clipped to fit the actual size without an error.
140 *
141 * # Examples
142 *
143 * Creating a view and manipulating a matrix from it.
144 * ```
145 * use easy_ml::matrices::Matrix;
146 * use easy_ml::matrices::views::{MatrixView, MatrixMask};
147 * let mut matrix = Matrix::from(vec![
148 * vec![ 2, 3, 4 ],
149 * vec![ 5, 1, 8 ]]);
150 * {
151 * let mut view = MatrixView::from(MatrixMask::from(&mut matrix, 0..1, 2..3));
152 * assert_eq!(vec![5, 1], view.row_major_iter().collect::<Vec<_>>());
153 * view.map_mut(|x| x + 10);
154 * }
155 * assert_eq!(matrix, Matrix::from(vec![
156 * vec![ 2, 3, 4 ],
157 * vec![ 15, 11, 8 ]]));
158 * ```
159 *
160 * Various ways to construct a MatrixMask
161 * ```
162 * use easy_ml::matrices::Matrix;
163 * use easy_ml::matrices::views::{IndexRange, MatrixMask};
164 * let matrix = Matrix::from(vec![vec![1]]);
165 * let index_range = MatrixMask::from(&matrix, IndexRange::new(0, 4), IndexRange::new(1, 3));
166 * let tuple = MatrixMask::from(&matrix, (0, 4), (1, 3));
167 * let array = MatrixMask::from(&matrix, [0, 4], [1, 3]);
168 * // Note std::ops::Range is start..end not start and length!
169 * let range = MatrixMask::from(&matrix, 0..4, 1..4);
170 * ```
171 */
172 pub fn from<R>(source: S, rows: R, columns: R) -> MatrixMask<T, S>
173 where
174 R: Into<IndexRange>,
175 {
176 let max_rows = source.view_rows();
177 let max_columns = source.view_columns();
178 MatrixMask {
179 source,
180 rows: {
181 let mut rows = rows.into();
182 rows.clip(max_rows);
183 rows
184 },
185 columns: {
186 let mut columns = columns.into();
187 columns.clip(max_columns);
188 columns
189 },
190 _type: PhantomData,
191 }
192 }
193
194 /**
195 * Creates a MatrixMask of this source that retains only the specified
196 * number of elements at both the start and end of the rows.
197 * If twice the provided number of elements for the rows exceeds the
198 * number of rows in the matrix, then all elements are retained. Similarly,
199 * passing None retains all elements.
200 *
201 * ```
202 * use std::num::NonZeroUsize;
203 * use easy_ml::matrices::Matrix;
204 * use easy_ml::matrices::views::{MatrixView, MatrixMask};
205 * let matrix = Matrix::from_flat_row_major((5, 5), (0..25).collect());
206 * let start_and_end = MatrixView::from(
207 * MatrixMask::start_and_end_of_rows(
208 * matrix, NonZeroUsize::new(1)
209 * )
210 * );
211 * assert_eq!(
212 * start_and_end,
213 * Matrix::from_flat_row_major((2, 5), vec![
214 * 0, 1, 2, 3, 4,
215 * 20, 21, 22, 23, 24,
216 * ])
217 * );
218 * ```
219 */
220 pub fn start_and_end_of_rows(source: S, retain: Option<NonZeroUsize>) -> MatrixMask<T, S> {
221 let rows = match retain {
222 None => IndexRange::new(0, 0),
223 Some(x) => {
224 let x = x.get();
225 let length = source.view_rows();
226 let retain_start = std::cmp::min(x, length - 1);
227 let retain_end = length.saturating_sub(x);
228 let mut range: IndexRange = (retain_start..retain_end).into();
229 range.clip(length - 1);
230 range
231 }
232 };
233 let columns = IndexRange::new(0, 0);
234 MatrixMask::from(source, rows, columns)
235 }
236
237 /**
238 * Creates a MatrixMask of this source that retains only the specified
239 * number of elements at both the start and end of the columns.
240 * If twice the provided number of elements for the columns exceeds the
241 * number of columns in the matrix, then all elements are retained. Similarly,
242 * passing None retains all elements.
243 *
244 * ```
245 * use std::num::NonZeroUsize;
246 * use easy_ml::matrices::Matrix;
247 * use easy_ml::matrices::views::{MatrixView, MatrixMask};
248 * let matrix = Matrix::from_flat_row_major((5, 5), (0..25).collect());
249 * let start_and_end = MatrixView::from(
250 * MatrixMask::start_and_end_of_columns(
251 * matrix, NonZeroUsize::new(1)
252 * )
253 * );
254 * assert_eq!(
255 * start_and_end,
256 * Matrix::from_flat_row_major((5, 2), vec![
257 * 0, 4,
258 * 5, 9,
259 * 10, 14,
260 * 15, 19,
261 * 20, 24,
262 * ])
263 * );
264 * ```
265 */
266 pub fn start_and_end_of_columns(source: S, retain: Option<NonZeroUsize>) -> MatrixMask<T, S> {
267 let rows = IndexRange::new(0, 0);
268 let columns = match retain {
269 None => IndexRange::new(0, 0),
270 Some(x) => {
271 let x = x.get();
272 let length = source.view_columns();
273 let retain_start = std::cmp::min(x, length - 1);
274 let retain_end = length.saturating_sub(x);
275 let mut range: IndexRange = (retain_start..retain_end).into();
276 range.clip(length - 1);
277 range
278 }
279 };
280 MatrixMask::from(source, rows, columns)
281 }
282
283 /**
284 * Consumes the MatrixMask, yielding the source it was created from.
285 */
286 #[allow(dead_code)]
287 pub fn source(self) -> S {
288 self.source
289 }
290
291 /**
292 * Gives a reference to the MatrixMask's source (in which the data is not masked).
293 */
294 // # Safety
295 //
296 // Giving out a mutable reference to our source could allow it to be changed out from under us
297 // and make our mask checks invalid. However, since the source implements MatrixRef
298 // interior mutability is not allowed, so we can give out shared references without breaking
299 // our own integrity.
300 #[allow(dead_code)]
301 pub fn source_ref(&self) -> &S {
302 &self.source
303 }
304}
305
306/**
307 * A range bounded between `start` inclusive and `start + length` exclusive.
308 *
309 * # Examples
310 *
311 * Converting between [Range](std::ops::Range) and IndexRange.
312 * ```
313 * use std::ops::Range;
314 * use easy_ml::matrices::views::IndexRange;
315 * assert_eq!(IndexRange::new(3, 2), (3..5).into());
316 * assert_eq!(IndexRange::new(1, 5), (1..6).into());
317 * assert_eq!(IndexRange::new(0, 4), (0..4).into());
318 * ```
319 *
320 * Creating a Range
321 *
322 * ```
323 * use easy_ml::matrices::views::IndexRange;
324 * let range = IndexRange::new(3, 2);
325 * let also_range: IndexRange = (3, 2).into();
326 * let also_also_range: IndexRange = [3, 2].into();
327 * ```
328 *
329 * NB: You can construct an IndexRange where start+length exceeds isize::MAX or even
330 * usize::MAX, however matrices and tensors themselves cannot contain more than isize::MAX
331 * elements. Concerned readers should note that on a 64 bit computer this maximum
332 * value is 9,223,372,036,854,775,807 so running out of memory is likely to occur first.
333 */
334#[derive(Clone, Debug, Eq, PartialEq)]
335pub struct IndexRange {
336 pub(crate) start: usize,
337 pub(crate) length: usize,
338}
339
340impl IndexRange {
341 pub fn new(start: usize, length: usize) -> IndexRange {
342 IndexRange { start, length }
343 }
344
345 // TODO: If we make these public we need to disambiguate Range from Mask behaviour better
346 /**
347 * Maps from a coordinate space of the ith index accessible by this range to the actual index
348 * into the entire dimension's data.
349 */
350 #[inline]
351 pub(crate) fn map(&self, index: usize) -> Option<usize> {
352 if index < self.length {
353 Some(index + self.start)
354 } else {
355 None
356 }
357 }
358
359 // NOTE: This doesn't perform bounds checks, adding the length of the mask could push
360 // the index out of the valid bounds of the dimension it is for, but if we performed
361 // bounds checks here they would be redundant since performing the get with the masked index
362 // will bounds check if required
363 #[inline]
364 pub(crate) fn mask(&self, index: usize) -> usize {
365 if index < self.start {
366 index
367 } else {
368 index + self.length
369 }
370 }
371
372 // Clips the range or mask to not exceed an index. Note, this may yield 0 length ranges
373 // that have non zero starting positions, however map and mask will still calculate correctly.
374 pub(crate) fn clip(&mut self, max_index: usize) {
375 let end = self.start + self.length;
376 let end = std::cmp::min(end, max_index);
377 let length = end.saturating_sub(self.start);
378 self.length = length;
379 }
380}
381
382/**
383 * Converts from a range of start..end to an IndexRange of start and length
384 *
385 * NOTE: In previous versions (<=1.8.1) this did not saturate when attempting to subtract the
386 * start of the range from the end to calculate the length. It will now correctly produce an
387 * IndexRange with a length of 0 if the end is before or equal to the start.
388 */
389impl From<Range<usize>> for IndexRange {
390 fn from(range: Range<usize>) -> IndexRange {
391 IndexRange::new(range.start, range.end.saturating_sub(range.start))
392 }
393}
394
395/** Converts from an IndexRange of start and length to a range of start..end */
396impl From<IndexRange> for Range<usize> {
397 fn from(range: IndexRange) -> Range<usize> {
398 Range {
399 start: range.start,
400 end: range.start + range.length,
401 }
402 }
403}
404
405/**
406 * Converts from a tuple of start and length to an IndexRange
407 *
408 * NOTE: In previous versions (<=1.8.1), this was erroneously implemented as conversion from a
409 * tuple of start and end, not start and length as documented.
410 */
411impl From<(usize, usize)> for IndexRange {
412 fn from(range: (usize, usize)) -> IndexRange {
413 let (start, length) = range;
414 IndexRange::new(start, length)
415 }
416}
417
418/**
419 * Converts from an array of start and length to an IndexRange
420 *
421 * NOTE: In previous versions (<=1.8.1), this was erroneously implemented as conversion from an
422 * array of start and end, not start and length as documented.
423 */
424impl From<[usize; 2]> for IndexRange {
425 fn from(range: [usize; 2]) -> IndexRange {
426 let [start, length] = range;
427 IndexRange::new(start, length)
428 }
429}
430
431#[test]
432fn test_index_range_clipping() {
433 let mut range: IndexRange = (0..6).into();
434 range.clip(4);
435 assert_eq!(range, (0..4).into());
436 let mut range: IndexRange = (1..4).into();
437 range.clip(5);
438 assert_eq!(range, (1..4).into());
439 range.clip(2);
440 assert_eq!(range, (1..2).into());
441 let mut range: IndexRange = (3..5).into();
442 range.clip(2);
443 assert_eq!(range, (3..2).into());
444 assert_eq!(range.map(0), None);
445 assert_eq!(range.map(1), None);
446 assert_eq!(range.mask(0), 0);
447 assert_eq!(range.mask(1), 1);
448}
449
450// # Safety
451//
452// Since the MatrixRef we own must implement MatrixRef correctly, so do we by delegating to it,
453// as we don't introduce any interior mutability.
454/**
455 * A MatrixRange of a MatrixRef type implements MatrixRef.
456 */
457unsafe impl<T, S> MatrixRef<T> for MatrixRange<T, S>
458where
459 S: MatrixRef<T>,
460{
461 fn try_get_reference(&self, row: Row, column: Column) -> Option<&T> {
462 let row = self.rows.map(row)?;
463 let column = self.columns.map(column)?;
464 self.source.try_get_reference(row, column)
465 }
466
467 fn view_rows(&self) -> Row {
468 self.rows.length
469 }
470
471 fn view_columns(&self) -> Column {
472 self.columns.length
473 }
474
475 unsafe fn get_reference_unchecked(&self, row: Row, column: Column) -> &T {
476 unsafe {
477 // It is the caller's responsibiltiy to always call with row/column indexes in range,
478 // therefore the unwrap() case should never happen because on an arbitary MatrixRef
479 // it would be undefined behavior.
480 let row = self.rows.map(row).unwrap();
481 let column = self.columns.map(column).unwrap();
482 self.source.get_reference_unchecked(row, column)
483 }
484 }
485
486 fn data_layout(&self) -> DataLayout {
487 self.source.data_layout()
488 }
489}
490
491// # Safety
492//
493// Since the MatrixMut we own must implement MatrixMut correctly, so do we by delegating to it,
494// as we don't introduce any interior mutability.
495/**
496 * A MatrixRange of a MatrixMut type implements MatrixMut.
497 */
498unsafe impl<T, S> MatrixMut<T> for MatrixRange<T, S>
499where
500 S: MatrixMut<T>,
501{
502 fn try_get_reference_mut(&mut self, row: Row, column: Column) -> Option<&mut T> {
503 let row = self.rows.map(row)?;
504 let column = self.columns.map(column)?;
505 self.source.try_get_reference_mut(row, column)
506 }
507
508 unsafe fn get_reference_unchecked_mut(&mut self, row: Row, column: Column) -> &mut T {
509 unsafe {
510 // It is the caller's responsibility to always call with row/column indexes in range,
511 // therefore the unwrap() case should never happen because on an arbitary MatrixRef
512 // it would be undefined behavior.
513 let row = self.rows.map(row).unwrap();
514 let column = self.columns.map(column).unwrap();
515 self.source.get_reference_unchecked_mut(row, column)
516 }
517 }
518}
519
520// # Safety
521//
522// Since the NoInteriorMutability we own must implement NoInteriorMutability correctly, so
523// do we by delegating to it, as we don't introduce any interior mutability.
524/**
525 * A MatrixRange of a NoInteriorMutability type implements NoInteriorMutability.
526 */
527unsafe impl<T, S> NoInteriorMutability for MatrixRange<T, S> where S: NoInteriorMutability {}
528
529#[test]
530fn test_matrix_range_shape_clips() {
531 use crate::matrices::Matrix;
532 let matrix = Matrix::from(vec![vec![1, 2, 3], vec![4, 5, 6]]);
533 let range = MatrixRange::from(&matrix, 0..7, 1..4);
534 assert_eq!(2, range.view_rows());
535 assert_eq!(2, range.view_columns());
536 assert_eq!(2, range.rows.length);
537 assert_eq!(2, range.columns.length);
538}
539
540// # Safety
541//
542// Since the MatrixRef we own must implement MatrixRef correctly, so do we by delegating to it,
543// as we don't introduce any interior mutability.
544/**
545 * A MatrixMask of a MatrixRef type implements MatrixRef.
546 */
547unsafe impl<T, S> MatrixRef<T> for MatrixMask<T, S>
548where
549 S: MatrixRef<T>,
550{
551 fn try_get_reference(&self, row: Row, column: Column) -> Option<&T> {
552 let row = self.rows.mask(row);
553 let column = self.columns.mask(column);
554 self.source.try_get_reference(row, column)
555 }
556
557 fn view_rows(&self) -> Row {
558 // We enforce in the constructor that the mask is clipped to the size of our actual
559 // matrix, hence the mask cannot be longer than our data in either dimension. If the
560 // mask is the same length as our data, we'd return 0 which for MatrixRef is allowed.
561 self.source.view_rows() - self.rows.length
562 }
563
564 fn view_columns(&self) -> Column {
565 // We enforce in the constructor that the mask is clipped to the size of our actual
566 // matrix, hence the mask cannot be longer than our data in either dimension. If the
567 // mask is the same length as our data, we'd return 0 which for MatrixRef is allowed.
568 self.source.view_columns() - self.columns.length
569 }
570
571 unsafe fn get_reference_unchecked(&self, row: Row, column: Column) -> &T {
572 unsafe {
573 // It is the caller's responsibility to always call with row/column indexes in range,
574 // therefore calling get_reference_unchecked with indexes beyond the size of the matrix
575 // should never happen because on an arbitary MatrixRef it would be undefined behavior.
576 let row = self.rows.mask(row);
577 let column = self.columns.mask(column);
578 self.source.get_reference_unchecked(row, column)
579 }
580 }
581
582 fn data_layout(&self) -> DataLayout {
583 self.source.data_layout()
584 }
585}
586
587// # Safety
588//
589// Since the MatrixMut we own must implement MatrixMut correctly, so do we by delegating to it,
590// as we don't introduce any interior mutability.
591/**
592 * A MatrixMask of a MatrixMut type implements MatrixMut.
593 */
594unsafe impl<T, S> MatrixMut<T> for MatrixMask<T, S>
595where
596 S: MatrixMut<T>,
597{
598 fn try_get_reference_mut(&mut self, row: Row, column: Column) -> Option<&mut T> {
599 let row = self.rows.mask(row);
600 let column = self.columns.mask(column);
601 self.source.try_get_reference_mut(row, column)
602 }
603
604 unsafe fn get_reference_unchecked_mut(&mut self, row: Row, column: Column) -> &mut T {
605 unsafe {
606 // It is the caller's responsibility to always call with row/column indexes in range,
607 // therefore calling get_reference_unchecked with indexes beyond the size of the matrix
608 // should never happen because on an arbitary MatrixRef it would be undefined behavior.
609 let row = self.rows.mask(row);
610 let column = self.columns.mask(column);
611 self.source.get_reference_unchecked_mut(row, column)
612 }
613 }
614}
615
616// # Safety
617//
618// Since the NoInteriorMutability we own must implement NoInteriorMutability correctly, so
619// do we by delegating to it, as we don't introduce any interior mutability.
620/**
621 * A MatrixMask of a NoInteriorMutability type implements NoInteriorMutability.
622 */
623unsafe impl<T, S> NoInteriorMutability for MatrixMask<T, S> where S: NoInteriorMutability {}