use std::slice::Iter;
use activation::Activation;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct Layer{
pub inputs: usize,
pub outputs: usize,
pub activation: Activation
}
impl Layer {
fn new(inputs: usize, outputs: usize, activation: Activation) -> Self {
Layer{
inputs: inputs,
outputs: outputs,
activation: activation
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TopologyBuilder {
last : usize,
layers: Vec<Layer>
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Topology {
layers: Vec<Layer>
}
impl Topology {
pub fn input(size: usize) -> TopologyBuilder {
assert!(size >= 1, "cannot define a zero-sized input layer");
TopologyBuilder{
last : size,
layers: vec![]
}
}
pub fn len_input(&self) -> usize {
self.layers
.first()
.expect("a finished disciple must have a valid first layer!")
.inputs
}
pub fn len_output(&self) -> usize {
self.layers
.last()
.expect("a finished disciple must have a valid last layer!")
.outputs
}
pub fn iter_layers(&self) -> Iter<Layer> {
self.layers.iter()
}
}
impl TopologyBuilder {
fn push_layer(&mut self, layer_size: usize, act: Activation) {
assert!(layer_size >= 1, "cannot define a zero-sized hidden layer");
self.layers.push(Layer::new(self.last, layer_size, act));
self.last = layer_size;
}
pub fn layer(mut self, layer_size: usize, act: Activation) -> TopologyBuilder {
self.push_layer(layer_size, act);
self
}
pub fn layers(mut self, layers: &[(usize, Activation)]) -> TopologyBuilder {
for &layer in layers {
self.push_layer(layer.0, layer.1);
}
self
}
pub fn output(mut self, layer_size: usize, act: Activation) -> Topology {
assert!(layer_size >= 1, "cannot define a zero-sized output layer");
self.push_layer(layer_size, act);
Topology {
layers: self.layers,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn construction() {
use self::Activation::{Logistic, Identity, ReLU, Tanh};
let dis = Topology::input(2)
.layer(5, Logistic)
.layers(&[
(10, Identity),
(10, ReLU)
])
.output(5, Tanh);
let mut it = dis.iter_layers()
.map(|&size| size);
assert_eq!(it.next(), Some(Layer::new(2, 5, Logistic)));
assert_eq!(it.next(), Some(Layer::new(5, 10, Identity)));
assert_eq!(it.next(), Some(Layer::new(10, 10, ReLU)));
assert_eq!(it.next(), Some(Layer::new(10, 5, Tanh)));
}
}