Struct gad::graph::Graph[][src]

pub struct Graph<C: Config> { /* fields omitted */ }

Main structure holding the computational graph (aka “tape”) used for automatic differentiation. In practice, the configuration is instantiated to build either Graph1 or GraphN, depending if higher-order differentials are needed or not.

Implementations

impl<C: Config> Graph<C>[src]

pub fn new() -> Self[src]

Create a new graph.

pub fn eval(&mut self) -> &mut C::EvalAlgebra[src]

impl<C: Config> Graph<C>[src]

pub fn make_node<D, G, F, Dims>(
    &mut self,
    data: D,
    inputs: Vec<Option<Id>>,
    update_func: F
) -> Value<D> where
    C::GradientAlgebra: CoreAlgebra<D, Value = G>,
    C::GradientStore: GradientStore<GradientId<D>, G>,
    D: HasDims<Dims = Dims>,
    G: HasDims<Dims = Dims> + Clone + 'static,
    Dims: PartialEq + Debug + Clone + 'static + Send + Sync,
    F: Fn(&mut C::GradientAlgebra, &mut C::GradientStore, G) -> Result<()> + 'static + Send + Sync
[src]

Create a computation node (used to define operators). During back-propagation, update_func must call store.add_gradient to propagate the gradient of each (non-constant) input.

pub fn make_generic_node<S, D, GS, GD, F, Dims>(
    &mut self,
    data: D,
    inputs: Vec<Option<Id>>,
    update_func: F
) -> Value<D> where
    C::GradientAlgebra: CoreAlgebra<S, Value = GS>,
    C::GradientAlgebra: CoreAlgebra<D, Value = GD>,
    C::GradientStore: GradientStore<GradientId<D>, GD>,
    C::GradientStore: GradientStore<GradientId<S>, GS>,
    D: HasDims<Dims = Dims>,
    GD: HasDims<Dims = Dims> + Clone + 'static,
    Dims: PartialEq + Debug + Clone + 'static + Send + Sync,
    F: Fn(&mut C::GradientAlgebra, &mut C::GradientStore, GD) -> Result<()> + 'static + Send + Sync
[src]

Create a computation node where the source type S may be different than the target type D.

impl<E: Default + Clone> Graph<Config1<E>>[src]

First order only (this is the most common case)

pub fn evaluate_gradients<T>(
    &self,
    id: GradientId<T>,
    gradient: T
) -> Result<GenericGradientMap1> where
    E: CoreAlgebra<T, Value = T>,
    T: 'static, 
[src]

Propagate gradients backward, starting with the node id.

  • Allow the graph to be re-used.
  • Gradients are stored as pure data.

pub fn evaluate_gradients_once<T>(
    self,
    id: GradientId<T>,
    gradient: T
) -> Result<GenericGradientMap1> where
    E: CoreAlgebra<T, Value = T>,
    T: 'static, 
[src]

Propagate gradients backward, starting with the node id.

  • Clean up memory when possible and consume the graph.
  • Gradients are stored as pure data.

impl<E: Default + Clone> Graph<ConfigN<E>>[src]

Higher order differentials.

pub fn compute_gradients<D>(
    &mut self,
    id: GradientId<D>,
    gradient: Value<D>
) -> Result<GenericGradientMapN> where
    Self: CoreAlgebra<D, Value = Value<D>>,
    D: 'static, 
[src]

Propagate gradients backward, starting with the node id.

  • Gradients are computed as graph values that can be differentiated later.
  • The graph is augmented with the nodes corresponding to gradient computations.

Trait Implementations

impl<D, E, Dims> AnalyticAlgebra<Value<D>> for Graph<Config1<E>> where
    E: Default + Clone + CoreAlgebra<D, Value = D> + AnalyticAlgebra<D> + ArithAlgebra<D> + ConstArithAlgebra<D, i16> + LinkedAlgebra<Value<D>, D>,
    D: HasDims<Dims = Dims> + Clone + 'static + Send + Sync,
    Dims: PartialEq + Debug + Clone + 'static + Send + Sync
[src]

impl<D, E, Dims> AnalyticAlgebra<Value<D>> for Graph<ConfigN<E>> where
    E: Default + Clone + CoreAlgebra<D, Value = D> + AnalyticAlgebra<D> + ArithAlgebra<D> + ConstArithAlgebra<D, i16> + LinkedAlgebra<Value<D>, D>,
    D: HasDims<Dims = Dims> + Clone + 'static + Send + Sync,
    Dims: PartialEq + Debug + Clone + 'static + Send + Sync
