burn_tensor/tensor/api/
autodiff.rs1use crate::{
2 BasicOps, Bool, Float, Int, Tensor, TensorKind, TensorPrimitive, backend::AutodiffBackend,
3};
4
5impl<const D: usize, B: AutodiffBackend> Tensor<B, D> {
6 pub fn backward(&self) -> B::Gradients {
8 B::backward(self.primitive.clone().tensor())
9 }
10
11 pub fn grad(&self, grads: &B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
17 match &self.primitive {
18 TensorPrimitive::Float(tensor) => B::grad(tensor, grads)
19 .map(TensorPrimitive::Float)
20 .map(Tensor::new),
21 TensorPrimitive::QFloat(_tensor) => B::grad(&self.primitive.clone().tensor(), grads)
22 .map(TensorPrimitive::Float)
23 .map(Tensor::new),
24 }
25 }
26
27 pub fn grad_remove(&self, grads: &mut B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
29 match &self.primitive {
30 TensorPrimitive::Float(tensor) => B::grad_remove(tensor, grads)
31 .map(TensorPrimitive::Float)
32 .map(Tensor::new),
33 TensorPrimitive::QFloat(_tensor) => {
34 B::grad_remove(&self.primitive.clone().tensor(), grads)
35 .map(TensorPrimitive::Float)
36 .map(Tensor::new)
37 }
38 }
39 }
40
41 pub fn grad_replace(&self, grads: &mut B::Gradients, grad: Tensor<B::InnerBackend, D>) {
44 match &self.primitive {
45 TensorPrimitive::Float(tensor) => {
46 B::grad_replace(tensor, grads, grad.primitive.tensor())
47 }
48 TensorPrimitive::QFloat(_tensor) => B::grad_replace(
49 &self.primitive.clone().tensor(),
50 grads,
51 grad.primitive.tensor(),
52 ),
53 }
54 }
55}
56
57impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> Tensor<B, D, K> {
58 pub fn inner(self) -> Tensor<B::InnerBackend, D, K::InnerKind> {
60 Tensor::new(K::inner(self.primitive))
61 }
62
63 pub fn from_inner(inner: Tensor<B::InnerBackend, D, K::InnerKind>) -> Self {
73 Self::new(K::from_inner(inner.primitive))
74 }
75}
76
77impl<B: AutodiffBackend> BasicAutodiffOps<B> for Float {
78 type InnerKind = Float;
79
80 fn inner(
81 tensor: <Self as TensorKind<B>>::Primitive,
82 ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
83 match tensor {
84 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::inner(tensor)),
85 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_inner(tensor)),
86 }
87 }
88
89 fn from_inner(
90 inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
91 ) -> <Self as TensorKind<B>>::Primitive {
92 match inner {
93 TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::from_inner(tensor)),
94 TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_from_inner(tensor)),
95 }
96 }
97}
98
99impl<B: AutodiffBackend> BasicAutodiffOps<B> for Int {
100 type InnerKind = Int;
101
102 fn inner(
103 tensor: <Self as TensorKind<B>>::Primitive,
104 ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
105 B::int_inner(tensor)
106 }
107
108 fn from_inner(
109 inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
110 ) -> <Self as TensorKind<B>>::Primitive {
111 B::int_from_inner(inner)
112 }
113}
114
115impl<B: AutodiffBackend> BasicAutodiffOps<B> for Bool {
116 type InnerKind = Bool;
117
118 fn inner(
119 tensor: <Self as TensorKind<B>>::Primitive,
120 ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
121 B::bool_inner(tensor)
122 }
123
124 fn from_inner(
125 inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
126 ) -> <Self as TensorKind<B>>::Primitive {
127 B::bool_from_inner(inner)
128 }
129}
130
131pub trait BasicAutodiffOps<B: AutodiffBackend>: BasicOps<B> + BasicOps<B::InnerBackend> {
137 type InnerKind: BasicOps<B::InnerBackend>;
139
140 fn inner(
151 tensor: <Self as TensorKind<B>>::Primitive,
152 ) -> <Self::InnerKind as TensorKind<B::InnerBackend>>::Primitive;
153
154 fn from_inner(
165 inner: <Self::InnerKind as TensorKind<B::InnerBackend>>::Primitive,
166 ) -> <Self as TensorKind<B>>::Primitive;
167}