1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
use crate::{backend::AutodiffBackend, BasicOps, Bool, Float, Int, Tensor, TensorKind};

impl<const D: usize, B: AutodiffBackend> Tensor<B, D> {
    /// Backward pass of the tensor.
    pub fn backward(&self) -> B::Gradients {
        B::backward::<D>(self.primitive.clone())
    }

    /// Get the gradients of a tensor if it exist.
    ///
    /// Returns a new reference to the same tensor. Therefore the same grad tensor can
    /// be accessed multiple times. If you only need to get the gradients one time,
    /// consider using [grad_remove](Tensor::grad_remove) for better performance.
    pub fn grad(&self, grads: &B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
        B::grad(&self.primitive, grads).map(Tensor::new)
    }

    /// Remove the grad tensor from the [grads](AutodiffBackend::Gradients) struct returning the result.
    pub fn grad_remove(&self, grads: &mut B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
        B::grad_remove(&self.primitive, grads).map(Tensor::new)
    }

    /// Replace the grad tensor from the [grads](AutodiffBackend::Gradients) struct with the provided
    /// gradient.
    pub fn grad_replace(&self, grads: &mut B::Gradients, grad: Tensor<B::InnerBackend, D>) {
        B::grad_replace(&self.primitive, grads, grad.primitive);
    }
}

impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> Tensor<B, D, K> {
    /// Returns the inner tensor without the autodiff information.
    pub fn inner(self) -> Tensor<B::InnerBackend, D, K::InnerKind> {
        Tensor::new(K::inner(self.primitive))
    }

    /// Convert a tensor to the autodiff backend.
    ///
    /// # Arguments
    ///
    /// * `inner` - The tensor to convert.
    ///
    /// # Returns
    ///
    /// The tensor converted to the autodiff backend.
    pub fn from_inner(inner: Tensor<B::InnerBackend, D, K::InnerKind>) -> Self {
        Self::new(K::from_inner(inner.primitive))
    }
}

impl<B: AutodiffBackend> BasicAutodiffOps<B> for Float {
    type InnerKind = Float;

    fn inner<const D: usize>(
        tensor: <Self as TensorKind<B>>::Primitive<D>,
    ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive<D> {
        B::inner(tensor)
    }

    fn from_inner<const D: usize>(
        inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive<D>,
    ) -> <Self as TensorKind<B>>::Primitive<D> {
        B::from_inner(inner)
    }
}

impl<B: AutodiffBackend> BasicAutodiffOps<B> for Int {
    type InnerKind = Int;

    fn inner<const D: usize>(
        tensor: <Self as TensorKind<B>>::Primitive<D>,
    ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive<D> {
        B::int_inner(tensor)
    }

    fn from_inner<const D: usize>(
        inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive<D>,
    ) -> <Self as TensorKind<B>>::Primitive<D> {
        B::int_from_inner(inner)
    }
}

impl<B: AutodiffBackend> BasicAutodiffOps<B> for Bool {
    type InnerKind = Bool;

    fn inner<const D: usize>(
        tensor: <Self as TensorKind<B>>::Primitive<D>,
    ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive<D> {
        B::bool_inner(tensor)
    }

    fn from_inner<const D: usize>(
        inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive<D>,
    ) -> <Self as TensorKind<B>>::Primitive<D> {
        B::bool_from_inner(inner)
    }
}

/// Trait that list all operations that can be applied on all tensors on an autodiff backend.
///
/// # Warnings
///
/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
pub trait BasicAutodiffOps<B: AutodiffBackend>: BasicOps<B> + BasicOps<B::InnerBackend> {
    /// Inner primitive tensor.
    type InnerKind: BasicOps<B::InnerBackend>;

    /// Returns the inner tensor without the autodiff information.
    ///
    /// # Remarks
    ///
    /// This is a low-level function used internally by the library to call different backend functions
    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
    /// or use this function directly.
    ///
    /// Users should prefer the [Tensor::inner](Tensor::inner) function,
    /// which is more high-level and designed for public use.
    fn inner<const D: usize>(
        tensor: <Self as TensorKind<B>>::Primitive<D>,
    ) -> <Self::InnerKind as TensorKind<B::InnerBackend>>::Primitive<D>;

    /// Convert a tensor to the autodiff backend.
    ///
    /// # Remarks
    ///
    /// This is a low-level function used internally by the library to call different backend functions
    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
    /// or use this function directly.
    ///
    /// Users should prefer the [Tensor::from_inner](Tensor::from_inner) function,
    /// which is more high-level and designed for public use.
    fn from_inner<const D: usize>(
        inner: <Self::InnerKind as TensorKind<B::InnerBackend>>::Primitive<D>,
    ) -> <Self as TensorKind<B>>::Primitive<D>;
}