easy_ml/tensors/views/
ranges.rs

1use crate::tensors::dimensions;
2use crate::tensors::views::{DataLayout, TensorMut, TensorRef, TensorView};
3use crate::tensors::{Dimension, InvalidDimensionsError, InvalidShapeError};
4use std::error::Error;
5use std::fmt;
6use std::marker::PhantomData;
7use std::num::NonZeroUsize;
8
9pub use crate::matrices::views::IndexRange;
10
11/**
12 * A range over a tensor in D dimensions, hiding the values **outside** the range from view.
13 *
14 * The entire source is still owned by the TensorRange however, so this does not permit
15 * creating multiple mutable ranges into a single tensor even if they wouldn't overlap.
16 *
17 * See also: [TensorMask](TensorMask)
18 *
19 * ```
20 * use easy_ml::tensors::Tensor;
21 * use easy_ml::tensors::views::{TensorView, TensorRange};
22 * let numbers = Tensor::from([("batch", 4), ("rows", 8), ("columns", 8)], vec![
23 *     0, 0, 0, 1, 1, 0, 0, 0,
24 *     0, 0, 1, 1, 1, 0, 0, 0,
25 *     0, 0, 0, 1, 1, 0, 0, 0,
26 *     0, 0, 0, 1, 1, 0, 0, 0,
27 *     0, 0, 0, 1, 1, 0, 0, 0,
28 *     0, 0, 0, 1, 1, 0, 0, 0,
29 *     0, 0, 1, 1, 1, 1, 0, 0,
30 *     0, 0, 1, 1, 1, 1, 0, 0,
31 *
32 *     0, 0, 0, 0, 0, 0, 0, 0,
33 *     0, 0, 0, 2, 2, 0, 0, 0,
34 *     0, 0, 2, 0, 0, 2, 0, 0,
35 *     0, 0, 0, 0, 0, 2, 0, 0,
36 *     0, 0, 0, 0, 2, 0, 0, 0,
37 *     0, 0, 0, 2, 0, 0, 0, 0,
38 *     0, 0, 2, 0, 0, 0, 0, 0,
39 *     0, 0, 2, 2, 2, 2, 0, 0,
40 *
41 *     0, 0, 0, 3, 3, 0, 0, 0,
42 *     0, 0, 3, 0, 0, 3, 0, 0,
43 *     0, 0, 0, 0, 0, 3, 0, 0,
44 *     0, 0, 0, 0, 3, 0, 0, 0,
45 *     0, 0, 0, 0, 3, 0, 0, 0,
46 *     0, 0, 0, 0, 0, 3, 0, 0,
47 *     0, 0, 3, 0, 0, 3, 0, 0,
48 *     0, 0, 0, 3, 3, 0, 0, 0,
49 *
50 *     0, 0, 0, 0, 0, 0, 0, 0,
51 *     0, 0, 0, 0, 4, 0, 0, 0,
52 *     0, 0, 0, 4, 4, 0, 0, 0,
53 *     0, 0, 4, 0, 4, 0, 0, 0,
54 *     0, 4, 4, 4, 4, 4, 0, 0,
55 *     0, 0, 0, 0, 4, 0, 0, 0,
56 *     0, 0, 0, 0, 4, 0, 0, 0,
57 *     0, 0, 0, 0, 4, 0, 0, 0
58 * ]);
59 * let one_and_two = TensorView::from(
60 *     TensorRange::from(&numbers, [("batch", 0..2)])
61 *         .expect("Input is constucted so that our range is valid")
62 * );
63 * let framed = TensorView::from(
64 *     TensorRange::from(&numbers, [("rows", [1, 6]), ("columns", [1, 6])])
65 *         .expect("Input is constucted so that our range is valid")
66 * );
67 * assert_eq!(one_and_two.shape(), [("batch", 2), ("rows", 8), ("columns", 8)]);
68 * assert_eq!(framed.shape(), [("batch", 4), ("rows", 6), ("columns", 6)]);
69 * println!("{}", framed.select([("batch", 3)]));
70 * // D = 2
71 * // ("rows", 6), ("columns", 6)
72 * // [ 0, 0, 0, 4, 0, 0
73 * //   0, 0, 4, 4, 0, 0
74 * //   0, 4, 0, 4, 0, 0
75 * //   4, 4, 4, 4, 4, 0
76 * //   0, 0, 0, 4, 0, 0
77 * //   0, 0, 0, 4, 0, 0 ]
78 * ```
79 */
80#[derive(Clone, Debug)]
81pub struct TensorRange<T, S, const D: usize> {
82    source: S,
83    range: [IndexRange; D],
84    _type: PhantomData<T>,
85}
86
87/**
88 * A mask over a tensor in D dimensions, hiding the values **inside** the range from view.
89 *
90 * The entire source is still owned by the TensorMask however, so this does not permit
91 * creating multiple mutable masks into a single tensor even if they wouldn't overlap.
92 *
93 * See also: [TensorRange](TensorRange)
94 *
95 * ```
96 * use easy_ml::tensors::Tensor;
97 * use easy_ml::tensors::views::{TensorView, TensorMask};
98 * let numbers = Tensor::from([("batch", 4), ("rows", 8), ("columns", 8)], vec![
99 *     0, 0, 0, 1, 1, 0, 0, 0,
100 *     0, 0, 1, 1, 1, 0, 0, 0,
101 *     0, 0, 0, 1, 1, 0, 0, 0,
102 *     0, 0, 0, 1, 1, 0, 0, 0,
103 *     0, 0, 0, 1, 1, 0, 0, 0,
104 *     0, 0, 0, 1, 1, 0, 0, 0,
105 *     0, 0, 1, 1, 1, 1, 0, 0,
106 *     0, 0, 1, 1, 1, 1, 0, 0,
107 *
108 *     0, 0, 0, 0, 0, 0, 0, 0,
109 *     0, 0, 0, 2, 2, 0, 0, 0,
110 *     0, 0, 2, 0, 0, 2, 0, 0,
111 *     0, 0, 0, 0, 0, 2, 0, 0,
112 *     0, 0, 0, 0, 2, 0, 0, 0,
113 *     0, 0, 0, 2, 0, 0, 0, 0,
114 *     0, 0, 2, 0, 0, 0, 0, 0,
115 *     0, 0, 2, 2, 2, 2, 0, 0,
116 *
117 *     0, 0, 0, 3, 3, 0, 0, 0,
118 *     0, 0, 3, 0, 0, 3, 0, 0,
119 *     0, 0, 0, 0, 0, 3, 0, 0,
120 *     0, 0, 0, 0, 3, 0, 0, 0,
121 *     0, 0, 0, 0, 3, 0, 0, 0,
122 *     0, 0, 0, 0, 0, 3, 0, 0,
123 *     0, 0, 3, 0, 0, 3, 0, 0,
124 *     0, 0, 0, 3, 3, 0, 0, 0,
125 *
126 *     0, 0, 0, 0, 0, 0, 0, 0,
127 *     0, 0, 0, 0, 4, 0, 0, 0,
128 *     0, 0, 0, 4, 4, 0, 0, 0,
129 *     0, 0, 4, 0, 4, 0, 0, 0,
130 *     0, 4, 4, 4, 4, 4, 0, 0,
131 *     0, 0, 0, 0, 4, 0, 0, 0,
132 *     0, 0, 0, 0, 4, 0, 0, 0,
133 *     0, 0, 0, 0, 4, 0, 0, 0
134 * ]);
135 * let one_and_four = TensorView::from(
136 *     TensorMask::from(&numbers, [("batch", 1..3)])
137 *         .expect("Input is constucted so that our mask is valid")
138 * );
139 * let corners = TensorView::from(
140 *     TensorMask::from(&numbers, [("rows", [3, 2]), ("columns", [3, 2])])
141 *         .expect("Input is constucted so that our mask is valid")
142 * );
143 * assert_eq!(one_and_four.shape(), [("batch", 2), ("rows", 8), ("columns", 8)]);
144 * assert_eq!(corners.shape(), [("batch", 4), ("rows", 6), ("columns", 6)]);
145 * println!("{}", corners.select([("batch", 2)]));
146 * // D = 2
147 * // ("rows", 6), ("columns", 6)
148 * // [ 0, 0, 0, 0, 0, 0
149 * //   0, 0, 3, 3, 0, 0
150 * //   0, 0, 0, 3, 0, 0
151 * //   0, 0, 0, 3, 0, 0
152 * //   0, 0, 3, 3, 0, 0
153 * //   0, 0, 0, 0, 0, 0 ]
154 * ```
155 */
156#[derive(Clone, Debug)]
157pub struct TensorMask<T, S, const D: usize> {
158    source: S,
159    mask: [IndexRange; D],
160    _type: PhantomData<T>,
161}
162
163/**
164 * An error in creating a [TensorRange](TensorRange) or a [TensorMask](TensorMask).
165 */
166#[derive(Clone, Debug, Eq, PartialEq)]
167pub enum IndexRangeValidationError<const D: usize, const P: usize> {
168    /**
169     * The shape that resulting Tensor would have would not be valid.
170     */
171    InvalidShape(InvalidShapeError<D>),
172    /**
173     * Multiple of the same dimension name were provided, but we can only take one mask or range
174     * for each dimension at a time.
175     */
176    InvalidDimensions(InvalidDimensionsError<D, P>),
177}
178
179impl<const D: usize, const P: usize> fmt::Display for IndexRangeValidationError<D, P> {
180    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181        match self {
182            IndexRangeValidationError::InvalidShape(error) => write!(f, "{:?}", error),
183            IndexRangeValidationError::InvalidDimensions(error) => write!(f, "{:?}", error),
184        }
185    }
186}
187
188impl<const D: usize, const P: usize> Error for IndexRangeValidationError<D, P> {
189    fn source(&self) -> Option<&(dyn Error + 'static)> {
190        match self {
191            IndexRangeValidationError::InvalidShape(error) => Some(error),
192            IndexRangeValidationError::InvalidDimensions(error) => Some(error),
193        }
194    }
195}
196
197/**
198 * An error in creating a [TensorRange](TensorRange) or a [TensorMask](TensorMask) using
199 * strict validation.
200 */
201#[derive(Clone, Debug, Eq, PartialEq)]
202pub enum StrictIndexRangeValidationError<const D: usize, const P: usize> {
203    /**
204     * In at least one dimension, the mask or range provided exceeds the bounds of the shape
205     * of the Tensor it was to be used on. This is not necessarily an issue as the mask or
206     * range could be clipped to the bounds of the Tensor's shape, but a constructor which
207     * rejects out of bounds input was used.
208     */
209    OutsideShape {
210        shape: [(Dimension, usize); D],
211        index_range: [Option<IndexRange>; D],
212    },
213    Error(IndexRangeValidationError<D, P>),
214}
215
216impl<const D: usize, const P: usize> fmt::Display for StrictIndexRangeValidationError<D, P> {
217    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218        use StrictIndexRangeValidationError as S;
219        match self {
220            S::OutsideShape { shape, index_range } => write!(
221                f,
222                "IndexRange array {:?} is out of bounds of shape {:?}",
223                index_range, shape
224            ),
225            S::Error(error) => write!(f, "{:?}", error),
226        }
227    }
228}
229
230impl<const D: usize, const P: usize> Error for StrictIndexRangeValidationError<D, P> {
231    fn source(&self) -> Option<&(dyn Error + 'static)> {
232        use StrictIndexRangeValidationError as S;
233        match self {
234            S::OutsideShape {
235                shape: _,
236                index_range: _,
237            } => None,
238            S::Error(error) => Some(error),
239        }
240    }
241}
242
243fn from_named_to_all_specific_error<T, S, R, const D: usize, const P: usize>(
244    source: &S,
245    ranges: [(Dimension, R); P],
246) -> Result<[Option<IndexRange>; D], InvalidDimensionsError<D, P>>
247where
248    S: TensorRef<T, D>,
249    R: Into<IndexRange>,
250{
251    let shape = source.view_shape();
252    let ranges = ranges.map(|(d, r)| (d, r.into()));
253    let dimensions = InvalidDimensionsError {
254        provided: ranges.clone().map(|(d, _)| d),
255        valid: shape.map(|(d, _)| d),
256    };
257    if dimensions.has_duplicates() {
258        return Err(dimensions);
259    }
260    // Since we now know there's no duplicates, we can lookup the dimension index for each name
261    // in the shape and we know we'll get different indexes on each lookup.
262    let mut all_ranges: [Option<IndexRange>; D] = std::array::from_fn(|_| None);
263    for (name, range) in ranges.into_iter() {
264        match crate::tensors::dimensions::position_of(&shape, name) {
265            Some(d) => all_ranges[d] = Some(range),
266            None => return Err(dimensions),
267        };
268    }
269    Ok(all_ranges)
270}
271
272fn from_named_to_all<T, S, R, const D: usize, const P: usize>(
273    source: &S,
274    ranges: [(Dimension, R); P],
275) -> Result<[Option<IndexRange>; D], IndexRangeValidationError<D, P>>
276where
277    S: TensorRef<T, D>,
278    R: Into<IndexRange>,
279{
280    from_named_to_all_specific_error(source, ranges)
281        .map_err(|error| IndexRangeValidationError::InvalidDimensions(error))
282}
283
284impl<T, S, const D: usize> TensorRange<T, S, D>
285where
286    S: TensorRef<T, D>,
287{
288    /**
289     * Constructs a TensorRange from a tensor and set of dimension name/range pairs.
290     *
291     * Returns the Err variant if any dimension would have a length of 0 after applying the
292     * ranges, if multiple pairs with the same name are provided, or if any dimension names aren't
293     * in the source.
294     */
295    pub fn from<R, const P: usize>(
296        source: S,
297        ranges: [(Dimension, R); P],
298    ) -> Result<TensorRange<T, S, D>, IndexRangeValidationError<D, P>>
299    where
300        R: Into<IndexRange>,
301    {
302        let all_ranges = from_named_to_all(&source, ranges)?;
303        match TensorRange::from_all(source, all_ranges) {
304            Ok(tensor_range) => Ok(tensor_range),
305            Err(invalid_shape) => Err(IndexRangeValidationError::InvalidShape(invalid_shape)),
306        }
307    }
308
309    /**
310     * Constructs a TensorRange from a tensor and set of dimension name/range pairs.
311     *
312     * Returns the Err variant if any dimension would have a length of 0 after applying the
313     * ranges, if multiple pairs with the same name are provided, or if any dimension names aren't
314     * in the source, or any range extends beyond the length of that dimension in the tensor.
315     */
316    pub fn from_strict<R, const P: usize>(
317        source: S,
318        ranges: [(Dimension, R); P],
319    ) -> Result<TensorRange<T, S, D>, StrictIndexRangeValidationError<D, P>>
320    where
321        R: Into<IndexRange>,
322    {
323        use StrictIndexRangeValidationError as S;
324        let all_ranges = match from_named_to_all(&source, ranges) {
325            Ok(all_ranges) => all_ranges,
326            Err(error) => return Err(S::Error(error)),
327        };
328        match TensorRange::from_all_strict(source, all_ranges) {
329            Ok(tensor_range) => Ok(tensor_range),
330            Err(S::OutsideShape { shape, index_range }) => {
331                Err(S::OutsideShape { shape, index_range })
332            }
333            Err(S::Error(IndexRangeValidationError::InvalidShape(error))) => {
334                Err(S::Error(IndexRangeValidationError::InvalidShape(error)))
335            }
336            Err(S::Error(IndexRangeValidationError::InvalidDimensions(_))) => panic!(
337                "Unexpected InvalidDimensions error case after validating for InvalidDimensions already"
338            ),
339        }
340    }
341
342    /**
343     * Constructs a TensorRange from a tensor and a range for each dimension in the tensor
344     * (provided in the same order as the tensor's shape).
345     *
346     * Returns the Err variant if any dimension would have a length of 0 after applying the ranges.
347     */
348    pub fn from_all<R>(
349        source: S,
350        ranges: [Option<R>; D],
351    ) -> Result<TensorRange<T, S, D>, InvalidShapeError<D>>
352    where
353        R: Into<IndexRange>,
354    {
355        TensorRange::clip_from(
356            source,
357            ranges.map(|option| option.map(|range| range.into())),
358        )
359    }
360
361    fn clip_from(
362        source: S,
363        ranges: [Option<IndexRange>; D],
364    ) -> Result<TensorRange<T, S, D>, InvalidShapeError<D>> {
365        let shape = source.view_shape();
366        let mut ranges = std::array::from_fn(|d| {
367            ranges[d]
368                .clone()
369                .unwrap_or_else(|| IndexRange::new(0, shape[d].1))
370        });
371        let shape = InvalidShapeError {
372            shape: clip_range_shape(&shape, &mut ranges),
373        };
374        if !shape.is_valid() {
375            return Err(shape);
376        }
377
378        Ok(TensorRange {
379            source,
380            range: ranges,
381            _type: PhantomData,
382        })
383    }
384
385    /**
386     * Constructs a TensorRange from a tensor and a range for each dimension in the tensor
387     * (provided in the same order as the tensor's shape), ensuring the range is within the
388     * lengths of the tensor.
389     *
390     * Returns the Err variant if any dimension would have a length of 0 after applying the
391     * ranges or any range extends beyond the length of that dimension in the tensor.
392     */
393    pub fn from_all_strict<R>(
394        source: S,
395        range: [Option<R>; D],
396    ) -> Result<TensorRange<T, S, D>, StrictIndexRangeValidationError<D, D>>
397    where
398        R: Into<IndexRange>,
399    {
400        let shape = source.view_shape();
401        let range = range.map(|option| option.map(|range| range.into()));
402        if range_exceeds_bounds(&shape, &range) {
403            return Err(StrictIndexRangeValidationError::OutsideShape {
404                shape,
405                index_range: range,
406            });
407        }
408
409        match TensorRange::clip_from(source, range) {
410            Ok(tensor_range) => Ok(tensor_range),
411            Err(invalid_shape) => Err(StrictIndexRangeValidationError::Error(
412                IndexRangeValidationError::InvalidShape(invalid_shape),
413            )),
414        }
415    }
416
417    /**
418     * Consumes the TensorRange, yielding the source it was created from.
419     */
420    #[allow(dead_code)]
421    pub fn source(self) -> S {
422        self.source
423    }
424
425    /**
426     * Gives a reference to the TensorRange's source (in which the data is not clipped).
427     */
428    // # Safety
429    //
430    // Giving out a mutable reference to our source could allow it to be changed out from under us
431    // and make our range checks invalid. However, since the source implements TensorRef
432    // interior mutability is not allowed, so we can give out shared references without breaking
433    // our own integrity.
434    #[allow(dead_code)]
435    pub fn source_ref(&self) -> &S {
436        &self.source
437    }
438}
439
440fn range_exceeds_bounds<const D: usize>(
441    source: &[(Dimension, usize); D],
442    range: &[Option<IndexRange>; D],
443) -> bool {
444    for (d, (_, end)) in source.iter().enumerate() {
445        let end = *end;
446        match &range[d] {
447            None => continue,
448            Some(range) => {
449                let range_end = range.start + range.length;
450                if range_end > end {
451                    return true;
452                };
453            }
454        }
455    }
456    false
457}
458
459// Returns the shape the tensor's shape will be left as with the range applied, clipping any
460// ranges that exceed the bounds of the tensor's shape.
461fn clip_range_shape<const D: usize>(
462    source: &[(Dimension, usize); D],
463    range: &mut [IndexRange; D],
464) -> [(Dimension, usize); D] {
465    let mut shape = *source;
466    for (d, (_, length)) in shape.iter_mut().enumerate() {
467        let range = &mut range[d];
468        range.clip(*length);
469        // the length that remains is the length of the range
470        *length = range.length;
471    }
472    shape
473}
474
475impl<T, S, const D: usize> TensorMask<T, S, D>
476where
477    S: TensorRef<T, D>,
478{
479    /**
480     * Constructs a TensorMask from a tensor and set of dimension name/mask pairs.
481     *
482     * Returns the Err variant if any masked dimension would have a length of 0, if multiple
483     * pairs with the same name are provided, or if any dimension names aren't in the source.
484     */
485    pub fn from<R, const P: usize>(
486        source: S,
487        masks: [(Dimension, R); P],
488    ) -> Result<TensorMask<T, S, D>, IndexRangeValidationError<D, P>>
489    where
490        R: Into<IndexRange>,
491    {
492        let all_masks = from_named_to_all(&source, masks)?;
493        match TensorMask::from_all(source, all_masks) {
494            Ok(tensor_mask) => Ok(tensor_mask),
495            Err(invalid_shape) => Err(IndexRangeValidationError::InvalidShape(invalid_shape)),
496        }
497    }
498
499    /**
500     * Constructs a TensorMask from a tensor and set of dimension name/range pairs.
501     *
502     * Returns the Err variant if any masked dimension would have a length of 0, if multiple
503     * pairs with the same name are provided, or if any dimension names aren't in the source,
504     * or any mask extends beyond the length of that dimension in the tensor.
505     */
506    pub fn from_strict<R, const P: usize>(
507        source: S,
508        masks: [(Dimension, R); P],
509    ) -> Result<TensorMask<T, S, D>, StrictIndexRangeValidationError<D, P>>
510    where
511        R: Into<IndexRange>,
512    {
513        use StrictIndexRangeValidationError as S;
514        let all_masks = match from_named_to_all(&source, masks) {
515            Ok(all_masks) => all_masks,
516            Err(error) => return Err(S::Error(error)),
517        };
518        match TensorMask::from_all_strict(source, all_masks) {
519            Ok(tensor_mask) => Ok(tensor_mask),
520            Err(S::OutsideShape { shape, index_range }) => {
521                Err(S::OutsideShape { shape, index_range })
522            }
523            Err(S::Error(IndexRangeValidationError::InvalidShape(error))) => {
524                Err(S::Error(IndexRangeValidationError::InvalidShape(error)))
525            }
526            Err(S::Error(IndexRangeValidationError::InvalidDimensions(_))) => panic!(
527                "Unexpected InvalidDimensions error case after validating for InvalidDimensions already"
528            ),
529        }
530    }
531
532    /**
533     * Constructs a TensorMask from a tensor and a mask for each dimension in the tensor
534     * (provided in the same order as the tensor's shape).
535     *
536     * Returns the Err variant if any masked dimension would have a length of 0.
537     */
538    pub fn from_all<R>(
539        source: S,
540        mask: [Option<R>; D],
541    ) -> Result<TensorMask<T, S, D>, InvalidShapeError<D>>
542    where
543        R: Into<IndexRange>,
544    {
545        TensorMask::clip_from(source, mask.map(|option| option.map(|mask| mask.into())))
546    }
547
548    fn clip_from(
549        source: S,
550        masks: [Option<IndexRange>; D],
551    ) -> Result<TensorMask<T, S, D>, InvalidShapeError<D>> {
552        let shape = source.view_shape();
553        let mut masks = masks.map(|option| option.unwrap_or_else(|| IndexRange::new(0, 0)));
554        let shape = InvalidShapeError {
555            shape: clip_masked_shape(&shape, &mut masks),
556        };
557        if !shape.is_valid() {
558            return Err(shape);
559        }
560
561        Ok(TensorMask {
562            source,
563            mask: masks,
564            _type: PhantomData,
565        })
566    }
567
568    /**
569     * Constructs a TensorMask from a tensor and a mask for each dimension in the tensor
570     * (provided in the same order as the tensor's shape), ensuring the mask is within the
571     * lengths of the tensor.
572     *
573     * Returns the Err variant if any masked dimension would have a length of 0 or any mask
574     * extends beyond the length of that dimension in the tensor.
575     */
576    pub fn from_all_strict<R>(
577        source: S,
578        masks: [Option<R>; D],
579    ) -> Result<TensorMask<T, S, D>, StrictIndexRangeValidationError<D, D>>
580    where
581        R: Into<IndexRange>,
582    {
583        let shape = source.view_shape();
584        let masks = masks.map(|option| option.map(|mask| mask.into()));
585        if mask_exceeds_bounds(&shape, &masks) {
586            return Err(StrictIndexRangeValidationError::OutsideShape {
587                shape,
588                index_range: masks,
589            });
590        }
591
592        match TensorMask::clip_from(source, masks) {
593            Ok(tensor_mask) => Ok(tensor_mask),
594            Err(invalid_shape) => Err(StrictIndexRangeValidationError::Error(
595                IndexRangeValidationError::InvalidShape(invalid_shape),
596            )),
597        }
598    }
599
600    /**
601     * Creates a TensorMask of this source that retains only the specified
602     * number of elements at both the start and end of the dimension provided.
603     * If twice the provided number of elements for a given dimension exceeds the
604     * number of elements actually in that tensor's dimension, then the entire
605     * dimension is retained in full.
606     *
607     * Returns the Err variant if the dimension is not in the source.
608     *
609     * ```
610     * use std::num::NonZeroUsize;
611     * use easy_ml::tensors::Tensor;
612     * use easy_ml::tensors::views::{TensorView, TensorMask};
613     * let tensor = Tensor::from([("x", 5), ("y", 5)], (0..25).collect());
614     * let start_and_end = TensorView::from(
615     *     TensorMask::start_and_end_of(
616     *         tensor, "x", NonZeroUsize::new(1).unwrap()
617     *     ).unwrap()
618     * );
619     * assert_eq!(
620     *     start_and_end,
621     *     Tensor::from([("x", 2), ("y", 5)], vec![
622     *          0,  1,  2,  3,  4,
623     *         20, 21, 22, 23, 24,
624     *     ])
625     * );
626     * ```
627     */
628    pub fn start_and_end_of(
629        source: S,
630        dimension: Dimension,
631        start_and_end: NonZeroUsize,
632    ) -> Result<TensorMask<T, S, D>, InvalidDimensionsError<D, 1>> {
633        let shape = source.view_shape();
634        let range = match dimensions::length_of(&shape, dimension) {
635            None => {
636                return Err(InvalidDimensionsError::new(
637                    [dimension],
638                    dimensions::names_of(&shape),
639                ));
640            }
641            Some(length) => {
642                let x = start_and_end.get();
643                let retain_start = std::cmp::min(x, length - 1);
644                let retain_end = length.saturating_sub(x);
645                let mut range: IndexRange = (retain_start..retain_end).into();
646                range.clip(length - 1);
647                range
648            }
649        };
650        Ok(TensorMask {
651            source,
652            mask: std::array::from_fn(|d| {
653                if shape[d].0 == dimension {
654                    range.clone()
655                } else {
656                    IndexRange::new(0, 0)
657                }
658            }),
659            _type: PhantomData,
660        })
661    }
662
663    #[track_caller]
664    pub(crate) fn panicking_start_and_end_of(
665        source: S,
666        dimension: Dimension,
667        start_and_end: usize,
668    ) -> TensorView<T, TensorMask<T, S, D>, D> {
669        match NonZeroUsize::new(start_and_end) {
670            Some(non_zero) => match TensorMask::start_and_end_of(source, dimension, non_zero) {
671                Ok(tensor) => TensorView::from(tensor),
672                Err(error) => panic!(
673                    "Dimension name provided {:?} must be in the set of dimension names in the tensor: {:?}",
674                    dimension, error.valid,
675                ),
676            },
677            None => panic!("start_and_end must be greater than 0"),
678        }
679    }
680
681    /**
682     * Consumes the TensorMask, yielding the source it was created from.
683     */
684    #[allow(dead_code)]
685    pub fn source(self) -> S {
686        self.source
687    }
688
689    /**
690     * Gives a reference to the TensorMask's source (in which the data is not masked).
691     */
692    // # Safety
693    //
694    // Giving out a mutable reference to our source could allow it to be changed out from under us
695    // and make our mask checks invalid. However, since the source implements TensorRef
696    // interior mutability is not allowed, so we can give out shared references without breaking
697    // our own integrity.
698    #[allow(dead_code)]
699    pub fn source_ref(&self) -> &S {
700        &self.source
701    }
702}
703
704// Returns the shape the tensor's shape will be left as with the mask applied, clipping any
705// masks that exceed the bounds of the tensor's shape.
706fn clip_masked_shape<const D: usize>(
707    source: &[(Dimension, usize); D],
708    mask: &mut [IndexRange; D],
709) -> [(Dimension, usize); D] {
710    let mut shape = *source;
711    for (d, (_, length)) in shape.iter_mut().enumerate() {
712        let mask = &mut mask[d];
713        mask.clip(*length);
714        // the length that remains is what is not included along the mask
715        *length -= mask.length;
716    }
717    shape
718}
719
720fn mask_exceeds_bounds<const D: usize>(
721    source: &[(Dimension, usize); D],
722    mask: &[Option<IndexRange>; D],
723) -> bool {
724    // same test for a mask extending past a shape as for a range
725    range_exceeds_bounds(source, mask)
726}
727
728fn map_indexes_by_range<const D: usize>(
729    indexes: [usize; D],
730    ranges: &[IndexRange; D],
731) -> Option<[usize; D]> {
732    let mut mapped = [0; D];
733    for (d, (r, i)) in ranges.iter().zip(indexes.into_iter()).enumerate() {
734        mapped[d] = r.map(i)?;
735    }
736    Some(mapped)
737}
738
739// # Safety
740//
741// The type implementing TensorRef must implement it correctly, so by delegating to it
742// and just hiding some of the valid indexes from view, we implement TensorRef correctly as well.
743/**
744 * A TensorRange implements TensorRef, with the dimension lengths reduced to the range the
745 * the TensorRange was created with.
746 */
747unsafe impl<T, S, const D: usize> TensorRef<T, D> for TensorRange<T, S, D>
748where
749    S: TensorRef<T, D>,
750{
751    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
752        self.source
753            .get_reference(map_indexes_by_range(indexes, &self.range)?)
754    }
755
756    fn view_shape(&self) -> [(Dimension, usize); D] {
757        // Since when we were constructed we clipped the length of each range to no more than
758        // our source, we can just return the length of each range now
759        let mut shape = self.source.view_shape();
760        // TODO: zip would work really nicely here but it's not stable yet
761        for (pair, range) in shape.iter_mut().zip(self.range.iter()) {
762            pair.1 = range.length;
763        }
764        shape
765    }
766
767    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
768        unsafe {
769            // It is the caller's responsibility to always call with indexes in range,
770            // therefore the unwrap() case should never happen because on an arbitary TensorRef
771            // it would be undefined behavior.
772            self.source.get_reference_unchecked(
773                map_indexes_by_range(indexes, &self.range).unwrap_unchecked(),
774            )
775        }
776    }
777
778    fn data_layout(&self) -> DataLayout<D> {
779        // Our range means the view shape no longer matches up to a single
780        // line of data in memory in the general case (ranges in 1D could still be linear
781        // but DataLayout is not very meaningful till we get to 2D anyway).
782        DataLayout::NonLinear
783    }
784}
785
786// # Safety
787//
788// The type implementing TensorMut must implement it correctly, so by delegating to it
789// and just hiding some of the valid indexes from view, we implement TensorMut correctly as well.
790/**
791 * A TensorRange implements TensorMut, with the dimension lengths reduced to the range the
792 * the TensorRange was created with.
793 */
794unsafe impl<T, S, const D: usize> TensorMut<T, D> for TensorRange<T, S, D>
795where
796    S: TensorMut<T, D>,
797{
798    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
799        self.source
800            .get_reference_mut(map_indexes_by_range(indexes, &self.range)?)
801    }
802
803    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
804        unsafe {
805            // It is the caller's responsibility to always call with indexes in range,
806            // therefore the unwrap() case should never happen because on an arbitary TensorMut
807            // it would be undefined behavior.
808            self.source.get_reference_unchecked_mut(
809                map_indexes_by_range(indexes, &self.range).unwrap_unchecked(),
810            )
811        }
812    }
813}
814
815fn map_indexes_by_mask<const D: usize>(indexes: [usize; D], masks: &[IndexRange; D]) -> [usize; D] {
816    let mut mapped = [0; D];
817    for (d, (r, i)) in masks.iter().zip(indexes.into_iter()).enumerate() {
818        mapped[d] = r.mask(i);
819    }
820    mapped
821}
822
823// # Safety
824//
825// The type implementing TensorRef must implement it correctly, so by delegating to it
826// and just hiding some of the valid indexes from view, we implement TensorRef correctly as well.
827/**
828 * A TensorMask implements TensorRef, with the dimension lengths reduced by the mask the
829 * the TensorMask was created with.
830 */
831unsafe impl<T, S, const D: usize> TensorRef<T, D> for TensorMask<T, S, D>
832where
833    S: TensorRef<T, D>,
834{
835    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
836        self.source
837            .get_reference(map_indexes_by_mask(indexes, &self.mask))
838    }
839
840    fn view_shape(&self) -> [(Dimension, usize); D] {
841        // Since when we were constructed we clipped the length of each mask to no more than
842        // our source, we can just return subtract length of each mask now
843        let mut shape = self.source.view_shape();
844        // TODO: zip would work really nicely here but it's not stable yet
845        for (pair, mask) in shape.iter_mut().zip(self.mask.iter()) {
846            pair.1 -= mask.length;
847        }
848        shape
849    }
850
851    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
852        unsafe {
853            // It is the caller's responsibility to always call with indexes in range,
854            // therefore out of bounds lookups created by map_indexes_by_mask should never happen.
855            self.source
856                .get_reference_unchecked(map_indexes_by_mask(indexes, &self.mask))
857        }
858    }
859
860    fn data_layout(&self) -> DataLayout<D> {
861        // Our mask means the view shape no longer matches up to a single
862        // line of data in memory.
863        DataLayout::NonLinear
864    }
865}
866
867// # Safety
868//
869// The type implementing TensorMut must implement it correctly, so by delegating to it
870// and just hiding some of the valid indexes from view, we implement TensorMut correctly as well.
871/**
872 * A TensorMask implements TensorMut, with the dimension lengths reduced by the mask the
873 * the TensorMask was created with.
874 */
875unsafe impl<T, S, const D: usize> TensorMut<T, D> for TensorMask<T, S, D>
876where
877    S: TensorMut<T, D>,
878{
879    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
880        self.source
881            .get_reference_mut(map_indexes_by_mask(indexes, &self.mask))
882    }
883
884    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
885        unsafe {
886            // It is the caller's responsibility to always call with indexes in range,
887            // therefore out of bounds lookups created by map_indexes_by_mask should never happen.
888            self.source
889                .get_reference_unchecked_mut(map_indexes_by_mask(indexes, &self.mask))
890        }
891    }
892}
893
894#[test]
895#[rustfmt::skip]
896fn test_constructors() {
897    use crate::tensors::Tensor;
898    use crate::tensors::views::TensorView;
899    let tensor = Tensor::from([("rows", 3), ("columns", 3)], (0..9).collect());
900    // Happy path
901    assert_eq!(
902        TensorView::from(TensorRange::from(&tensor, [("rows", IndexRange::new(1, 2))]).unwrap()),
903        Tensor::from([("rows", 2), ("columns", 3)], vec![
904            3, 4, 5,
905            6, 7, 8
906        ])
907    );
908    assert_eq!(
909        TensorView::from(TensorRange::from(&tensor, [("columns", 2..3)]).unwrap()),
910        Tensor::from([("rows", 3), ("columns", 1)], vec![
911            2,
912            5,
913            8
914        ])
915    );
916    assert_eq!(
917        TensorView::from(TensorRange::from(&tensor, [("rows", (1, 1)), ("columns", (2, 1))]).unwrap()),
918        Tensor::from([("rows", 1), ("columns", 1)], vec![5])
919    );
920    assert_eq!(
921        TensorView::from(TensorRange::from(&tensor, [("columns", 1..3)]).unwrap()),
922        Tensor::from([("rows", 3), ("columns", 2)], vec![
923            1, 2,
924            4, 5,
925            7, 8
926        ])
927    );
928
929    assert_eq!(
930        TensorView::from(TensorMask::from(&tensor, [("rows", IndexRange::new(1, 1))]).unwrap()),
931        Tensor::from([("rows", 2), ("columns", 3)], vec![
932            0, 1, 2,
933            6, 7, 8
934        ])
935    );
936    assert_eq!(
937        TensorView::from(TensorMask::from(&tensor, [("rows", 2..3), ("columns", 0..1)]).unwrap()),
938        Tensor::from([("rows", 2), ("columns", 2)], vec![
939            1, 2,
940            4, 5
941        ])
942    );
943
944    use IndexRangeValidationError as IRVError;
945    use InvalidShapeError as ShapeError;
946    use StrictIndexRangeValidationError::Error as SError;
947    use StrictIndexRangeValidationError::OutsideShape as OutsideShape;
948    use InvalidDimensionsError as DError;
949    // Dimension names that aren't present
950    assert_eq!(
951        TensorRange::from(&tensor, [("invalid", 1..2)]).unwrap_err(),
952        IRVError::InvalidDimensions(DError::new(["invalid"], ["rows", "columns"]))
953    );
954    assert_eq!(
955        TensorMask::from(&tensor, [("wrong", 0..1)]).unwrap_err(),
956        IRVError::InvalidDimensions(DError::new(["wrong"], ["rows", "columns"]))
957    );
958    assert_eq!(
959        TensorRange::from_strict(&tensor, [("invalid", 1..2)]).unwrap_err(),
960        SError(IRVError::InvalidDimensions(DError::new(["invalid"], ["rows", "columns"])))
961    );
962    assert_eq!(
963        TensorMask::from_strict(&tensor, [("wrong", 0..1)]).unwrap_err(),
964        SError(IRVError::InvalidDimensions(DError::new(["wrong"], ["rows", "columns"])))
965    );
966
967    // Mask / Range creates a 0 length dimension
968    assert_eq!(
969        TensorRange::from(&tensor, [("rows", 0..0)]).unwrap_err(),
970        IRVError::InvalidShape(ShapeError::new([("rows", 0), ("columns", 3)]))
971    );
972    assert_eq!(
973        TensorMask::from(&tensor, [("columns", 0..3)]).unwrap_err(),
974        IRVError::InvalidShape(ShapeError::new([("rows", 3), ("columns", 0)]))
975    );
976    assert_eq!(
977        TensorRange::from_strict(&tensor, [("rows", 0..0)]).unwrap_err(),
978        SError(IRVError::InvalidShape(ShapeError::new([("rows", 0), ("columns", 3)])))
979    );
980    assert_eq!(
981        TensorMask::from_strict(&tensor, [("columns", 0..3)]).unwrap_err(),
982        SError(IRVError::InvalidShape(ShapeError::new([("rows", 3), ("columns", 0)])))
983    );
984
985    // Dimension name specified twice
986    assert_eq!(
987        TensorRange::from(&tensor, [("rows", 1..2), ("rows", 2..3)]).unwrap_err(),
988        IRVError::InvalidDimensions(DError::new(["rows", "rows"], ["rows", "columns"]))
989    );
990    assert_eq!(
991        TensorMask::from(&tensor, [("columns", 1..2), ("columns", 2..3)]).unwrap_err(),
992        IRVError::InvalidDimensions(DError::new(["columns", "columns"], ["rows", "columns"]))
993    );
994    assert_eq!(
995        TensorRange::from_strict(&tensor, [("rows", 1..2), ("rows", 2..3)]).unwrap_err(),
996        SError(IRVError::InvalidDimensions(DError::new(["rows", "rows"], ["rows", "columns"])))
997    );
998    assert_eq!(
999        TensorMask::from_strict(&tensor, [("columns", 1..2), ("columns", 2..3)]).unwrap_err(),
1000        SError(IRVError::InvalidDimensions(DError::new(["columns", "columns"], ["rows", "columns"])))
1001    );
1002
1003    // Mask / Range needs clipping
1004    assert!(
1005        TensorView::from(TensorRange::from(&tensor, [("rows", 0..4)]).unwrap()).eq(&tensor),
1006    );
1007    assert_eq!(
1008        TensorRange::from_strict(&tensor, [("rows", 0..4)]).unwrap_err(),
1009        OutsideShape {
1010            shape: [("rows", 3), ("columns", 3)],
1011            index_range: [Some(IndexRange::new(0, 4)), None],
1012        }
1013    );
1014    assert_eq!(
1015        TensorView::from(TensorMask::from(&tensor, [("columns", 1..4)]).unwrap()),
1016        Tensor::from([("rows", 3), ("columns", 1)], vec![
1017            0,
1018            3,
1019            6,
1020        ])
1021    );
1022    assert_eq!(
1023        TensorMask::from_strict(&tensor, [("columns", 1..4)]).unwrap_err(),
1024        OutsideShape {
1025            shape: [("rows", 3), ("columns", 3)],
1026            index_range: [None, Some(IndexRange::new(1, 3))],
1027        }
1028    );
1029}