pub struct Graph {
pub optimizer: Optimizer,
pub named_idxs: BTreeMap<String, Idx>,
/* private fields */
}
Expand description
A differentiable computation graph. Use this struct to hold your differentiable program
which is a directed acyclic graph of Nodes, their associated values
and losses (gradients). The graph computes values moving forward in insertion order (see
forward
method) and propagates losses backwards in reverse insertion order (see backward
method). The default graph comes with an xavier initializer and a vanilla stochastic gradient
descent optimizer.
Fields§
§optimizer: Optimizer
§named_idxs: BTreeMap<String, Idx>
Implementations§
source§impl Graph
impl Graph
sourcepub fn new(
initializer: Box<dyn Fn(&[usize]) -> ArrayD<f32>>,
optimizer: Optimizer
) -> Self
pub fn new(
initializer: Box<dyn Fn(&[usize]) -> ArrayD<f32>>,
optimizer: Optimizer
) -> Self
Consider using Graph::default()
if you don’t want to choose your own optimizer and
initializer.
sourcepub fn register(&mut self, node: Node) -> Idx
pub fn register(&mut self, node: Node) -> Idx
Inserts the node into the graph and returns the index
sourcepub fn param(&mut self, shape: &[usize]) -> Idx
pub fn param(&mut self, shape: &[usize]) -> Idx
Registers a parameter of the given shape and initializes the value using the graph’s initializer.
sourcepub fn input(&mut self, it: Option<Box<dyn Iterator<Item = ArrayD<f32>>>>) -> Idx
pub fn input(&mut self, it: Option<Box<dyn Iterator<Item = ArrayD<f32>>>>) -> Idx
Registers an input node which advances the iterator it
each forward pass.
sourcepub fn op(&mut self, op: impl Operation + 'static, inputs: &[Idx]) -> Idx
pub fn op(&mut self, op: impl Operation + 'static, inputs: &[Idx]) -> Idx
Registers an operation and its inputs
sourcepub fn constant(&mut self, c: ArrayD<f32>) -> Idx
pub fn constant(&mut self, c: ArrayD<f32>) -> Idx
Registers a constant, sets its value to c
, then returns the idx
sourcepub fn forward(&mut self)
pub fn forward(&mut self)
Computes values for each node in insertion order. Parameters are unaffected. Inputs will set their value to the next output of their iterator, Operations will compute a new value based on the values of its inputs.
sourcepub fn backward(&mut self)
pub fn backward(&mut self)
Propagates gradients in reverse insertion order. Parameters will apply gradients with the graph’s optimizer. Inputs are unaffected Operations will compute gradient given values from their inputs and gradients from its outputs
sourcepub fn remove(&mut self, idx: Idx)
pub fn remove(&mut self, idx: Idx)
Remove the node at idx
as well as its associated value and loss.
sourcepub fn clear_non_parameters(&mut self)
pub fn clear_non_parameters(&mut self)
This op removes every node from the graph that is not a parameter. This is useful for dynamic graphs and recurrent neural networks when you want to rebuild everything each forward and backward pass of the network.
pub fn set_value(&mut self, idx: Idx, val: ArrayD<f32>)
pub fn get_value(&self, idx: Idx) -> &ArrayD<f32>
pub fn set_loss(&mut self, idx: Idx, loss: ArrayD<f32>)
pub fn get_loss(&self, idx: Idx) -> &ArrayD<f32>
sourcepub fn replace_input_iterator(
&mut self,
idx: Idx,
new: Box<dyn Iterator<Item = ArrayD<f32>>>
) -> Result<(), String>
pub fn replace_input_iterator(
&mut self,
idx: Idx,
new: Box<dyn Iterator<Item = ArrayD<f32>>>
) -> Result<(), String>
Replace an Input node’s iterator or converts Constant nodes into Input with this iterator. Note that Input node iterators are not saved when serialized with serde.
pub fn add(&mut self, inputs: &[Idx]) -> Idx
pub fn mult(&mut self, inputs: &[Idx]) -> Idx
sourcepub fn conv(
&mut self,
kernel: Idx,
img: Idx,
padding: Padding,
stride: usize
) -> Idx
pub fn conv(
&mut self,
kernel: Idx,
img: Idx,
padding: Padding,
stride: usize
) -> Idx
Registers a convolution operation node and returns the index
sourcepub fn global_pool(&mut self, x: Idx, pool: GlobalPool) -> Idx
pub fn global_pool(&mut self, x: Idx, pool: GlobalPool) -> Idx
Registers a pooling operation takes a Batch * Height * Width * Channels
image and reduces
it to a Batch * Channels
vector.
sourcepub fn relu(&mut self, x: Idx) -> Idx
pub fn relu(&mut self, x: Idx) -> Idx
Registers a Relu operation which takes the elementwise maximum of the input array and 0.
sourcepub fn sigmoid(&mut self, x: Idx) -> Idx
pub fn sigmoid(&mut self, x: Idx) -> Idx
Registers a new sigmoid activation operation, an elementwise application of $\frac{ 1 }{1 - e^{-x}}$.