burn_tensor/tensor/api/
autodiff.rs1pub use burn_backend::tensor::BasicAutodiffOps;
2
3use crate::{Tensor, TensorPrimitive, backend::AutodiffBackend};
4
5impl<const D: usize, B: AutodiffBackend> Tensor<B, D> {
6 pub fn backward(&self) -> B::Gradients {
8 B::backward(self.primitive.clone().tensor())
9 }
10
11 pub fn grad(&self, grads: &B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
17 match &self.primitive {
18 TensorPrimitive::Float(tensor) => B::grad(tensor, grads)
19 .map(TensorPrimitive::Float)
20 .map(Tensor::new),
21 TensorPrimitive::QFloat(_tensor) => B::grad(&self.primitive.clone().tensor(), grads)
22 .map(TensorPrimitive::Float)
23 .map(Tensor::new),
24 }
25 }
26
27 pub fn grad_remove(&self, grads: &mut B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
29 match &self.primitive {
30 TensorPrimitive::Float(tensor) => B::grad_remove(tensor, grads)
31 .map(TensorPrimitive::Float)
32 .map(Tensor::new),
33 TensorPrimitive::QFloat(_tensor) => {
34 B::grad_remove(&self.primitive.clone().tensor(), grads)
35 .map(TensorPrimitive::Float)
36 .map(Tensor::new)
37 }
38 }
39 }
40
41 pub fn grad_replace(&self, grads: &mut B::Gradients, grad: Tensor<B::InnerBackend, D>) {
44 match &self.primitive {
45 TensorPrimitive::Float(tensor) => {
46 B::grad_replace(tensor, grads, grad.primitive.tensor())
47 }
48 TensorPrimitive::QFloat(_tensor) => B::grad_replace(
49 &self.primitive.clone().tensor(),
50 grads,
51 grad.primitive.tensor(),
52 ),
53 }
54 }
55}
56
57impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> Tensor<B, D, K> {
58 pub fn inner(self) -> Tensor<B::InnerBackend, D, K::InnerKind> {
60 Tensor::new(K::inner(self.primitive))
61 }
62
63 pub fn from_inner(inner: Tensor<B::InnerBackend, D, K::InnerKind>) -> Self {
73 Self::new(K::from_inner(inner.primitive))
74 }
75}