ella_common/
ops.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
use crate::TensorValue;

pub trait TensorOp<Rhs: TensorValue>: TensorValue {
    type Output<Out>;

    fn apply<F, O>(self, other: Rhs, op: F) -> Self::Output<O>
    where
        F: Fn(Self::Unmasked, Rhs::Unmasked) -> O,
        O: TensorValue,
        Self::Output<O>: TensorValue;
}

pub trait TensorUnaryOp: TensorValue {
    type Output<Out>;

    fn apply<F, O>(self, op: F) -> Self::Output<O>
    where
        F: Fn(Self::Unmasked) -> O,
        O: TensorValue,
        Self::Output<O>: TensorValue;
}

// Implement T <op> T
impl<T> TensorOp<T> for T
where
    T: TensorValue<Unmasked = Self>,
{
    type Output<Out> = Out;

    #[inline]
    fn apply<F, O>(self, other: T, op: F) -> Self::Output<O>
    where
        F: Fn(Self::Unmasked, <T as TensorValue>::Unmasked) -> O,
        Self::Output<O>: TensorValue,
    {
        op(self, other)
    }
}

// Implement T <op> Option<T>
impl<T> TensorOp<Option<T>> for T
where
    T: TensorValue<Unmasked = T, Masked = Option<T>>,
    Option<T>: TensorValue<Unmasked = T>,
{
    type Output<Out> = Option<Out>;

    #[inline]
    fn apply<F, O>(self, other: Option<T>, op: F) -> Self::Output<O>
    where
        F: Fn(Self::Unmasked, <Option<T> as TensorValue>::Unmasked) -> O,
        O: TensorValue,
        Self::Output<O>: TensorValue,
    {
        other.map(|other| op(self, other))
    }
}

// Implement Option<T> <op> T
impl<T> TensorOp<T> for Option<T>
where
    T: TensorValue<Unmasked = T, Masked = Option<T>>,
    Option<T>: TensorValue<Unmasked = T>,
{
    type Output<Out> = Option<Out>;

    #[inline]
    fn apply<F, O>(self, other: T, op: F) -> Self::Output<O>
    where
        F: Fn(Self::Unmasked, <T as TensorValue>::Unmasked) -> O,
        O: TensorValue,
        Self::Output<O>: TensorValue,
    {
        self.map(|this| op(this, other))
    }
}

// Implement Option<T> <op> Option<T>
impl<T> TensorOp<Option<T>> for Option<T>
where
    T: TensorValue<Unmasked = T, Masked = Option<T>>,
    Option<T>: TensorValue<Unmasked = T>,
{
    type Output<Out> = Option<Out>;

    #[inline]
    fn apply<F, O>(self, other: Option<T>, op: F) -> Self::Output<O>
    where
        F: Fn(Self::Unmasked, <Option<T> as TensorValue>::Unmasked) -> O,
        O: TensorValue,
        Self::Output<O>: TensorValue,
    {
        self.zip(other).map(|(a, b)| op(a, b))
    }
}

impl<T> TensorUnaryOp for T
where
    T: TensorValue<Unmasked = T, Masked = Option<T>>,
{
    type Output<Out> = Out;

    fn apply<F, O>(self, op: F) -> Self::Output<O>
    where
        F: Fn(Self::Unmasked) -> O,
        O: TensorValue,
        Self::Output<O>: TensorValue,
    {
        op(self)
    }
}

impl<T> TensorUnaryOp for Option<T>
where
    T: TensorValue<Unmasked = T, Masked = Option<T>>,
{
    type Output<Out> = Option<Out>;

    fn apply<F, O>(self, op: F) -> Self::Output<O>
    where
        F: Fn(Self::Unmasked) -> O,
        O: TensorValue,
        Self::Output<O>: TensorValue,
    {
        self.map(op)
    }
}