use ndarray::{Array, ArrayD, ArrayViewD};
use nodes::*;
use std::collections::BTreeMap;
use std::fmt;
use optimizers::Optimizer;
use xavier_initialize;
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Hash, PartialEq, Eq)]
pub struct Idx {
idx: usize,
}
#[derive(DebugStub, Serialize, Deserialize)]
pub struct Graph {
nodes: BTreeMap<usize, Node>,
values: BTreeMap<usize, ArrayD<f32>>,
losses: BTreeMap<usize, ArrayD<f32>>,
num_inserted: usize,
#[debug_stub = "Initializer function"]
#[serde(skip)]
initializer: Initializer,
pub optimizer: Optimizer,
pub named_idxs: BTreeMap<String, Idx>,
}
struct Initializer(Box<(Fn(&[usize]) -> ArrayD<f32>)>);
impl Default for Initializer {
fn default() -> Self {
Initializer(Box::new(xavier_initialize))
}
}
impl fmt::Display for Graph {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, "Computation Graph with Optimizer:\n\t{}", self.optimizer)?;
for (i, node) in self.nodes.iter() {
writeln!(
f,
"\n{}\t{:?}\n\tvalue shape: {:?}\tloss shape: {:?}",
i,
node,
self.values[&i].shape(),
self.losses[&i].shape(),
)?
}
Ok(())
}
}
impl Default for Graph {
fn default() -> Self {
Graph::new(Box::new(xavier_initialize), Optimizer::default())
}
}
impl Graph {
pub fn new(initializer: Box<(Fn(&[usize]) -> ArrayD<f32>)>, optimizer: Optimizer) -> Self {
Graph {
nodes: BTreeMap::new(),
values: BTreeMap::new(),
losses: BTreeMap::new(),
named_idxs: BTreeMap::new(),
num_inserted: 0,
initializer: Initializer(initializer),
optimizer,
}
}
pub fn register(&mut self, node: Node) -> Idx {
let idx = self.num_inserted;
if let Node::Parameter(ref shape) = node {
self.optimizer.register(Idx { idx }, shape)
}
self.nodes.insert(idx, node);
self.values.insert(idx, Array::zeros(()).into_dyn());
self.losses.insert(idx, Array::zeros(()).into_dyn());
self.num_inserted += 1;
Idx { idx }
}
pub fn param(&mut self, shape: &[usize]) -> Idx {
let idx = self.register(Node::Parameter(shape.to_vec().into_boxed_slice()));
self.values.insert(idx.idx, (self.initializer.0)(shape));
self.losses.insert(idx.idx, Array::zeros(shape));
self.num_inserted += 1;
idx
}
pub fn input(&mut self, it: Option<Box<Iterator<Item = ArrayD<f32>>>>) -> Idx {
self.register(Node::Input { it })
}
pub fn op(&mut self, op: impl Operation + 'static, inputs: &[Idx]) -> Idx {
let o = Node::Operation {
operation: Box::new(op),
inputs: inputs.to_vec().into_boxed_slice(),
};
self.register(o)
}
pub fn constant(&mut self, c: ArrayD<f32>) -> Idx {
let idx = self.register(Node::Constant);
self.set_value(idx, c);
idx
}
fn _forward1(&mut self, i: usize) {
if let Some(n) = self.nodes.get_mut(&i) {
let inps = n.inputs();
if let Some(v) = n.forward(&view_at_idxs(&inps, &self.values)) {
self.values.insert(i, v);
}
}
self.losses.insert(i, Array::zeros(self.values[&i].shape()));
}
fn _backward1(&mut self, i: usize) {
if let Some(n) = self.nodes.get_mut(&i) {
if let Node::Parameter(..) = n {
self.optimizer.apply_gradient(
Idx { idx: i },
self.values.get_mut(&i).unwrap().view_mut(),
&self.losses[&i],
);
} else {
let inps = n.inputs();
let gradients = n.backward(&view_at_idxs(&inps, &self.values), &self.losses[&i]);
for (grad, j) in gradients.iter().zip(inps.iter()) {
if let Some(x) = self.losses.get_mut(&j.idx) {
*x += grad;
}
}
}
}
}
pub fn forward(&mut self) {
let keys: Vec<usize> = self.nodes.keys().cloned().collect();
for i in keys.into_iter() {
self._forward1(i);
}
}
pub fn backward(&mut self) {
let keys: Vec<usize> = self.nodes.keys().rev().cloned().collect();
for i in keys.into_iter() {
self._backward1(i);
}
}
pub fn forward1(&mut self, i: Idx) {
self._forward1(i.idx);
}
pub fn backward1(&mut self, i: Idx) {
self._backward1(i.idx);
}
pub fn remove(&mut self, idx: Idx) {
self.nodes.remove(&idx.idx);
self.values.remove(&idx.idx);
self.losses.remove(&idx.idx);
}
pub fn clear_non_parameters(&mut self) {
let mut keys = Vec::new();
for (i, n) in self.nodes.iter() {
if let Node::Parameter(_) = n {
} else {
keys.push(*i);
}
}
for k in keys.into_iter() {
self.nodes.remove(&k);
self.values.remove(&k);
self.losses.remove(&k);
}
}
pub fn set_value(&mut self, idx: Idx, val: ArrayD<f32>) {
if self.values.insert(idx.idx, val).is_none() {
panic!("Tried to set value at a removed index")
}
}
pub fn get_value(&self, idx: Idx) -> &ArrayD<f32> {
&self.values[&idx.idx]
}
pub fn set_loss(&mut self, idx: Idx, loss: ArrayD<f32>) {
if self.losses.insert(idx.idx, loss).is_none() {
panic!("Tried to set loss at a removed index")
}
}
pub fn get_loss(&self, idx: Idx) -> &ArrayD<f32> {
&self.losses[&idx.idx]
}
pub fn replace_input_iterator(
&mut self,
idx: Idx,
new: Box<Iterator<Item = ArrayD<f32>>>,
) -> Result<(), String> {
if let Some(n) = self.nodes.get_mut(&idx.idx) {
match n {
Node::Input { it } => *it = Some(new),
Node::Constant => *n = Node::Input { it: Some(new) },
_ => {
return Err("Tried to replace input iter at non Input/Constant node.".to_string())
}
}
Ok(())
} else {
Err("Tried to replace input iterator at invalid index.".to_string())
}
}
pub fn add(&mut self, inputs: &[Idx]) -> Idx {
self.register(Node::Add {
xs: inputs.to_vec(),
})
}
pub fn mult(&mut self, inputs: &[Idx]) -> Idx {
self.register(Node::Mult {
xs: inputs.to_vec(),
})
}
pub fn conv(&mut self, kernel: Idx, img: Idx, padding: Padding, stride: usize) -> Idx {
self.register(Node::Conv {
kernel,
img,
conv: Conv::new(padding, stride),
})
}
pub fn global_pool(&mut self, x: Idx, pool: GlobalPool) -> Idx {
self.register(Node::GlobalPool { x, pool })
}
pub fn relu(&mut self, x: Idx) -> Idx {
self.register(Node::Activation {
x,
a: Activation::Relu { leak: 0.0 },
})
}
pub fn sigmoid(&mut self, x: Idx) -> Idx {
self.register(Node::Activation {
x,
a: Activation::Sigmoid,
})
}
pub fn tanh(&mut self, x: Idx) -> Idx {
self.register(Node::Activation {
x,
a: Activation::Tanh,
})
}
pub fn matmul(&mut self, mat: Idx, v: Idx) -> Idx {
self.register(Node::MatMul { mat, v })
}
pub fn embedding(&mut self, emb: Idx, code: Idx) -> Idx {
self.register(Node::Embedding { emb, code })
}
}
fn view_at_idxs<'a>(
indices: &[Idx],
nodes: &'a BTreeMap<usize, ArrayD<f32>>,
) -> Box<[ArrayViewD<'a, f32>]> {
let mut vals = Vec::new();
for i in indices.iter() {
vals.push(nodes[&i.idx].view());
}
vals.into_boxed_slice()
}