[src]

impl<D, E, Dims> ArithAlgebra<Value<D>> for Graph<Config1<E>> where
    E: Default + Clone + CoreAlgebra<D, Value = D> + ArithAlgebra<D> + LinkedAlgebra<Value<D>, D>,
    D: HasDims<Dims = Dims> + Clone + 'static + Send + Sync,
    Dims: PartialEq + Debug + Clone + 'static + Send + Sync
[src]

impl<D, E, Dims> ArithAlgebra<Value<D>> for Graph<ConfigN<E>> where
    E: Default + Clone + CoreAlgebra<D, Value = D> + ArithAlgebra<D> + LinkedAlgebra<Value<D>, D>,
    D: HasDims<Dims = Dims> + Clone + 'static + Send + Sync,
    Dims: PartialEq + Debug + Clone + 'static + Send + Sync
[src]

impl<D, E, T, Dims> ArrayAlgebra<Value<D>> for Graph<Config1<E>> where
    E: Default + Clone + CoreAlgebra<D, Value = D> + CoreAlgebra<T, Value = T> + LinkedAlgebra<Value<D>, D> + LinkedAlgebra<Value<T>, T> + ArrayAlgebra<D, Scalar = T, Dims = Dims>,
    Dims: PartialEq + Clone + Copy + Debug + Default + 'static + Send + Sync,
    D: HasDims<Dims = Dims> + Clone + 'static + Send + Sync,
    T: Number
[src]

type Dims = Dims

type Scalar = Value<T>

impl<D, E, T, Dims> ArrayAlgebra<Value<D>> for Graph<ConfigN<E>> where
    E: Default + Clone + CoreAlgebra<D, Value = D> + CoreAlgebra<T, Value = T> + LinkedAlgebra<Value<D>, D> + LinkedAlgebra<Value<T>, T> + ArrayAlgebra<D, Scalar = T, Dims = Dims>,
    Dims: PartialEq + Clone + Copy + Debug + Default + 'static + Send + Sync,
    D: HasDims<Dims = Dims> + Clone + 'static + Send + Sync,
    T: Number
[src]

type Dims = Dims

type Scalar = Value<T>

impl<D, E, T, Dims> ArrayCompareAlgebra<Value<D>> for Graph<Config1<E>> where
    E: Default + Clone + CoreAlgebra<D, Value = D> + CoreAlgebra<T, Value = T> + CompareAlgebra<D> + ArrayCompareAlgebra<D> + ArrayAlgebra<D, Dims = Dims> + ArithAlgebra<D> + ArrayAlgebra<D, Scalar = T, Dims = Dims> + LinkedAlgebra<Value<D>, D> + LinkedAlgebra<Value<T>, T>,
    T: Number,
    D: HasDims<Dims = Dims> + Clone + 'static + Send + Sync,
    Dims: PartialEq + Debug + Default + Copy + Clone + 'static + Send + Sync
[src]

impl<D, E, T, Dims> ArrayCompareAlgebra<Value<D>> for Graph<ConfigN<E>> where
    E: Default + Clone + CoreAlgebra<D, Value = D> + CoreAlgebra<T, Value = T> + CompareAlgebra<D> + ArrayCompareAlgebra<D> + ArrayAlgebra<D, Dims = Dims> + ArithAlgebra<D> + ArrayAlgebra<D, Scalar = T, Dims = Dims> + LinkedAlgebra<Value<D>, D> + LinkedAlgebra<Value<T>, T>,
    T: Number,
    D: HasDims<Dims = Dims> + Clone + 'static + Send + Sync,
    Dims: PartialEq + Debug + Default + Copy + Clone + 'static + Send + Sync
[src]

impl<C: Config> Clone for Graph<C>[src]

impl<D, E, Dims> CompareAlgebra<Value<D>> for Graph<Config1<E>> where
    E: Default + Clone + CoreAlgebra<D, Value = D> + CompareAlgebra<D> + LinkedAlgebra<Value<D>, D>,
    D: HasDims<Dims = Dims> + Clone + 'static + Send + Sync,
    Dims: PartialEq + Debug + Clone + 'static + Send + Sync
[src]

impl<D, E, Dims> CompareAlgebra<Value<D>> for Graph<ConfigN<E>> where
    E: Default + Clone + CoreAlgebra<D, Value = D> + CompareAlgebra<D> + LinkedAlgebra<Value<D>, D>,
    D: HasDims<Dims = Dims> + Clone + 'static + Send + Sync,
    Dims: PartialEq + Debug + Clone + 'static + Send + Sync
