1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
pub use burn_backend::tensor::BasicAutodiffOps;
use crate::{Tensor, TensorPrimitive, backend::AutodiffBackend};
impl<const D: usize, B: AutodiffBackend> Tensor<B, D> {
/// Backward pass of the tensor.
pub fn backward(&self) -> B::Gradients {
B::backward(self.primitive.clone().tensor())
}
/// Get the gradients of a tensor if it exist.
///
/// Returns a new reference to the same tensor. Therefore the same grad tensor can
/// be accessed multiple times. If you only need to get the gradients one time,
/// consider using [grad_remove](Tensor::grad_remove) for better performance.
pub fn grad(&self, grads: &B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
match &self.primitive {
TensorPrimitive::Float(tensor) => B::grad(tensor, grads)
.map(TensorPrimitive::Float)
.map(Tensor::new),
TensorPrimitive::QFloat(_tensor) => B::grad(&self.primitive.clone().tensor(), grads)
.map(TensorPrimitive::Float)
.map(Tensor::new),
}
}
/// Remove the grad tensor from the [grads](AutodiffBackend::Gradients) struct returning the result.
pub fn grad_remove(&self, grads: &mut B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
match &self.primitive {
TensorPrimitive::Float(tensor) => B::grad_remove(tensor, grads)
.map(TensorPrimitive::Float)
.map(Tensor::new),
TensorPrimitive::QFloat(_tensor) => {
B::grad_remove(&self.primitive.clone().tensor(), grads)
.map(TensorPrimitive::Float)
.map(Tensor::new)
}
}
}
/// Replace the grad tensor from the [grads](AutodiffBackend::Gradients) struct with the provided
/// gradient.
pub fn grad_replace(&self, grads: &mut B::Gradients, grad: Tensor<B::InnerBackend, D>) {
match &self.primitive {
TensorPrimitive::Float(tensor) => {
B::grad_replace(tensor, grads, grad.primitive.tensor())
}
TensorPrimitive::QFloat(_tensor) => B::grad_replace(
&self.primitive.clone().tensor(),
grads,
grad.primitive.tensor(),
),
}
}
}
impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> Tensor<B, D, K> {
/// Returns the inner tensor without the autodiff information.
pub fn inner(self) -> Tensor<B::InnerBackend, D, K::InnerKind> {
Tensor::new(K::inner(self.primitive))
}
/// Convert a tensor to the autodiff backend.
///
/// # Arguments
///
/// * `inner` - The tensor to convert.
///
/// # Returns
///
/// The tensor converted to the autodiff backend.
pub fn from_inner(inner: Tensor<B::InnerBackend, D, K::InnerKind>) -> Self {
Self::new(K::from_inner(inner.primitive))
}
}
// TODO: a lot of the `tensor.inner` / `Tensor::from_inner(...)` are actually scoped to perform some operations
// so it might be cleaner and easier to manage the device etc. if we provide a method to scope the autodiff?