Trait rai_core::Differentiable
source · pub trait Differentiable {
type Tensors: TensorIter + 'static;
type Gradient;
// Required methods
fn tensors(&self) -> Self::Tensors;
fn grad(
tensors: &Self::Tensors,
grad_map: &HashMap<usize, Tensor>
) -> Self::Gradient;
fn grad_map(
tensors: &Self::Tensors,
grad: Self::Gradient,
out: &mut HashMap<usize, Tensor>
);
}
Required Associated Types§
type Tensors: TensorIter + 'static
type Gradient
Required Methods§
fn tensors(&self) -> Self::Tensors
fn grad( tensors: &Self::Tensors, grad_map: &HashMap<usize, Tensor> ) -> Self::Gradient
fn grad_map( tensors: &Self::Tensors, grad: Self::Gradient, out: &mut HashMap<usize, Tensor> )
Object Safety§
This trait is not object safe.