#![feature(test)]
#[macro_use]
pub extern crate ndarray;
extern crate rand;
extern crate test;
#[macro_use]
extern crate debug_stub_derive;
#[macro_use]
extern crate serde_derive;
#[macro_use]
extern crate erased_serde;
extern crate serde;
#[cfg(test)]
#[macro_use(iproduct)]
extern crate itertools;
use ndarray::prelude::*;
use rand::distributions::{Distribution, Normal};
use rand::thread_rng;
mod graph;
pub mod nodes;
mod optimizers;
pub use graph::*;
pub use nodes::{GlobalPool, Operation, Padding};
pub use optimizers::Optimizer;
pub fn xavier_initialize(shape: &[usize]) -> ArrayD<f32> {
let (n_in, n_out) = match shape.len() {
4 => (shape[2], shape[3]), 2 => (shape[0], shape[1]), 1 => (shape[0], shape[0]), x => unimplemented!("Initialize with {:?}", x),
};
let var = 2.0 / (n_in as f64 + n_out as f64);
let normal = Normal::new(0.0, var.sqrt());
let mut rng = thread_rng();
ArrayD::from_shape_fn(shape, |_| normal.sample(&mut rng) as f32)
}
pub fn softmax(logits: &ArrayD<f32>) -> Array2<f32> {
let mut softmax = logits.to_owned().into_dimensionality::<Ix2>().unwrap();
let max = softmax.fold_axis(Axis(1), 0.0, |x, y| if *x > *y { *x } else { *y });
for ((b, _), x) in softmax.indexed_iter_mut() {
*x = (*x - max[b]).exp();
}
let sum = softmax.sum_axis(Axis(1));
for ((b, _), x) in softmax.indexed_iter_mut() {
*x /= sum[b];
}
softmax
}
pub fn softmax_cross_entropy_loss(logits: &ArrayD<f32>, labels: &[usize]) -> (f32, ArrayD<f32>) {
let mut softmax = softmax(logits);
let mut log_loss = 0.0;
for (b, lbl) in labels.iter().enumerate() {
let correct = *lbl;
log_loss -= softmax[(b, correct)].ln();
softmax[(b, correct)] -= 1.0;
}
log_loss /= labels.len() as f32;
(log_loss, softmax.into_dyn())
}
#[cfg(test)]
mod libc {
use super::*;
use graph::Graph;
use std::f32;
#[test]
fn param_initialize() {
let mut g = Graph::default();
let x = g.param(&[3, 3, 1, 8]);
assert_eq!(g.get_value(x).shape(), [3, 3, 1, 8]);
}
#[test]
fn softmax_vs_correct() {
let logits = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
let correct = arr2(&[
[
9.003057317038046e-2,
0.24472847105479767,
0.6652409557748219,
],
[
9.003057317038045e-2,
0.24472847105479764,
0.6652409557748219,
],
]);
let softmax = softmax(&logits.into_dyn());
for i in 0..2 {
for j in 0..3 {
assert!((softmax[(i, j)] - correct[(i, j)]).abs() < f32::EPSILON);
}
}
}
}