drug/nodes/
mod.rs

1//! This module holds the different types nodes that exist in a computation graph. Nodes that
2//! represent a differentiable computation are implemented by a struct with the "Operation" trait.
3//! Use [Graph](../struct.Graph.html) methods to create and register nodes inside a graph.
4//! See [Node](enum.Node.html) for the types of node available.
5//! This module may eventually be made private...
6
7pub use self::activation::*;
8pub use self::arithmetic::{Add, Mult};
9pub use self::conv::Conv;
10pub use self::conv::Padding;
11pub use self::embedding::Embedding;
12pub use self::global_pool::GlobalPool;
13pub use self::matmul::MatMul;
14
15use graph::Idx;
16use ndarray::prelude::*;
17use std::fmt::Debug;
18mod activation;
19mod arithmetic;
20mod conv;
21mod embedding;
22mod global_pool;
23mod matmul;
24
25/// Represents a differentiable function in a computation graph.
26/// Operations hold their own hyperparameters but not their parameters, values or losses.
27/// Unfortunately boxed traits cannot be saved with serde. When reloaded they will be replaced
28/// by `Box<arithmetic::Add>` nodes. When reloading a model with custom Operations, you need to
29/// replace them manually.
30pub trait Operation: Debug {
31    /// Mutates Outputs based on inputs.
32    /// TODO consider modifying output ArrayD<f32>  in place
33    fn eval(&self, inputs: &[ArrayViewD<f32>]) -> ArrayD<f32>;
34    // fn eval(&self, inputs: Box<[ArrayViewD<f32>]>) -> ArrayD<f32>;
35
36    /// Returns gradients of inputs wrt outputs.
37    /// Note the inputs and output vectors should be the same length.
38    /// TODO consider modifying output ArrayD<f32>s  in place
39    fn grad(&self, inputs: &[ArrayViewD<f32>], loss: ArrayViewD<f32>) -> Vec<ArrayD<f32>>;
40}
41serialize_trait_object!(Operation);
42
43#[derive(DebugStub, Serialize, Deserialize)]
44/// Nodes are the building blocks of the [computation graph](../struct.Graph.html).
45/// The variants of a node differ in how the value is produced and how loss is propagated back.
46/// Users typically interact with Nodes with their index `:Idx` which is returned by the graph
47/// when registered / created.
48pub enum Node {
49    /// Produce Value from beyond the graph.
50    /// * In a forward pass, its value is updates by the iterator or panics if its None
51    /// * In a backward pass, its losses are currently calculated but unused.
52    /// * When serializing, the internal iterator is ignored. It deserializes to None.
53    Input {
54        #[serde(skip)]
55        #[debug_stub = "Option<Box<Iterator<Item=ArrayD<f32>>>>"]
56        it: Option<Box<Iterator<Item = ArrayD<f32>>>>,
57    },
58
59    /// Parameter nodes only hold a shape. Its values are initialized when inserted into the graph
60    /// using the graph's initializer.
61    /// * In a foward pass, parameters are ignored.
62    /// * In a backward pass, their losses are applied by the graph's optimizer.
63    Parameter(Box<[usize]>),
64    /// See [Conv](struct.Conv.html) for more.
65    Conv { kernel: Idx, img: Idx, conv: Conv },
66    /// See [Add](struct.Add.html) for more.
67    Add { xs: Vec<Idx> },
68    /// See [Mult](struct.Mult.html) for more.
69    Mult { xs: Vec<Idx> },
70    /// See [Matmul](struct.Matmul.html) for more.
71    MatMul { mat: Idx, v: Idx },
72    /// See [Activation](enum.Activation.html) for more.
73    Activation { x: Idx, a: Activation },
74    /// See [Embedding](struct.Embedding.html) for more.
75    Embedding { emb: Idx, code: Idx },
76    /// See [GlobalPool](struct.GlobalPool.html) for more.
77    GlobalPool { pool: GlobalPool, x: Idx },
78    /// An Operation node holds an [Operation trait object](trait.Operation.html) and the indices
79    /// referring to its input values.
80    /// * In a forward pass, its value is updated by the `operation` and the values indexed by
81    /// `inputs`.
82    /// * In a backward pass, gradients are calculated and losses are propagated backwards and added
83    /// to the losses indexed by `inputs`.
84    Operation {
85        inputs: Box<[Idx]>,
86        #[serde(skip_deserializing)]
87        operation: Box<Operation>,
88    },
89
90    /// Ignored by the graph, you have to set the values yourself
91    Constant,
92}
93
94impl Node {
95    pub fn inputs(&self) -> Vec<Idx> {
96        match self {
97            Node::Conv { kernel, img, .. } => vec![*kernel, *img],
98            Node::Add { xs } => xs.to_vec(),
99            Node::Mult { xs } => xs.to_vec(),
100            Node::MatMul { mat, v } => vec![*mat, *v],
101            Node::Activation { x, .. } => vec![*x],
102            Node::Embedding { emb, code } => vec![*emb, *code],
103            Node::GlobalPool { x, .. } => vec![*x],
104            Node::Operation { inputs, .. } => inputs.to_vec(),
105            Node::Input { .. } | Node::Parameter(..) | Node::Constant => vec![],
106        }
107    }
108    pub fn forward(&mut self, inputs: &[ArrayViewD<f32>]) -> Option<ArrayD<f32>> {
109        match self {
110            Node::Conv { conv, .. } => Some(conv.eval(inputs)),
111            Node::Add { .. } => Some(Add().eval(inputs)),
112            Node::Mult { .. } => Some(Mult().eval(inputs)),
113            Node::MatMul { .. } => Some(MatMul().eval(inputs)),
114            Node::Activation { a, .. } => Some(a.eval(inputs)),
115            Node::Embedding { .. } => Some(Embedding().eval(inputs)),
116            Node::GlobalPool { pool, .. } => Some(pool.eval(inputs)),
117            Node::Operation { operation, .. } => Some(operation.eval(inputs)),
118            Node::Input { ref mut it } => it.as_mut().expect("Input node uninitialized.").next(),
119            Node::Parameter(..) | Node::Constant => None,
120        }
121    }
122    pub fn backward(&self, inputs: &[ArrayViewD<f32>], loss: &ArrayD<f32>) -> Vec<ArrayD<f32>> {
123        match self {
124            Node::Conv { conv, .. } => conv.grad(inputs, loss.view()),
125            Node::Add { .. } => Add().grad(inputs, loss.view()),
126            Node::Mult { .. } => Mult().grad(inputs, loss.view()),
127            Node::MatMul { .. } => MatMul().grad(inputs, loss.view()),
128            Node::Activation { a, .. } => a.grad(inputs, loss.view()),
129            Node::Embedding { .. } => Embedding().grad(inputs, loss.view()),
130            Node::GlobalPool { pool, .. } => pool.grad(inputs, loss.view()),
131            Node::Operation { operation, .. } => operation.grad(inputs, loss.view()),
132            Node::Input { .. } | Node::Constant | Node::Parameter(..) => vec![],
133        }
134    }
135}
136
137// TODO figure out serialization and deserialization of Boxed traits. This may not be possible :/
138impl Default for Box<Operation> {
139    fn default() -> Self {
140        Box::new(arithmetic::Add())
141    }
142}