ella_tensor/ops/
masked.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
use crate::{Const, Mask, MaskedValue, Shape, Tensor, TensorValue};

pub trait AsMask<S: Shape> {
    fn as_mask(&self) -> Mask<S>;
}

impl<S: Shape> AsMask<S> for Tensor<bool, S> {
    fn as_mask(&self) -> Mask<S> {
        self.into()
    }
}

impl<S: Shape> AsMask<S> for &Tensor<bool, S> {
    fn as_mask(&self) -> Mask<S> {
        (*self).into()
    }
}

impl<S: Shape> AsMask<S> for Mask<S> {
    fn as_mask(&self) -> Mask<S> {
        self.clone()
    }
}

impl<T, S> Tensor<T, S>
where
    T: MaskedValue,
    S: Shape,
{
    pub fn mask(&self) -> Mask<S> {
        self.mask_inner()
    }

    pub fn fill_masked(&self, value: T::Unmasked) -> Tensor<T::Unmasked, S> {
        self.map(|x| T::to_option(x).unwrap_or(value.clone()))
    }

    pub fn drop_mask(&self) -> Tensor<T::Unmasked, S> {
        let values = self.values().clone().with_mask(None);
        Tensor::new(values.cast(), self.shape().clone(), self.strides().clone())
    }

    pub fn compress(&self) -> Tensor<T::Unmasked, Const<1>> {
        self.iter_valid().collect()
    }
}

impl<T, S> Tensor<T, S>
where
    T: TensorValue,
    S: Shape,
{
    pub fn with_mask<M>(&self, mask: M) -> Tensor<T::Masked, S>
    where
        M: AsMask<S>,
    {
        let mask = mask.as_mask();
        let values = self.values().clone().with_mask(Some(mask.into_values()));
        Tensor::new(values.cast(), self.shape().clone(), self.strides().clone())
    }

    pub fn nullable(&self) -> Tensor<T::Masked, S> {
        Tensor::new(
            self.values().clone().cast(),
            self.shape().clone(),
            self.strides().clone(),
        )
    }
}