pub enum Node {
Input {
it: Option<Box<dyn Iterator<Item = ArrayD<f32>>>>,
},
Parameter(Box<[usize]>),
Conv {
kernel: Idx,
img: Idx,
conv: Conv,
},
Add {
xs: Vec<Idx>,
},
Mult {
xs: Vec<Idx>,
},
MatMul {
mat: Idx,
v: Idx,
},
Activation {
x: Idx,
a: Activation,
},
Embedding {
emb: Idx,
code: Idx,
},
GlobalPool {
pool: GlobalPool,
x: Idx,
},
Operation {
inputs: Box<[Idx]>,
operation: Box<dyn Operation>,
},
Constant,
}
Expand description
Nodes are the building blocks of the computation graph.
The variants of a node differ in how the value is produced and how loss is propagated back.
Users typically interact with Nodes with their index :Idx
which is returned by the graph
when registered / created.
Variants§
Input
Produce Value from beyond the graph.
- In a forward pass, its value is updates by the iterator or panics if its None
- In a backward pass, its losses are currently calculated but unused.
- When serializing, the internal iterator is ignored. It deserializes to None.
Parameter(Box<[usize]>)
Parameter nodes only hold a shape. Its values are initialized when inserted into the graph using the graph’s initializer.
- In a foward pass, parameters are ignored.
- In a backward pass, their losses are applied by the graph’s optimizer.
Conv
See Conv for more.
Add
See Add for more.
Mult
See Mult for more.
MatMul
See Matmul for more.
Activation
See Activation for more.
Embedding
See Embedding for more.
GlobalPool
See GlobalPool for more.
Operation
An Operation node holds an Operation trait object and the indices referring to its input values.
- In a forward pass, its value is updated by the
operation
and the values indexed byinputs
. - In a backward pass, gradients are calculated and losses are propagated backwards and added
to the losses indexed by
inputs
.
Constant
Ignored by the graph, you have to set the values yourself
Implementations§
Trait Implementations§
source§impl<'de> Deserialize<'de> for Node
impl<'de> Deserialize<'de> for Node
source§fn deserialize<__D>(__deserializer: __D) -> Result<Self, __D::Error>where
__D: Deserializer<'de>,
fn deserialize<__D>(__deserializer: __D) -> Result<Self, __D::Error>where
__D: Deserializer<'de>,
Deserialize this value from the given Serde deserializer. Read more