extern crate ndarray;
use ndarray_ext::NdArray;
use std::cmp::Ordering;
use std::collections::btree_set::BTreeSet;
use tensor::Tensor;
pub fn gradient_check<'a, 'b, T>(
objective: &Tensor,
gradients: &[T],
variables: &[&Tensor],
feeds: &[(&Tensor, &NdArray)],
eps: f32,
tol: f32,
) where
T: AsRef<Tensor>,
{
let theoretical_grads = ::runtime::eval(gradients, feeds);
for (var_node, th_grad) in variables.iter().zip(theoretical_grads) {
let v_arr = unsafe {
var_node
.persistent_array
.as_ref()
.expect("This is not variable")
.get_as_variable_mut()
};
let head_ptr: *mut f32 = v_arr.as_mut_ptr();
for i in 0..v_arr.len() as isize {
let evacuated;
unsafe {
evacuated = *head_ptr.offset(i);
*head_ptr.offset(i) = evacuated + eps;
}
let obj_pos = objective.eval(feeds).unwrap();
unsafe {
*head_ptr.offset(i) = evacuated - eps;
}
let obj_neg = objective.eval(feeds).unwrap();
unsafe {
*head_ptr.offset(i) = evacuated;
}
let g_num = (obj_pos - obj_neg).scalar_sum() / (2. * eps);
let g_th = unsafe { *th_grad.as_ref().unwrap().as_ptr().offset(i) };
let diff = (g_num - g_th).abs();
if diff > tol {
panic!(
"Gradient checking failed with too large error: num={}, bp={}",
g_num, g_th
);
}
}
}
}
pub fn visit_once<F>(t: &Tensor, f: &mut F)
where
F: FnMut(&Tensor) -> (),
{
visit_once_internal(t, f, &mut BTreeSet::new())
}
fn visit_once_internal<'a, F>(t: &'a Tensor, f: &mut F, visited: &mut BTreeSet<&'a Tensor>)
where
F: FnMut(&'a Tensor) -> (),
{
if visited.contains(&t) {
return; } else {
visited.insert(t); }
f(&t);
for child in t.inputs.iter() {
visit_once_internal(child, f, visited)
}
}
impl<'a> Ord for &'a Tensor {
#[inline]
fn cmp(&self, other: &&'a Tensor) -> Ordering {
let a = (*self) as *const Tensor;
let b = (*other) as *const Tensor;
a.cmp(&b)
}
}
impl<'a> PartialOrd for &'a Tensor {
#[inline]
fn partial_cmp(&self, other: &&'a Tensor) -> Option<Ordering> {
Some(self.cmp(&other))
}
}
#[doc(hidden)]
#[macro_export]
macro_rules! eval_with_time {
($x:expr) => {{
use std::time::{Duration, Instant};
let start = Instant::now();
let result = $x;
let end = start.elapsed();
println!(
"{}.{:03} sec",
end.as_secs(),
end.subsec_nanos() / 1_000_000
);
result
}};
}