[src]

impl<D, E, Dims, C> ConstArithAlgebra<Value<D>, C> for Graph<Config1<E>> where
    E: Default + Clone + CoreAlgebra<D, Value = D> + ArithAlgebra<D> + ConstArithAlgebra<D, C> + LinkedAlgebra<Value<D>, D>,
    C: Sub<C, Output = C> + One + Clone + 'static + Send + Sync,
    D: HasDims<Dims = Dims> + Clone + 'static + Send + Sync,
    Dims: PartialEq + Debug + Clone + 'static + Send + Sync
[src]

impl<D, E, Dims, C> ConstArithAlgebra<Value<D>, C> for Graph<ConfigN<E>> where
    E: Default + Clone + CoreAlgebra<D, Value = D> + ArithAlgebra<D> + ConstArithAlgebra<D, C> + LinkedAlgebra<Value<D>, D>,
    C: Sub<C, Output = C> + One + Clone + 'static + Send + Sync,
    D: HasDims<Dims = Dims> + Clone + 'static + Send + Sync,
    Dims: PartialEq + Debug + Clone + 'static + Send + Sync
[src]

impl<D, E, Dims> CoreAlgebra<D> for Graph<Config1<E>> where
    E: Default + Clone + CoreAlgebra<D, Value = D>,
    D: HasDims<Dims = Dims> + Clone + 'static + Send + Sync,
    Dims: PartialEq + Debug + Clone + 'static + Send + Sync
[src]

type Value = Value<D>

Tracked values of underlying type Data.

impl<D, E, Dims> CoreAlgebra<D> for Graph<ConfigN<E>> where
    E: Default + Clone + CoreAlgebra<D, Value = D>,
    D: HasDims<Dims = Dims> + Clone + 'static + Send + Sync,
    Dims: PartialEq + Debug + Clone + 'static + Send + Sync
[src]

type Value = Value<D>

Tracked values of underlying type Data.

impl<C: Config> Debug for Graph<C>[src]

impl<C: Config> Default for Graph<C>[src]

impl<C: Config> HasGradientReader for Graph<C>[src]

impl<V, C: Config> LinkedAlgebra<V, V> for Graph<C>[src]

Assume that we link into a copy of the original graph.

impl<D, E, Dims> MatrixAlgebra<Value<D>> for Graph<Config1<E>> where
    E: Default + Clone + CoreAlgebra<D, Value = D> + LinkedAlgebra<Value<D>, D> + MatrixAlgebra<D>,
    D: HasDims<Dims = Dims> + Clone + 'static + Send + Sync,
    Dims: PartialEq + Debug + Clone + 'static + Send + Sync
[src]

impl<D, E, Dims> MatrixAlgebra<Value<D>> for Graph<ConfigN<E>> where
    E: Default + Clone + CoreAlgebra<D, Value = D> + LinkedAlgebra<Value<D>, D> + MatrixAlgebra<D>,
    D: HasDims<Dims = Dims> + Clone + 'static + Send + Sync,
    Dims: PartialEq + Debug + Clone + 'static + Send + Sync
[src]

Auto Trait Implementations

impl<C> !RefUnwindSafe for Graph<C>

impl<C> Send for Graph<C> where
    <C as Config>::EvalAlgebra: Send

impl<C> Sync for Graph<C> where
    <C as Config>::EvalAlgebra: Sync

impl<C> Unpin for Graph<C> where
    <C as Config>::EvalAlgebra: Unpin

impl<C> !UnwindSafe for Graph<C>

Blanket Implementations

impl<T> Any for T where
    T: 'static + ?Sized
[src]

impl<T> Borrow<T> for T where
    T: ?Sized
[src]

impl<T> BorrowMut<T> for T where
    T: ?Sized
[src]

impl<T> From<T> for T[src]

impl<T, U> Into<U> for T where
    U: From<T>, 
[src]

impl<D, A> LinkedAlgebra<Value<D>, D> for A where
    A: CoreAlgebra<D, Value = D>, 
[src]

impl<T> ToOwned for T where
    T: Clone
[src]

type Owned = T

The resulting type after obtaining ownership.

impl<T, U> TryFrom<U> for T where
    U: Into<T>, 
[src]

type Error = Infallible

The type returned in the event of a conversion error.

impl<T, U> TryInto<U> for T where
    U: TryFrom<T>, 
[src]

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.