pub use self::activation::*;
pub use self::arithmetic::{Add, Mult};
pub use self::conv::Conv;
pub use self::conv::Padding;
pub use self::embedding::Embedding;
pub use self::global_pool::GlobalPool;
pub use self::matmul::MatMul;
use graph::Idx;
use ndarray::prelude::*;
use std::fmt::Debug;
mod activation;
mod arithmetic;
mod conv;
mod embedding;
mod global_pool;
mod matmul;
pub trait Operation: Debug {
fn eval(&self, inputs: &[ArrayViewD<f32>]) -> ArrayD<f32>;
fn grad(&self, inputs: &[ArrayViewD<f32>], loss: ArrayViewD<f32>) -> Vec<ArrayD<f32>>;
}
serialize_trait_object!(Operation);
#[derive(DebugStub, Serialize, Deserialize)]
pub enum Node {
Input {
#[serde(skip)]
#[debug_stub = "Option<Box<Iterator<Item=ArrayD<f32>>>>"]
it: Option<Box<Iterator<Item = ArrayD<f32>>>>,
},
Parameter(Box<[usize]>),
Conv { kernel: Idx, img: Idx, conv: Conv },
Add { xs: Vec<Idx> },
Mult { xs: Vec<Idx> },
MatMul { mat: Idx, v: Idx },
Activation { x: Idx, a: Activation },
Embedding { emb: Idx, code: Idx },
GlobalPool { pool: GlobalPool, x: Idx },
Operation {
inputs: Box<[Idx]>,
#[serde(skip_deserializing)]
operation: Box<Operation>,
},
Constant,
}
impl Node {
pub fn inputs(&self) -> Vec<Idx> {
match self {
Node::Conv { kernel, img, .. } => vec![*kernel, *img],
Node::Add { xs } => xs.to_vec(),
Node::Mult { xs } => xs.to_vec(),
Node::MatMul { mat, v } => vec![*mat, *v],
Node::Activation { x, .. } => vec![*x],
Node::Embedding { emb, code } => vec![*emb, *code],
Node::GlobalPool { x, .. } => vec![*x],
Node::Operation { inputs, .. } => inputs.to_vec(),
Node::Input { .. } | Node::Parameter(..) | Node::Constant => vec![],
}
}
pub fn forward(&mut self, inputs: &[ArrayViewD<f32>]) -> Option<ArrayD<f32>> {
match self {
Node::Conv { conv, .. } => Some(conv.eval(inputs)),
Node::Add { .. } => Some(Add().eval(inputs)),
Node::Mult { .. } => Some(Mult().eval(inputs)),
Node::MatMul { .. } => Some(MatMul().eval(inputs)),
Node::Activation { a, .. } => Some(a.eval(inputs)),
Node::Embedding { .. } => Some(Embedding().eval(inputs)),
Node::GlobalPool { pool, .. } => Some(pool.eval(inputs)),
Node::Operation { operation, .. } => Some(operation.eval(inputs)),
Node::Input { ref mut it } => it.as_mut().expect("Input node uninitialized.").next(),
Node::Parameter(..) | Node::Constant => None,
}
}
pub fn backward(&self, inputs: &[ArrayViewD<f32>], loss: &ArrayD<f32>) -> Vec<ArrayD<f32>> {
match self {
Node::Conv { conv, .. } => conv.grad(inputs, loss.view()),
Node::Add { .. } => Add().grad(inputs, loss.view()),
Node::Mult { .. } => Mult().grad(inputs, loss.view()),
Node::MatMul { .. } => MatMul().grad(inputs, loss.view()),
Node::Activation { a, .. } => a.grad(inputs, loss.view()),
Node::Embedding { .. } => Embedding().grad(inputs, loss.view()),
Node::GlobalPool { pool, .. } => pool.grad(inputs, loss.view()),
Node::Operation { operation, .. } => operation.grad(inputs, loss.view()),
Node::Input { .. } | Node::Constant | Node::Parameter(..) => vec![],
}
}
}
impl Default for Box<Operation> {
fn default() -> Self {
Box::new(arithmetic::Add())
}
}