ella_tensor/ops/
unary_arith.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
use num_traits::{Float, Signed};

use crate::{Shape, Tensor, TensorValue};

use super::{unary_op, TensorUnaryOp};

macro_rules! impl_unary_ops {
    ($([$op:tt $kernel:path])+) => {
        $(
        pub fn $op(&self) -> Tensor<T::Output<T::Unmasked>, S> {
            unary_op(self, |x| x.apply($kernel))
        }
        )+
    };
}

impl<T, S> Tensor<T, S>
where
    T: TensorUnaryOp<Output<<T as TensorValue>::Unmasked> = T>,
    S: Shape,
    T::Unmasked: Float,
{
    impl_unary_ops!(
        [sin   Float::sin]
        [cos   Float::cos]
        [tan   Float::tan]
        [acos  Float::acos]
        [asin  Float::asin]
        [atan  Float::atan]
        [exp   Float::exp]
        [exp2  Float::exp2]
        [ln    Float::ln]
        [log2  Float::log2]
        [log10 Float::log10]
    );
}

impl<T, S> Tensor<T, S>
where
    T: TensorUnaryOp<Output<<T as TensorValue>::Unmasked> = T>,
    S: Shape,
    T::Unmasked: Signed,
{
    pub fn abs(&self) -> Tensor<T::Output<T::Unmasked>, S> {
        unary_op(self, |x| x.apply(|x| x.abs()))
    }
}