easy_ml/tensors/views/
ranges.rs

1use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
2use crate::tensors::{Dimension, InvalidDimensionsError, InvalidShapeError};
3use std::error::Error;
4use std::fmt;
5use std::marker::PhantomData;
6
7pub use crate::matrices::views::IndexRange;
8
9/**
10 * A range over a tensor in D dimensions, hiding the values **outside** the range from view.
11 *
12 * The entire source is still owned by the TensorRange however, so this does not permit
13 * creating multiple mutable ranges into a single tensor even if they wouldn't overlap.
14 *
15 * See also: [TensorMask](TensorMask)
16 *
17 * ```
18 * use easy_ml::tensors::Tensor;
19 * use easy_ml::tensors::views::{TensorView, TensorRange};
20 * let numbers = Tensor::from([("batch", 4), ("rows", 8), ("columns", 8)], vec![
21 *     0, 0, 0, 1, 1, 0, 0, 0,
22 *     0, 0, 1, 1, 1, 0, 0, 0,
23 *     0, 0, 0, 1, 1, 0, 0, 0,
24 *     0, 0, 0, 1, 1, 0, 0, 0,
25 *     0, 0, 0, 1, 1, 0, 0, 0,
26 *     0, 0, 0, 1, 1, 0, 0, 0,
27 *     0, 0, 1, 1, 1, 1, 0, 0,
28 *     0, 0, 1, 1, 1, 1, 0, 0,
29 *
30 *     0, 0, 0, 0, 0, 0, 0, 0,
31 *     0, 0, 0, 2, 2, 0, 0, 0,
32 *     0, 0, 2, 0, 0, 2, 0, 0,
33 *     0, 0, 0, 0, 0, 2, 0, 0,
34 *     0, 0, 0, 0, 2, 0, 0, 0,
35 *     0, 0, 0, 2, 0, 0, 0, 0,
36 *     0, 0, 2, 0, 0, 0, 0, 0,
37 *     0, 0, 2, 2, 2, 2, 0, 0,
38 *
39 *     0, 0, 0, 3, 3, 0, 0, 0,
40 *     0, 0, 3, 0, 0, 3, 0, 0,
41 *     0, 0, 0, 0, 0, 3, 0, 0,
42 *     0, 0, 0, 0, 3, 0, 0, 0,
43 *     0, 0, 0, 0, 3, 0, 0, 0,
44 *     0, 0, 0, 0, 0, 3, 0, 0,
45 *     0, 0, 3, 0, 0, 3, 0, 0,
46 *     0, 0, 0, 3, 3, 0, 0, 0,
47 *
48 *     0, 0, 0, 0, 0, 0, 0, 0,
49 *     0, 0, 0, 0, 4, 0, 0, 0,
50 *     0, 0, 0, 4, 4, 0, 0, 0,
51 *     0, 0, 4, 0, 4, 0, 0, 0,
52 *     0, 4, 4, 4, 4, 4, 0, 0,
53 *     0, 0, 0, 0, 4, 0, 0, 0,
54 *     0, 0, 0, 0, 4, 0, 0, 0,
55 *     0, 0, 0, 0, 4, 0, 0, 0
56 * ]);
57 * let one_and_two = TensorView::from(
58 *     TensorRange::from(&numbers, [("batch", 0..2)])
59 *         .expect("Input is constucted so that our range is valid")
60 * );
61 * let framed = TensorView::from(
62 *     TensorRange::from(&numbers, [("rows", [1, 6]), ("columns", [1, 6])])
63 *         .expect("Input is constucted so that our range is valid")
64 * );
65 * assert_eq!(one_and_two.shape(), [("batch", 2), ("rows", 8), ("columns", 8)]);
66 * assert_eq!(framed.shape(), [("batch", 4), ("rows", 6), ("columns", 6)]);
67 * println!("{}", framed.select([("batch", 3)]));
68 * // D = 2
69 * // ("rows", 6), ("columns", 6)
70 * // [ 0, 0, 0, 4, 0, 0
71 * //   0, 0, 4, 4, 0, 0
72 * //   0, 4, 0, 4, 0, 0
73 * //   4, 4, 4, 4, 4, 0
74 * //   0, 0, 0, 4, 0, 0
75 * //   0, 0, 0, 4, 0, 0 ]
76 * ```
77 */
78#[derive(Clone, Debug)]
79pub struct TensorRange<T, S, const D: usize> {
80    source: S,
81    range: [IndexRange; D],
82    _type: PhantomData<T>,
83}
84
85/**
86 * A mask over a tensor in D dimensions, hiding the values **inside** the range from view.
87 *
88 * The entire source is still owned by the TensorMask however, so this does not permit
89 * creating multiple mutable masks into a single tensor even if they wouldn't overlap.
90 *
91 * See also: [TensorRange](TensorRange)
92 *
93 * ```
94 * use easy_ml::tensors::Tensor;
95 * use easy_ml::tensors::views::{TensorView, TensorMask};
96 * let numbers = Tensor::from([("batch", 4), ("rows", 8), ("columns", 8)], vec![
97 *     0, 0, 0, 1, 1, 0, 0, 0,
98 *     0, 0, 1, 1, 1, 0, 0, 0,
99 *     0, 0, 0, 1, 1, 0, 0, 0,
100 *     0, 0, 0, 1, 1, 0, 0, 0,
101 *     0, 0, 0, 1, 1, 0, 0, 0,
102 *     0, 0, 0, 1, 1, 0, 0, 0,
103 *     0, 0, 1, 1, 1, 1, 0, 0,
104 *     0, 0, 1, 1, 1, 1, 0, 0,
105 *
106 *     0, 0, 0, 0, 0, 0, 0, 0,
107 *     0, 0, 0, 2, 2, 0, 0, 0,
108 *     0, 0, 2, 0, 0, 2, 0, 0,
109 *     0, 0, 0, 0, 0, 2, 0, 0,
110 *     0, 0, 0, 0, 2, 0, 0, 0,
111 *     0, 0, 0, 2, 0, 0, 0, 0,
112 *     0, 0, 2, 0, 0, 0, 0, 0,
113 *     0, 0, 2, 2, 2, 2, 0, 0,
114 *
115 *     0, 0, 0, 3, 3, 0, 0, 0,
116 *     0, 0, 3, 0, 0, 3, 0, 0,
117 *     0, 0, 0, 0, 0, 3, 0, 0,
118 *     0, 0, 0, 0, 3, 0, 0, 0,
119 *     0, 0, 0, 0, 3, 0, 0, 0,
120 *     0, 0, 0, 0, 0, 3, 0, 0,
121 *     0, 0, 3, 0, 0, 3, 0, 0,
122 *     0, 0, 0, 3, 3, 0, 0, 0,
123 *
124 *     0, 0, 0, 0, 0, 0, 0, 0,
125 *     0, 0, 0, 0, 4, 0, 0, 0,
126 *     0, 0, 0, 4, 4, 0, 0, 0,
127 *     0, 0, 4, 0, 4, 0, 0, 0,
128 *     0, 4, 4, 4, 4, 4, 0, 0,
129 *     0, 0, 0, 0, 4, 0, 0, 0,
130 *     0, 0, 0, 0, 4, 0, 0, 0,
131 *     0, 0, 0, 0, 4, 0, 0, 0
132 * ]);
133 * let one_and_four = TensorView::from(
134 *     TensorMask::from(&numbers, [("batch", 1..3)])
135 *         .expect("Input is constucted so that our mask is valid")
136 * );
137 * let corners = TensorView::from(
138 *     TensorMask::from(&numbers, [("rows", [3, 2]), ("columns", [3, 2])])
139 *         .expect("Input is constucted so that our mask is valid")
140 * );
141 * assert_eq!(one_and_four.shape(), [("batch", 2), ("rows", 8), ("columns", 8)]);
142 * assert_eq!(corners.shape(), [("batch", 4), ("rows", 6), ("columns", 6)]);
143 * println!("{}", corners.select([("batch", 2)]));
144 * // D = 2
145 * // ("rows", 6), ("columns", 6)
146 * // [ 0, 0, 0, 0, 0, 0
147 * //   0, 0, 3, 3, 0, 0
148 * //   0, 0, 0, 3, 0, 0
149 * //   0, 0, 0, 3, 0, 0
150 * //   0, 0, 3, 3, 0, 0
151 * //   0, 0, 0, 0, 0, 0 ]
152 * ```
153 */
154#[derive(Clone, Debug)]
155pub struct TensorMask<T, S, const D: usize> {
156    source: S,
157    mask: [IndexRange; D],
158    _type: PhantomData<T>,
159}
160
161/**
162 * An error in creating a [TensorRange](TensorRange) or a [TensorMask](TensorMask).
163 */
164#[derive(Clone, Debug, Eq, PartialEq)]
165pub enum IndexRangeValidationError<const D: usize, const P: usize> {
166    /**
167     * The shape that resulting Tensor would have would not be valid.
168     */
169    InvalidShape(InvalidShapeError<D>),
170    /**
171     * Multiple of the same dimension name were provided, but we can only take one mask or range
172     * for each dimension at a time.
173     */
174    InvalidDimensions(InvalidDimensionsError<D, P>),
175}
176
177impl<const D: usize, const P: usize> fmt::Display for IndexRangeValidationError<D, P> {
178    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179        match self {
180            IndexRangeValidationError::InvalidShape(error) => write!(f, "{:?}", error),
181            IndexRangeValidationError::InvalidDimensions(error) => write!(f, "{:?}", error),
182        }
183    }
184}
185
186impl<const D: usize, const P: usize> Error for IndexRangeValidationError<D, P> {
187    fn source(&self) -> Option<&(dyn Error + 'static)> {
188        match self {
189            IndexRangeValidationError::InvalidShape(error) => Some(error),
190            IndexRangeValidationError::InvalidDimensions(error) => Some(error),
191        }
192    }
193}
194
195/**
196 * An error in creating a [TensorRange](TensorRange) or a [TensorMask](TensorMask) using
197 * strict validation.
198 */
199#[derive(Clone, Debug, Eq, PartialEq)]
200pub enum StrictIndexRangeValidationError<const D: usize, const P: usize> {
201    /**
202     * In at least one dimension, the mask or range provided exceeds the bounds of the shape
203     * of the Tensor it was to be used on. This is not necessarily an issue as the mask or
204     * range could be clipped to the bounds of the Tensor's shape, but a constructor which
205     * rejects out of bounds input was used.
206     */
207    OutsideShape {
208        shape: [(Dimension, usize); D],
209        index_range: [Option<IndexRange>; D],
210    },
211    Error(IndexRangeValidationError<D, P>),
212}
213
214impl<const D: usize, const P: usize> fmt::Display for StrictIndexRangeValidationError<D, P> {
215    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
216        use StrictIndexRangeValidationError as S;
217        match self {
218            S::OutsideShape { shape, index_range } => write!(
219                f,
220                "IndexRange array {:?} is out of bounds of shape {:?}",
221                index_range, shape
222            ),
223            S::Error(error) => write!(f, "{:?}", error),
224        }
225    }
226}
227
228impl<const D: usize, const P: usize> Error for StrictIndexRangeValidationError<D, P> {
229    fn source(&self) -> Option<&(dyn Error + 'static)> {
230        use StrictIndexRangeValidationError as S;
231        match self {
232            S::OutsideShape {
233                shape: _,
234                index_range: _,
235            } => None,
236            S::Error(error) => Some(error),
237        }
238    }
239}
240
241fn from_named_to_all<T, S, R, const D: usize, const P: usize>(
242    source: &S,
243    ranges: [(Dimension, R); P],
244) -> Result<[Option<IndexRange>; D], IndexRangeValidationError<D, P>>
245where
246    S: TensorRef<T, D>,
247    R: Into<IndexRange>,
248{
249    let shape = source.view_shape();
250    let ranges = ranges.map(|(d, r)| (d, r.into()));
251    let dimensions = InvalidDimensionsError {
252        provided: ranges.clone().map(|(d, _)| d),
253        valid: shape.map(|(d, _)| d),
254    };
255    if dimensions.has_duplicates() {
256        return Err(IndexRangeValidationError::InvalidDimensions(dimensions));
257    }
258    // Since we now know there's no duplicates, we can lookup the dimension index for each name
259    // in the shape and we know we'll get different indexes on each lookup.
260    let mut all_ranges: [Option<IndexRange>; D] = std::array::from_fn(|_| None);
261    for (name, range) in ranges.into_iter() {
262        match crate::tensors::dimensions::position_of(&shape, name) {
263            Some(d) => all_ranges[d] = Some(range),
264            None => return Err(IndexRangeValidationError::InvalidDimensions(dimensions)),
265        };
266    }
267    Ok(all_ranges)
268}
269
270impl<T, S, const D: usize> TensorRange<T, S, D>
271where
272    S: TensorRef<T, D>,
273{
274    /**
275     * Constructs a TensorRange from a tensor and set of dimension name/range pairs.
276     *
277     * Returns the Err variant if any dimension would have a length of 0 after applying the
278     * ranges, if multiple pairs with the same name are provided, or if any dimension names aren't
279     * in the source.
280     */
281    pub fn from<R, const P: usize>(
282        source: S,
283        ranges: [(Dimension, R); P],
284    ) -> Result<TensorRange<T, S, D>, IndexRangeValidationError<D, P>>
285    where
286        R: Into<IndexRange>,
287    {
288        let all_ranges = from_named_to_all(&source, ranges)?;
289        match TensorRange::from_all(source, all_ranges) {
290            Ok(tensor_range) => Ok(tensor_range),
291            Err(invalid_shape) => Err(IndexRangeValidationError::InvalidShape(invalid_shape)),
292        }
293    }
294
295    /**
296     * Constructs a TensorRange from a tensor and set of dimension name/range pairs.
297     *
298     * Returns the Err variant if any dimension would have a length of 0 after applying the
299     * ranges, if multiple pairs with the same name are provided, or if any dimension names aren't
300     * in the source, or any range extends beyond the length of that dimension in the tensor.
301     */
302    pub fn from_strict<R, const P: usize>(
303        source: S,
304        ranges: [(Dimension, R); P],
305    ) -> Result<TensorRange<T, S, D>, StrictIndexRangeValidationError<D, P>>
306    where
307        R: Into<IndexRange>,
308    {
309        use StrictIndexRangeValidationError as S;
310        let all_ranges = match from_named_to_all(&source, ranges) {
311            Ok(all_ranges) => all_ranges,
312            Err(error) => return Err(S::Error(error)),
313        };
314        match TensorRange::from_all_strict(source, all_ranges) {
315            Ok(tensor_range) => Ok(tensor_range),
316            Err(S::OutsideShape { shape, index_range }) => {
317                Err(S::OutsideShape { shape, index_range })
318            }
319            Err(S::Error(IndexRangeValidationError::InvalidShape(error))) => {
320                Err(S::Error(IndexRangeValidationError::InvalidShape(error)))
321            }
322            Err(S::Error(IndexRangeValidationError::InvalidDimensions(_))) => panic!(
323                "Unexpected InvalidDimensions error case after validating for InvalidDimensions already"
324            ),
325        }
326    }
327
328    /**
329     * Constructs a TensorRange from a tensor and a range for each dimension in the tensor
330     * (provided in the same order as the tensor's shape).
331     *
332     * Returns the Err variant if any dimension would have a length of 0 after applying the ranges.
333     */
334    pub fn from_all<R>(
335        source: S,
336        ranges: [Option<R>; D],
337    ) -> Result<TensorRange<T, S, D>, InvalidShapeError<D>>
338    where
339        R: Into<IndexRange>,
340    {
341        TensorRange::clip_from(
342            source,
343            ranges.map(|option| option.map(|range| range.into())),
344        )
345    }
346
347    fn clip_from(
348        source: S,
349        ranges: [Option<IndexRange>; D],
350    ) -> Result<TensorRange<T, S, D>, InvalidShapeError<D>> {
351        let shape = source.view_shape();
352        let mut ranges = {
353            // TODO: A iterator enumerate call would be much cleaner here but everything
354            // except array::map is not stable yet.
355            let mut d = 0;
356            ranges.map(|option| {
357                // convert None to ranges that select the entire length of the tensor
358                let range = option.unwrap_or_else(|| IndexRange::new(0, shape[d].1));
359                d += 1;
360                range
361            })
362        };
363        let shape = InvalidShapeError {
364            shape: clip_range_shape(&shape, &mut ranges),
365        };
366        if !shape.is_valid() {
367            return Err(shape);
368        }
369
370        Ok(TensorRange {
371            source,
372            range: ranges,
373            _type: PhantomData,
374        })
375    }
376
377    /**
378     * Constructs a TensorRange from a tensor and a range for each dimension in the tensor
379     * (provided in the same order as the tensor's shape), ensuring the range is within the
380     * lengths of the tensor.
381     *
382     * Returns the Err variant if any dimension would have a length of 0 after applying the
383     * ranges or any range extends beyond the length of that dimension in the tensor.
384     */
385    pub fn from_all_strict<R>(
386        source: S,
387        range: [Option<R>; D],
388    ) -> Result<TensorRange<T, S, D>, StrictIndexRangeValidationError<D, D>>
389    where
390        R: Into<IndexRange>,
391    {
392        let shape = source.view_shape();
393        let range = range.map(|option| option.map(|range| range.into()));
394        if range_exceeds_bounds(&shape, &range) {
395            return Err(StrictIndexRangeValidationError::OutsideShape {
396                shape,
397                index_range: range,
398            });
399        }
400
401        match TensorRange::clip_from(source, range) {
402            Ok(tensor_range) => Ok(tensor_range),
403            Err(invalid_shape) => Err(StrictIndexRangeValidationError::Error(
404                IndexRangeValidationError::InvalidShape(invalid_shape),
405            )),
406        }
407    }
408
409    /**
410     * Consumes the TensorRange, yielding the source it was created from.
411     */
412    #[allow(dead_code)]
413    pub fn source(self) -> S {
414        self.source
415    }
416
417    /**
418     * Gives a reference to the TensorRange's source (in which the data is not clipped).
419     */
420    // # Safety
421    //
422    // Giving out a mutable reference to our source could allow it to be changed out from under us
423    // and make our range checks invalid. However, since the source implements TensorRef
424    // interior mutability is not allowed, so we can give out shared references without breaking
425    // our own integrity.
426    #[allow(dead_code)]
427    pub fn source_ref(&self) -> &S {
428        &self.source
429    }
430}
431
432fn range_exceeds_bounds<const D: usize>(
433    source: &[(Dimension, usize); D],
434    range: &[Option<IndexRange>; D],
435) -> bool {
436    for (d, (_, end)) in source.iter().enumerate() {
437        let end = *end;
438        match &range[d] {
439            None => continue,
440            Some(range) => {
441                let range_end = range.start + range.length;
442                match range_end > end {
443                    true => return true,
444                    false => (),
445                };
446            }
447        }
448    }
449    false
450}
451
452// Returns the shape the tensor's shape will be left as with the range applied, clipping any
453// ranges that exceed the bounds of the tensor's shape.
454fn clip_range_shape<const D: usize>(
455    source: &[(Dimension, usize); D],
456    range: &mut [IndexRange; D],
457) -> [(Dimension, usize); D] {
458    let mut shape = *source;
459    for (d, (_, length)) in shape.iter_mut().enumerate() {
460        let range = &mut range[d];
461        range.clip(*length);
462        // the length that remains is the length of the range
463        *length = range.length;
464    }
465    shape
466}
467
468impl<T, S, const D: usize> TensorMask<T, S, D>
469where
470    S: TensorRef<T, D>,
471{
472    /**
473     * Constructs a TensorMask from a tensor and set of dimension name/mask pairs.
474     *
475     * Returns the Err variant if any masked dimension would have a length of 0, if multiple
476     * pairs with the same name are provided, or if any dimension names aren't in the source.
477     */
478    pub fn from<R, const P: usize>(
479        source: S,
480        masks: [(Dimension, R); P],
481    ) -> Result<TensorMask<T, S, D>, IndexRangeValidationError<D, P>>
482    where
483        R: Into<IndexRange>,
484    {
485        let all_masks = from_named_to_all(&source, masks)?;
486        match TensorMask::from_all(source, all_masks) {
487            Ok(tensor_mask) => Ok(tensor_mask),
488            Err(invalid_shape) => Err(IndexRangeValidationError::InvalidShape(invalid_shape)),
489        }
490    }
491
492    /**
493     * Constructs a TensorMask from a tensor and set of dimension name/range pairs.
494     *
495     * Returns the Err variant if any masked dimension would have a length of 0, if multiple
496     * pairs with the same name are provided, or if any dimension names aren't in the source,
497     * or any mask extends beyond the length of that dimension in the tensor.
498     */
499    pub fn from_strict<R, const P: usize>(
500        source: S,
501        masks: [(Dimension, R); P],
502    ) -> Result<TensorMask<T, S, D>, StrictIndexRangeValidationError<D, P>>
503    where
504        R: Into<IndexRange>,
505    {
506        use StrictIndexRangeValidationError as S;
507        let all_masks = match from_named_to_all(&source, masks) {
508            Ok(all_masks) => all_masks,
509            Err(error) => return Err(S::Error(error)),
510        };
511        match TensorMask::from_all_strict(source, all_masks) {
512            Ok(tensor_mask) => Ok(tensor_mask),
513            Err(S::OutsideShape { shape, index_range }) => {
514                Err(S::OutsideShape { shape, index_range })
515            }
516            Err(S::Error(IndexRangeValidationError::InvalidShape(error))) => {
517                Err(S::Error(IndexRangeValidationError::InvalidShape(error)))
518            }
519            Err(S::Error(IndexRangeValidationError::InvalidDimensions(_))) => panic!(
520                "Unexpected InvalidDimensions error case after validating for InvalidDimensions already"
521            ),
522        }
523    }
524
525    /**
526     * Constructs a TensorMask from a tensor and a mask for each dimension in the tensor
527     * (provided in the same order as the tensor's shape).
528     *
529     * Returns the Err variant if any masked dimension would have a length of 0.
530     */
531    pub fn from_all<R>(
532        source: S,
533        mask: [Option<R>; D],
534    ) -> Result<TensorMask<T, S, D>, InvalidShapeError<D>>
535    where
536        R: Into<IndexRange>,
537    {
538        TensorMask::clip_from(source, mask.map(|option| option.map(|mask| mask.into())))
539    }
540
541    fn clip_from(
542        source: S,
543        masks: [Option<IndexRange>; D],
544    ) -> Result<TensorMask<T, S, D>, InvalidShapeError<D>> {
545        let shape = source.view_shape();
546        let mut masks = masks.map(|option| option.unwrap_or_else(|| IndexRange::new(0, 0)));
547        let shape = InvalidShapeError {
548            shape: clip_masked_shape(&shape, &mut masks),
549        };
550        if !shape.is_valid() {
551            return Err(shape);
552        }
553
554        Ok(TensorMask {
555            source,
556            mask: masks,
557            _type: PhantomData,
558        })
559    }
560
561    /**
562     * Constructs a TensorMask from a tensor and a mask for each dimension in the tensor
563     * (provided in the same order as the tensor's shape), ensuring the mask is within the
564     * lengths of the tensor.
565     *
566     * Returns the Err variant if any masked dimension would have a length of 0 or any mask
567     * extends beyond the length of that dimension in the tensor.
568     */
569    pub fn from_all_strict<R>(
570        source: S,
571        masks: [Option<R>; D],
572    ) -> Result<TensorMask<T, S, D>, StrictIndexRangeValidationError<D, D>>
573    where
574        R: Into<IndexRange>,
575    {
576        let shape = source.view_shape();
577        let masks = masks.map(|option| option.map(|mask| mask.into()));
578        if mask_exceeds_bounds(&shape, &masks) {
579            return Err(StrictIndexRangeValidationError::OutsideShape {
580                shape,
581                index_range: masks,
582            });
583        }
584
585        match TensorMask::clip_from(source, masks) {
586            Ok(tensor_mask) => Ok(tensor_mask),
587            Err(invalid_shape) => Err(StrictIndexRangeValidationError::Error(
588                IndexRangeValidationError::InvalidShape(invalid_shape),
589            )),
590        }
591    }
592
593    /**
594     * Consumes the TensorMask, yielding the source it was created from.
595     */
596    #[allow(dead_code)]
597    pub fn source(self) -> S {
598        self.source
599    }
600
601    /**
602     * Gives a reference to the TensorMask's source (in which the data is not masked).
603     */
604    // # Safety
605    //
606    // Giving out a mutable reference to our source could allow it to be changed out from under us
607    // and make our mask checks invalid. However, since the source implements TensorRef
608    // interior mutability is not allowed, so we can give out shared references without breaking
609    // our own integrity.
610    #[allow(dead_code)]
611    pub fn source_ref(&self) -> &S {
612        &self.source
613    }
614}
615
616// Returns the shape the tensor's shape will be left as with the mask applied, clipping any
617// masks that exceed the bounds of the tensor's shape.
618fn clip_masked_shape<const D: usize>(
619    source: &[(Dimension, usize); D],
620    mask: &mut [IndexRange; D],
621) -> [(Dimension, usize); D] {
622    let mut shape = *source;
623    for (d, (_, length)) in shape.iter_mut().enumerate() {
624        let mask = &mut mask[d];
625        mask.clip(*length);
626        // the length that remains is what is not included along the mask
627        *length -= mask.length;
628    }
629    shape
630}
631
632fn mask_exceeds_bounds<const D: usize>(
633    source: &[(Dimension, usize); D],
634    mask: &[Option<IndexRange>; D],
635) -> bool {
636    // same test for a mask extending past a shape as for a range
637    range_exceeds_bounds(source, mask)
638}
639
640fn map_indexes_by_range<const D: usize>(
641    indexes: [usize; D],
642    ranges: &[IndexRange; D],
643) -> Option<[usize; D]> {
644    let mut mapped = [0; D];
645    for (d, (r, i)) in ranges.iter().zip(indexes.into_iter()).enumerate() {
646        mapped[d] = r.map(i)?;
647    }
648    Some(mapped)
649}
650
651// # Safety
652//
653// The type implementing TensorRef must implement it correctly, so by delegating to it
654// and just hiding some of the valid indexes from view, we implement TensorRef correctly as well.
655/**
656 * A TensorRange implements TensorRef, with the dimension lengths reduced to the range the
657 * the TensorRange was created with.
658 */
659unsafe impl<T, S, const D: usize> TensorRef<T, D> for TensorRange<T, S, D>
660where
661    S: TensorRef<T, D>,
662{
663    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
664        self.source
665            .get_reference(map_indexes_by_range(indexes, &self.range)?)
666    }
667
668    fn view_shape(&self) -> [(Dimension, usize); D] {
669        // Since when we were constructed we clipped the length of each range to no more than
670        // our source, we can just return the length of each range now
671        let mut shape = self.source.view_shape();
672        // TODO: zip would work really nicely here but it's not stable yet
673        for (pair, range) in shape.iter_mut().zip(self.range.iter()) {
674            pair.1 = range.length;
675        }
676        shape
677    }
678
679    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
680        unsafe {
681            // It is the caller's responsibility to always call with indexes in range,
682            // therefore the unwrap() case should never happen because on an arbitary TensorRef
683            // it would be undefined behavior.
684            // TODO: Can we use unwrap_unchecked here?
685            self.source
686                .get_reference_unchecked(map_indexes_by_range(indexes, &self.range).unwrap())
687        }
688    }
689
690    fn data_layout(&self) -> DataLayout<D> {
691        // Our range means the view shape no longer matches up to a single
692        // line of data in memory in the general case (ranges in 1D could still be linear
693        // but DataLayout is not very meaningful till we get to 2D anyway).
694        DataLayout::NonLinear
695    }
696}
697
698// # Safety
699//
700// The type implementing TensorMut must implement it correctly, so by delegating to it
701// and just hiding some of the valid indexes from view, we implement TensorMut correctly as well.
702/**
703 * A TensorRange implements TensorMut, with the dimension lengths reduced to the range the
704 * the TensorRange was created with.
705 */
706unsafe impl<T, S, const D: usize> TensorMut<T, D> for TensorRange<T, S, D>
707where
708    S: TensorMut<T, D>,
709{
710    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
711        self.source
712            .get_reference_mut(map_indexes_by_range(indexes, &self.range)?)
713    }
714
715    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
716        unsafe {
717            // It is the caller's responsibility to always call with indexes in range,
718            // therefore the unwrap() case should never happen because on an arbitary TensorMut
719            // it would be undefined behavior.
720            // TODO: Can we use unwrap_unchecked here?
721            self.source
722                .get_reference_unchecked_mut(map_indexes_by_range(indexes, &self.range).unwrap())
723        }
724    }
725}
726
727fn map_indexes_by_mask<const D: usize>(indexes: [usize; D], masks: &[IndexRange; D]) -> [usize; D] {
728    let mut mapped = [0; D];
729    for (d, (r, i)) in masks.iter().zip(indexes.into_iter()).enumerate() {
730        mapped[d] = r.mask(i);
731    }
732    mapped
733}
734
735// # Safety
736//
737// The type implementing TensorRef must implement it correctly, so by delegating to it
738// and just hiding some of the valid indexes from view, we implement TensorRef correctly as well.
739/**
740 * A TensorMask implements TensorRef, with the dimension lengths reduced by the mask the
741 * the TensorMask was created with.
742 */
743unsafe impl<T, S, const D: usize> TensorRef<T, D> for TensorMask<T, S, D>
744where
745    S: TensorRef<T, D>,
746{
747    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
748        self.source
749            .get_reference(map_indexes_by_mask(indexes, &self.mask))
750    }
751
752    fn view_shape(&self) -> [(Dimension, usize); D] {
753        // Since when we were constructed we clipped the length of each mask to no more than
754        // our source, we can just return subtract length of each mask now
755        let mut shape = self.source.view_shape();
756        // TODO: zip would work really nicely here but it's not stable yet
757        for (pair, mask) in shape.iter_mut().zip(self.mask.iter()) {
758            pair.1 -= mask.length;
759        }
760        shape
761    }
762
763    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
764        unsafe {
765            // It is the caller's responsibility to always call with indexes in range,
766            // therefore out of bounds lookups created by map_indexes_by_mask should never happen.
767            self.source
768                .get_reference_unchecked(map_indexes_by_mask(indexes, &self.mask))
769        }
770    }
771
772    fn data_layout(&self) -> DataLayout<D> {
773        // Our mask means the view shape no longer matches up to a single
774        // line of data in memory.
775        DataLayout::NonLinear
776    }
777}
778
779// # Safety
780//
781// The type implementing TensorMut must implement it correctly, so by delegating to it
782// and just hiding some of the valid indexes from view, we implement TensorMut correctly as well.
783/**
784 * A TensorMask implements TensorMut, with the dimension lengths reduced by the mask the
785 * the TensorMask was created with.
786 */
787unsafe impl<T, S, const D: usize> TensorMut<T, D> for TensorMask<T, S, D>
788where
789    S: TensorMut<T, D>,
790{
791    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
792        self.source
793            .get_reference_mut(map_indexes_by_mask(indexes, &self.mask))
794    }
795
796    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
797        unsafe {
798            // It is the caller's responsibility to always call with indexes in range,
799            // therefore out of bounds lookups created by map_indexes_by_mask should never happen.
800            self.source
801                .get_reference_unchecked_mut(map_indexes_by_mask(indexes, &self.mask))
802        }
803    }
804}
805
806#[test]
807#[rustfmt::skip]
808fn test_constructors() {
809    use crate::tensors::Tensor;
810    use crate::tensors::views::TensorView;
811    let tensor = Tensor::from([("rows", 3), ("columns", 3)], (0..9).collect());
812    // Happy path
813    assert_eq!(
814        TensorView::from(TensorRange::from(&tensor, [("rows", IndexRange::new(1, 2))]).unwrap()),
815        Tensor::from([("rows", 2), ("columns", 3)], vec![
816            3, 4, 5,
817            6, 7, 8
818        ])
819    );
820    assert_eq!(
821        TensorView::from(TensorRange::from(&tensor, [("columns", 2..3)]).unwrap()),
822        Tensor::from([("rows", 3), ("columns", 1)], vec![
823            2,
824            5,
825            8
826        ])
827    );
828    assert_eq!(
829        TensorView::from(TensorRange::from(&tensor, [("rows", (1, 1)), ("columns", (2, 1))]).unwrap()),
830        Tensor::from([("rows", 1), ("columns", 1)], vec![5])
831    );
832    assert_eq!(
833        TensorView::from(TensorRange::from(&tensor, [("columns", 1..3)]).unwrap()),
834        Tensor::from([("rows", 3), ("columns", 2)], vec![
835            1, 2,
836            4, 5,
837            7, 8
838        ])
839    );
840
841    assert_eq!(
842        TensorView::from(TensorMask::from(&tensor, [("rows", IndexRange::new(1, 1))]).unwrap()),
843        Tensor::from([("rows", 2), ("columns", 3)], vec![
844            0, 1, 2,
845            6, 7, 8
846        ])
847    );
848    assert_eq!(
849        TensorView::from(TensorMask::from(&tensor, [("rows", 2..3), ("columns", 0..1)]).unwrap()),
850        Tensor::from([("rows", 2), ("columns", 2)], vec![
851            1, 2,
852            4, 5
853        ])
854    );
855
856    use IndexRangeValidationError as IRVError;
857    use InvalidShapeError as ShapeError;
858    use StrictIndexRangeValidationError::Error as SError;
859    use StrictIndexRangeValidationError::OutsideShape as OutsideShape;
860    use InvalidDimensionsError as DError;
861    // Dimension names that aren't present
862    assert_eq!(
863        TensorRange::from(&tensor, [("invalid", 1..2)]).unwrap_err(),
864        IRVError::InvalidDimensions(DError::new(["invalid"], ["rows", "columns"]))
865    );
866    assert_eq!(
867        TensorMask::from(&tensor, [("wrong", 0..1)]).unwrap_err(),
868        IRVError::InvalidDimensions(DError::new(["wrong"], ["rows", "columns"]))
869    );
870    assert_eq!(
871        TensorRange::from_strict(&tensor, [("invalid", 1..2)]).unwrap_err(),
872        SError(IRVError::InvalidDimensions(DError::new(["invalid"], ["rows", "columns"])))
873    );
874    assert_eq!(
875        TensorMask::from_strict(&tensor, [("wrong", 0..1)]).unwrap_err(),
876        SError(IRVError::InvalidDimensions(DError::new(["wrong"], ["rows", "columns"])))
877    );
878
879    // Mask / Range creates a 0 length dimension
880    assert_eq!(
881        TensorRange::from(&tensor, [("rows", 0..0)]).unwrap_err(),
882        IRVError::InvalidShape(ShapeError::new([("rows", 0), ("columns", 3)]))
883    );
884    assert_eq!(
885        TensorMask::from(&tensor, [("columns", 0..3)]).unwrap_err(),
886        IRVError::InvalidShape(ShapeError::new([("rows", 3), ("columns", 0)]))
887    );
888    assert_eq!(
889        TensorRange::from_strict(&tensor, [("rows", 0..0)]).unwrap_err(),
890        SError(IRVError::InvalidShape(ShapeError::new([("rows", 0), ("columns", 3)])))
891    );
892    assert_eq!(
893        TensorMask::from_strict(&tensor, [("columns", 0..3)]).unwrap_err(),
894        SError(IRVError::InvalidShape(ShapeError::new([("rows", 3), ("columns", 0)])))
895    );
896
897    // Dimension name specified twice
898    assert_eq!(
899        TensorRange::from(&tensor, [("rows", 1..2), ("rows", 2..3)]).unwrap_err(),
900        IRVError::InvalidDimensions(DError::new(["rows", "rows"], ["rows", "columns"]))
901    );
902    assert_eq!(
903        TensorMask::from(&tensor, [("columns", 1..2), ("columns", 2..3)]).unwrap_err(),
904        IRVError::InvalidDimensions(DError::new(["columns", "columns"], ["rows", "columns"]))
905    );
906    assert_eq!(
907        TensorRange::from_strict(&tensor, [("rows", 1..2), ("rows", 2..3)]).unwrap_err(),
908        SError(IRVError::InvalidDimensions(DError::new(["rows", "rows"], ["rows", "columns"])))
909    );
910    assert_eq!(
911        TensorMask::from_strict(&tensor, [("columns", 1..2), ("columns", 2..3)]).unwrap_err(),
912        SError(IRVError::InvalidDimensions(DError::new(["columns", "columns"], ["rows", "columns"])))
913    );
914
915    // Mask / Range needs clipping
916    assert!(
917        TensorView::from(TensorRange::from(&tensor, [("rows", 0..4)]).unwrap()).eq(&tensor),
918    );
919    assert_eq!(
920        TensorRange::from_strict(&tensor, [("rows", 0..4)]).unwrap_err(),
921        OutsideShape {
922            shape: [("rows", 3), ("columns", 3)],
923            index_range: [Some(IndexRange::new(0, 4)), None],
924        }
925    );
926    assert_eq!(
927        TensorView::from(TensorMask::from(&tensor, [("columns", 1..4)]).unwrap()),
928        Tensor::from([("rows", 3), ("columns", 1)], vec![
929            0,
930            3,
931            6,
932        ])
933    );
934    assert_eq!(
935        TensorMask::from_strict(&tensor, [("columns", 1..4)]).unwrap_err(),
936        OutsideShape {
937            shape: [("rows", 3), ("columns", 3)],
938            index_range: [None, Some(IndexRange::new(1, 3))],
939        }
940    );
941}