burn_tensor/tensor/api/
autodiff.rs

1use crate::{
2    BasicOps, Bool, Float, Int, Tensor, TensorKind, TensorPrimitive, backend::AutodiffBackend,
3};
4
5impl<const D: usize, B: AutodiffBackend> Tensor<B, D> {
6    /// Backward pass of the tensor.
7    pub fn backward(&self) -> B::Gradients {
8        B::backward(self.primitive.clone().tensor())
9    }
10
11    /// Get the gradients of a tensor if it exist.
12    ///
13    /// Returns a new reference to the same tensor. Therefore the same grad tensor can
14    /// be accessed multiple times. If you only need to get the gradients one time,
15    /// consider using [grad_remove](Tensor::grad_remove) for better performance.
16    pub fn grad(&self, grads: &B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
17        match &self.primitive {
18            TensorPrimitive::Float(tensor) => B::grad(tensor, grads)
19                .map(TensorPrimitive::Float)
20                .map(Tensor::new),
21            TensorPrimitive::QFloat(_tensor) => B::grad(&self.primitive.clone().tensor(), grads)
22                .map(TensorPrimitive::Float)
23                .map(Tensor::new),
24        }
25    }
26
27    /// Remove the grad tensor from the [grads](AutodiffBackend::Gradients) struct returning the result.
28    pub fn grad_remove(&self, grads: &mut B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
29        match &self.primitive {
30            TensorPrimitive::Float(tensor) => B::grad_remove(tensor, grads)
31                .map(TensorPrimitive::Float)
32                .map(Tensor::new),
33            TensorPrimitive::QFloat(_tensor) => {
34                B::grad_remove(&self.primitive.clone().tensor(), grads)
35                    .map(TensorPrimitive::Float)
36                    .map(Tensor::new)
37            }
38        }
39    }
40
41    /// Replace the grad tensor from the [grads](AutodiffBackend::Gradients) struct with the provided
42    /// gradient.
43    pub fn grad_replace(&self, grads: &mut B::Gradients, grad: Tensor<B::InnerBackend, D>) {
44        match &self.primitive {
45            TensorPrimitive::Float(tensor) => {
46                B::grad_replace(tensor, grads, grad.primitive.tensor())
47            }
48            TensorPrimitive::QFloat(_tensor) => B::grad_replace(
49                &self.primitive.clone().tensor(),
50                grads,
51                grad.primitive.tensor(),
52            ),
53        }
54    }
55}
56
57impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> Tensor<B, D, K> {
58    /// Returns the inner tensor without the autodiff information.
59    pub fn inner(self) -> Tensor<B::InnerBackend, D, K::InnerKind> {
60        Tensor::new(K::inner(self.primitive))
61    }
62
63    /// Convert a tensor to the autodiff backend.
64    ///
65    /// # Arguments
66    ///
67    /// * `inner` - The tensor to convert.
68    ///
69    /// # Returns
70    ///
71    /// The tensor converted to the autodiff backend.
72    pub fn from_inner(inner: Tensor<B::InnerBackend, D, K::InnerKind>) -> Self {
73        Self::new(K::from_inner(inner.primitive))
74    }
75}
76
77impl<B: AutodiffBackend> BasicAutodiffOps<B> for Float {
78    type InnerKind = Float;
79
80    fn inner(
81        tensor: <Self as TensorKind<B>>::Primitive,
82    ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
83        match tensor {
84            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::inner(tensor)),
85            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_inner(tensor)),
86        }
87    }
88
89    fn from_inner(
90        inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
91    ) -> <Self as TensorKind<B>>::Primitive {
92        match inner {
93            TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::from_inner(tensor)),
94            TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_from_inner(tensor)),
95        }
96    }
97}
98
99impl<B: AutodiffBackend> BasicAutodiffOps<B> for Int {
100    type InnerKind = Int;
101
102    fn inner(
103        tensor: <Self as TensorKind<B>>::Primitive,
104    ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
105        B::int_inner(tensor)
106    }
107
108    fn from_inner(
109        inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
110    ) -> <Self as TensorKind<B>>::Primitive {
111        B::int_from_inner(inner)
112    }
113}
114
115impl<B: AutodiffBackend> BasicAutodiffOps<B> for Bool {
116    type InnerKind = Bool;
117
118    fn inner(
119        tensor: <Self as TensorKind<B>>::Primitive,
120    ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
121        B::bool_inner(tensor)
122    }
123
124    fn from_inner(
125        inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
126    ) -> <Self as TensorKind<B>>::Primitive {
127        B::bool_from_inner(inner)
128    }
129}
130
131/// Trait that list all operations that can be applied on all tensors on an autodiff backend.
132///
133/// # Warnings
134///
135/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
136pub trait BasicAutodiffOps<B: AutodiffBackend>: BasicOps<B> + BasicOps<B::InnerBackend> {
137    /// Inner primitive tensor.
138    type InnerKind: BasicOps<B::InnerBackend>;
139
140    /// Returns the inner tensor without the autodiff information.
141    ///
142    /// # Remarks
143    ///
144    /// This is a low-level function used internally by the library to call different backend functions
145    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
146    /// or use this function directly.
147    ///
148    /// Users should prefer the [Tensor::inner](Tensor::inner) function,
149    /// which is more high-level and designed for public use.
150    fn inner(
151        tensor: <Self as TensorKind<B>>::Primitive,
152    ) -> <Self::InnerKind as TensorKind<B::InnerBackend>>::Primitive;
153
154    /// Convert a tensor to the autodiff backend.
155    ///
156    /// # Remarks
157    ///
158    /// This is a low-level function used internally by the library to call different backend functions
159    /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
160    /// or use this function directly.
161    ///
162    /// Users should prefer the [Tensor::from_inner](Tensor::from_inner) function,
163    /// which is more high-level and designed for public use.
164    fn from_inner(
165        inner: <Self::InnerKind as TensorKind<B::InnerBackend>>::Primitive,
166    ) -> <Self as TensorKind<B>>::Primitive;
167}