Struct easy_ml::tensors::views::TensorMask
source · pub struct TensorMask<T, S, const D: usize> { /* private fields */ }
Expand description
A mask over a tensor in D dimensions, hiding the values inside the range from view.
The entire source is still owned by the TensorMask however, so this does not permit creating multiple mutable masks into a single tensor even if they wouldn’t overlap.
See also: TensorRange
use easy_ml::tensors::Tensor;
use easy_ml::tensors::views::{TensorView, TensorMask};
let numbers = Tensor::from([("batch", 4), ("rows", 8), ("columns", 8)], vec![
0, 0, 0, 1, 1, 0, 0, 0,
0, 0, 1, 1, 1, 0, 0, 0,
0, 0, 0, 1, 1, 0, 0, 0,
0, 0, 0, 1, 1, 0, 0, 0,
0, 0, 0, 1, 1, 0, 0, 0,
0, 0, 0, 1, 1, 0, 0, 0,
0, 0, 1, 1, 1, 1, 0, 0,
0, 0, 1, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 2, 2, 0, 0, 0,
0, 0, 2, 0, 0, 2, 0, 0,
0, 0, 0, 0, 0, 2, 0, 0,
0, 0, 0, 0, 2, 0, 0, 0,
0, 0, 0, 2, 0, 0, 0, 0,
0, 0, 2, 0, 0, 0, 0, 0,
0, 0, 2, 2, 2, 2, 0, 0,
0, 0, 0, 3, 3, 0, 0, 0,
0, 0, 3, 0, 0, 3, 0, 0,
0, 0, 0, 0, 0, 3, 0, 0,
0, 0, 0, 0, 3, 0, 0, 0,
0, 0, 0, 0, 3, 0, 0, 0,
0, 0, 0, 0, 0, 3, 0, 0,
0, 0, 3, 0, 0, 3, 0, 0,
0, 0, 0, 3, 3, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 4, 0, 0, 0,
0, 0, 0, 4, 4, 0, 0, 0,
0, 0, 4, 0, 4, 0, 0, 0,
0, 4, 4, 4, 4, 4, 0, 0,
0, 0, 0, 0, 4, 0, 0, 0,
0, 0, 0, 0, 4, 0, 0, 0,
0, 0, 0, 0, 4, 0, 0, 0
]);
let one_and_four = TensorView::from(
TensorMask::from(&numbers, [("batch", 1..3)])
.expect("Input is constucted so that our mask is valid")
);
let corners = TensorView::from(
TensorMask::from(&numbers, [("rows", [3, 2]), ("columns", [3, 2])])
.expect("Input is constucted so that our mask is valid")
);
assert_eq!(one_and_four.shape(), [("batch", 2), ("rows", 8), ("columns", 8)]);
assert_eq!(corners.shape(), [("batch", 4), ("rows", 6), ("columns", 6)]);
println!("{}", corners.select([("batch", 2)]));
// D = 2
// ("rows", 6), ("columns", 6)
// [ 0, 0, 0, 0, 0, 0
// 0, 0, 3, 3, 0, 0
// 0, 0, 0, 3, 0, 0
// 0, 0, 0, 3, 0, 0
// 0, 0, 3, 3, 0, 0
// 0, 0, 0, 0, 0, 0 ]
Implementations§
source§impl<T, S, const D: usize> TensorMask<T, S, D>where
S: TensorRef<T, D>,
impl<T, S, const D: usize> TensorMask<T, S, D>where
S: TensorRef<T, D>,
sourcepub fn from<R, const P: usize>(
source: S,
masks: [(Dimension, R); P]
) -> Result<TensorMask<T, S, D>, IndexRangeValidationError<D, P>>where
R: Into<IndexRange>,
pub fn from<R, const P: usize>(
source: S,
masks: [(Dimension, R); P]
) -> Result<TensorMask<T, S, D>, IndexRangeValidationError<D, P>>where
R: Into<IndexRange>,
Constructs a TensorMask from a tensor and set of dimension name/mask pairs.
Returns the Err variant if any masked dimension would have a length of 0, if multiple pairs with the same name are provided, or if any dimension names aren’t in the source.
sourcepub fn from_strict<R, const P: usize>(
source: S,
masks: [(Dimension, R); P]
) -> Result<TensorMask<T, S, D>, StrictIndexRangeValidationError<D, P>>where
R: Into<IndexRange>,
pub fn from_strict<R, const P: usize>(
source: S,
masks: [(Dimension, R); P]
) -> Result<TensorMask<T, S, D>, StrictIndexRangeValidationError<D, P>>where
R: Into<IndexRange>,
Constructs a TensorMask from a tensor and set of dimension name/range pairs.
Returns the Err variant if any masked dimension would have a length of 0, if multiple pairs with the same name are provided, or if any dimension names aren’t in the source, or any mask extends beyond the length of that dimension in the tensor.
sourcepub fn from_all<R>(
source: S,
mask: [Option<R>; D]
) -> Result<TensorMask<T, S, D>, InvalidShapeError<D>>where
R: Into<IndexRange>,
pub fn from_all<R>(
source: S,
mask: [Option<R>; D]
) -> Result<TensorMask<T, S, D>, InvalidShapeError<D>>where
R: Into<IndexRange>,
Constructs a TensorMask from a tensor and a mask for each dimension in the tensor (provided in the same order as the tensor’s shape).
Returns the Err variant if any masked dimension would have a length of 0.
sourcepub fn from_all_strict<R>(
source: S,
masks: [Option<R>; D]
) -> Result<TensorMask<T, S, D>, StrictIndexRangeValidationError<D, D>>where
R: Into<IndexRange>,
pub fn from_all_strict<R>(
source: S,
masks: [Option<R>; D]
) -> Result<TensorMask<T, S, D>, StrictIndexRangeValidationError<D, D>>where
R: Into<IndexRange>,
Constructs a TensorMask from a tensor and a mask for each dimension in the tensor (provided in the same order as the tensor’s shape), ensuring the mask is within the lengths of the tensor.
Returns the Err variant if any masked dimension would have a length of 0 or any mask extends beyond the length of that dimension in the tensor.
Trait Implementations§
source§impl<T: Clone, S: Clone, const D: usize> Clone for TensorMask<T, S, D>
impl<T: Clone, S: Clone, const D: usize> Clone for TensorMask<T, S, D>
source§fn clone(&self) -> TensorMask<T, S, D>
fn clone(&self) -> TensorMask<T, S, D>
1.0.0 · source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source
. Read moresource§impl<T, S, const D: usize> TensorMut<T, D> for TensorMask<T, S, D>where
S: TensorMut<T, D>,
impl<T, S, const D: usize> TensorMut<T, D> for TensorMask<T, S, D>where
S: TensorMut<T, D>,
A TensorMask implements TensorMut, with the dimension lengths reduced by the mask the the TensorMask was created with.
source§fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T>
fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T>
source§unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T
unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T
source§impl<T, S, const D: usize> TensorRef<T, D> for TensorMask<T, S, D>where
S: TensorRef<T, D>,
impl<T, S, const D: usize> TensorRef<T, D> for TensorMask<T, S, D>where
S: TensorRef<T, D>,
A TensorMask implements TensorRef, with the dimension lengths reduced by the mask the the TensorMask was created with.