1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
//! This module holds the different types nodes that exist in a computation graph. Nodes that
//! represent a differentiable computation are implemented by a struct with the "Operation" trait.
//! Use [Graph](../struct.Graph.html) methods to create and register nodes inside a graph.
//! See [Node](enum.Node.html) for the types of node available.
//! This module may eventually be made private...

pub use self::activation::*;
pub use self::arithmetic::{Add, Mult};
pub use self::conv::Conv;
pub use self::conv::Padding;
pub use self::embedding::Embedding;
pub use self::global_pool::GlobalPool;
pub use self::matmul::MatMul;

use graph::Idx;
use ndarray::prelude::*;
use std::fmt::Debug;
mod activation;
mod arithmetic;
mod conv;
mod embedding;
mod global_pool;
mod matmul;

/// Represents a differentiable function in a computation graph.
/// Operations hold their own hyperparameters but not their parameters, values or losses.
/// Unfortunately boxed traits cannot be saved with serde. When reloaded they will be replaced
/// by `Box<arithmetic::Add>` nodes. When reloading a model with custom Operations, you need to
/// replace them manually.
pub trait Operation: Debug {
    /// Mutates Outputs based on inputs.
    /// TODO consider modifying output ArrayD<f32>  in place
    fn eval(&self, inputs: &[ArrayViewD<f32>]) -> ArrayD<f32>;
    // fn eval(&self, inputs: Box<[ArrayViewD<f32>]>) -> ArrayD<f32>;

    /// Returns gradients of inputs wrt outputs.
    /// Note the inputs and output vectors should be the same length.
    /// TODO consider modifying output ArrayD<f32>s  in place
    fn grad(&self, inputs: &[ArrayViewD<f32>], loss: ArrayViewD<f32>) -> Vec<ArrayD<f32>>;
}
serialize_trait_object!(Operation);

#[derive(DebugStub, Serialize, Deserialize)]
/// Nodes are the building blocks of the [computation graph](../struct.Graph.html).
/// The variants of a node differ in how the value is produced and how loss is propagated back.
/// Users typically interact with Nodes with their index `:Idx` which is returned by the graph
/// when registered / created.
pub enum Node {
    /// Produce Value from beyond the graph.
    /// * In a forward pass, its value is updates by the iterator or panics if its None
    /// * In a backward pass, its losses are currently calculated but unused.
    /// * When serializing, the internal iterator is ignored. It deserializes to None.
    Input {
        #[serde(skip)]
        #[debug_stub = "Option<Box<Iterator<Item=ArrayD<f32>>>>"]
        it: Option<Box<Iterator<Item = ArrayD<f32>>>>,
    },

    /// Parameter nodes only hold a shape. Its values are initialized when inserted into the graph
    /// using the graph's initializer.
    /// * In a foward pass, parameters are ignored.
    /// * In a backward pass, their losses are applied by the graph's optimizer.
    Parameter(Box<[usize]>),
    /// See [Conv](struct.Conv.html) for more.
    Conv { kernel: Idx, img: Idx, conv: Conv },
    /// See [Add](struct.Add.html) for more.
    Add { xs: Vec<Idx> },
    /// See [Mult](struct.Mult.html) for more.
    Mult { xs: Vec<Idx> },
    /// See [Matmul](struct.Matmul.html) for more.
    MatMul { mat: Idx, v: Idx },
    /// See [Activation](enum.Activation.html) for more.
    Activation { x: Idx, a: Activation },
    /// See [Embedding](struct.Embedding.html) for more.
    Embedding { emb: Idx, code: Idx },
    /// See [GlobalPool](struct.GlobalPool.html) for more.
    GlobalPool { pool: GlobalPool, x: Idx },
    /// An Operation node holds an [Operation trait object](trait.Operation.html) and the indices
    /// referring to its input values.
    /// * In a forward pass, its value is updated by the `operation` and the values indexed by
    /// `inputs`.
    /// * In a backward pass, gradients are calculated and losses are propagated backwards and added
    /// to the losses indexed by `inputs`.
    Operation {
        inputs: Box<[Idx]>,
        #[serde(skip_deserializing)]
        operation: Box<Operation>,
    },

    /// Ignored by the graph, you have to set the values yourself
    Constant,
}

impl Node {
    pub fn inputs(&self) -> Vec<Idx> {
        match self {
            Node::Conv { kernel, img, .. } => vec![*kernel, *img],
            Node::Add { xs } => xs.to_vec(),
            Node::Mult { xs } => xs.to_vec(),
            Node::MatMul { mat, v } => vec![*mat, *v],
            Node::Activation { x, .. } => vec![*x],
            Node::Embedding { emb, code } => vec![*emb, *code],
            Node::GlobalPool { x, .. } => vec![*x],
            Node::Operation { inputs, .. } => inputs.to_vec(),
            Node::Input { .. } | Node::Parameter(..) | Node::Constant => vec![],
        }
    }
    pub fn forward(&mut self, inputs: &[ArrayViewD<f32>]) -> Option<ArrayD<f32>> {
        match self {
            Node::Conv { conv, .. } => Some(conv.eval(inputs)),
            Node::Add { .. } => Some(Add().eval(inputs)),
            Node::Mult { .. } => Some(Mult().eval(inputs)),
            Node::MatMul { .. } => Some(MatMul().eval(inputs)),
            Node::Activation { a, .. } => Some(a.eval(inputs)),
            Node::Embedding { .. } => Some(Embedding().eval(inputs)),
            Node::GlobalPool { pool, .. } => Some(pool.eval(inputs)),
            Node::Operation { operation, .. } => Some(operation.eval(inputs)),
            Node::Input { ref mut it } => it.as_mut().expect("Input node uninitialized.").next(),
            Node::Parameter(..) | Node::Constant => None,
        }
    }
    pub fn backward(&self, inputs: &[ArrayViewD<f32>], loss: &ArrayD<f32>) -> Vec<ArrayD<f32>> {
        match self {
            Node::Conv { conv, .. } => conv.grad(inputs, loss.view()),
            Node::Add { .. } => Add().grad(inputs, loss.view()),
            Node::Mult { .. } => Mult().grad(inputs, loss.view()),
            Node::MatMul { .. } => MatMul().grad(inputs, loss.view()),
            Node::Activation { a, .. } => a.grad(inputs, loss.view()),
            Node::Embedding { .. } => Embedding().grad(inputs, loss.view()),
            Node::GlobalPool { pool, .. } => pool.grad(inputs, loss.view()),
            Node::Operation { operation, .. } => operation.grad(inputs, loss.view()),
            Node::Input { .. } | Node::Constant | Node::Parameter(..) => vec![],
        }
    }
}

// TODO figure out serialization and deserialization of Boxed traits. This may not be possible :/
impl Default for Box<Operation> {
    fn default() -> Self {
        Box::new(arithmetic::Add())
    }
}