burn_tensor/tensor/api/
autodiff.rsuse crate::{
backend::AutodiffBackend, BasicOps, Bool, Float, Int, Tensor, TensorKind, TensorPrimitive,
};
impl<const D: usize, B: AutodiffBackend> Tensor<B, D> {
pub fn backward(&self) -> B::Gradients {
B::backward(self.primitive.clone().tensor())
}
pub fn grad(&self, grads: &B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
match &self.primitive {
TensorPrimitive::Float(tensor) => B::grad(tensor, grads)
.map(TensorPrimitive::Float)
.map(Tensor::new),
TensorPrimitive::QFloat(_tensor) => B::grad(&self.primitive.clone().tensor(), grads)
.map(TensorPrimitive::Float)
.map(Tensor::new),
}
}
pub fn grad_remove(&self, grads: &mut B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
match &self.primitive {
TensorPrimitive::Float(tensor) => B::grad_remove(tensor, grads)
.map(TensorPrimitive::Float)
.map(Tensor::new),
TensorPrimitive::QFloat(_tensor) => {
B::grad_remove(&self.primitive.clone().tensor(), grads)
.map(TensorPrimitive::Float)
.map(Tensor::new)
}
}
}
pub fn grad_replace(&self, grads: &mut B::Gradients, grad: Tensor<B::InnerBackend, D>) {
match &self.primitive {
TensorPrimitive::Float(tensor) => {
B::grad_replace(tensor, grads, grad.primitive.tensor())
}
TensorPrimitive::QFloat(_tensor) => B::grad_replace(
&self.primitive.clone().tensor(),
grads,
grad.primitive.tensor(),
),
}
}
}
impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> Tensor<B, D, K> {
pub fn inner(self) -> Tensor<B::InnerBackend, D, K::InnerKind> {
Tensor::new(K::inner(self.primitive))
}
pub fn from_inner(inner: Tensor<B::InnerBackend, D, K::InnerKind>) -> Self {
Self::new(K::from_inner(inner.primitive))
}
}
impl<B: AutodiffBackend> BasicAutodiffOps<B> for Float {
type InnerKind = Float;
fn inner(
tensor: <Self as TensorKind<B>>::Primitive,
) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
match tensor {
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::inner(tensor)),
TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_inner(tensor)),
}
}
fn from_inner(
inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
) -> <Self as TensorKind<B>>::Primitive {
match inner {
TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::from_inner(tensor)),
TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_from_inner(tensor)),
}
}
}
impl<B: AutodiffBackend> BasicAutodiffOps<B> for Int {
type InnerKind = Int;
fn inner(
tensor: <Self as TensorKind<B>>::Primitive,
) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
B::int_inner(tensor)
}
fn from_inner(
inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
) -> <Self as TensorKind<B>>::Primitive {
B::int_from_inner(inner)
}
}
impl<B: AutodiffBackend> BasicAutodiffOps<B> for Bool {
type InnerKind = Bool;
fn inner(
tensor: <Self as TensorKind<B>>::Primitive,
) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
B::bool_inner(tensor)
}
fn from_inner(
inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
) -> <Self as TensorKind<B>>::Primitive {
B::bool_from_inner(inner)
}
}
pub trait BasicAutodiffOps<B: AutodiffBackend>: BasicOps<B> + BasicOps<B::InnerBackend> {
type InnerKind: BasicOps<B::InnerBackend>;
fn inner(
tensor: <Self as TensorKind<B>>::Primitive,
) -> <Self::InnerKind as TensorKind<B::InnerBackend>>::Primitive;
fn from_inner(
inner: <Self::InnerKind as TensorKind<B::InnerBackend>>::Primitive,
) -> <Self as TensorKind<B>>::Primitive;
}