pub use burn_backend::tensor::BasicAutodiffOps;
use crate::{Tensor, TensorPrimitive, backend::AutodiffBackend};
impl<const D: usize, B: AutodiffBackend> Tensor<B, D> {
pub fn backward(&self) -> B::Gradients {
B::backward(self.primitive.clone().tensor())
}
pub fn grad(&self, grads: &B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
match &self.primitive {
TensorPrimitive::Float(tensor) => B::grad(tensor, grads)
.map(TensorPrimitive::Float)
.map(Tensor::new),
TensorPrimitive::QFloat(_tensor) => B::grad(&self.primitive.clone().tensor(), grads)
.map(TensorPrimitive::Float)
.map(Tensor::new),
}
}
pub fn grad_remove(&self, grads: &mut B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
match &self.primitive {
TensorPrimitive::Float(tensor) => B::grad_remove(tensor, grads)
.map(TensorPrimitive::Float)
.map(Tensor::new),
TensorPrimitive::QFloat(_tensor) => {
B::grad_remove(&self.primitive.clone().tensor(), grads)
.map(TensorPrimitive::Float)
.map(Tensor::new)
}
}
}
pub fn grad_replace(&self, grads: &mut B::Gradients, grad: Tensor<B::InnerBackend, D>) {
match &self.primitive {
TensorPrimitive::Float(tensor) => {
B::grad_replace(tensor, grads, grad.primitive.tensor())
}
TensorPrimitive::QFloat(_tensor) => B::grad_replace(
&self.primitive.clone().tensor(),
grads,
grad.primitive.tensor(),
),
}
}
}
impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> Tensor<B, D, K> {
pub fn inner(self) -> Tensor<B::InnerBackend, D, K::InnerKind> {
Tensor::new(K::inner(self.primitive))
}
pub fn from_inner(inner: Tensor<B::InnerBackend, D, K::InnerKind>) -> Self {
Self::new(K::from_inner(inner.primitive))
}
}