extern crate ndarray;
use ndarray_ext::NdArray;
use op;
use std::collections::btree_map::Entry;
use std::collections::BTreeMap;
use std::mem;
use std::rc::Rc;
use tensor::Tensor;
pub struct Eval<'a> {
buf: Vec<&'a Tensor>,
}
impl<'t> Eval<'t> {
pub fn new() -> Self {
Eval { buf: Vec::new() }
}
pub fn push(&mut self, x: &'t Tensor) -> &mut Self {
self.buf.push(x);
self
}
pub fn extend<A>(&mut self, xs: &'t [A]) -> &mut Self
where
A: AsRef<Tensor>,
{
self.buf.extend(xs.iter().map(|x| x.as_ref()));
self
}
pub fn run<'tpl, 'tsr: 'tpl, 'arr: 'tpl, F>(&self, feed: F) -> Vec<Option<NdArray>>
where
F: IntoIterator<Item = &'tpl (&'tsr Tensor, &'arr ndarray::Array<f32, ndarray::IxDyn>)>,
{
eval(&self.buf, feed)
}
}
pub struct OpComputeContext<'a, 'b> {
pub node: &'a Tensor, pub xs: Vec<&'b NdArray>,
}
impl<'a, 'b> OpComputeContext<'a, 'b> {
#[inline]
pub fn grab_inputs(&self) -> &[&NdArray] {
self.xs.as_slice()
}
#[inline]
#[allow(mutable_transmutes)]
pub unsafe fn grab_assignable_inputs(&mut self) -> &mut [&mut NdArray] {
mem::transmute(self.xs.as_slice())
}
#[inline]
pub fn grab_input_node(&self, i: usize) -> &Tensor {
&self.node.inputs[i]
}
#[inline]
fn _grab_inputs<'n, 's: 'n>(
node: &'s Tensor,
store: &'n ResourceStore,
feed_store: &FeedStore<'n>,
) -> Option<Vec<&'n NdArray>> {
fn recurse<'n, 's: 'n>(
x: &'s Tensor,
store: &'n ResourceStore,
feed_store: &FeedStore<'n>,
value_index: usize,
) -> Option<&'n NdArray> {
if let Some(ref per) = x.persistent_array {
Some(per.get_array())
} else if x.is_placeholder {
Some(feed_store[x.resource_lookup_key.get()])
} else {
match store[x.resource_lookup_key.get()].value[value_index] {
Ok(ref a) => Some(a),
Err(::op::ComputeError::Delegate { to: i }) => {
recurse(&x.inputs[i], store, feed_store, x.input_indices[i])
}
_ => None, }
}
}
let input_nodes = &node.inputs;
let mut input_arrays = Vec::with_capacity(input_nodes.len());
for (x, &i) in input_nodes.into_iter().zip(node.input_indices.iter()) {
if let Some(res) = recurse(x, store, feed_store, i) {
input_arrays.push(res);
} else {
return None;
}
}
Some(input_arrays)
}
}
struct NodeWithValue<'a> {
node: &'a Tensor,
value: op::ComputeResult,
pending_count: usize,
}
impl<'a> Tensor {
#[inline]
fn with_value(&'a self, val: op::ComputeResult) -> NodeWithValue<'a> {
NodeWithValue {
node: self,
value: val,
pending_count: 0,
}
}
}
pub fn eval<'a, 'b: 'a, 'c: 'a, T, U>(tensors: &[T], feeds: U) -> Vec<Option<NdArray>>
where
T: AsRef<Tensor>,
U: IntoIterator<Item = &'a (&'b Tensor, &'c ndarray::Array<f32, ndarray::IxDyn>)>,
{
let feeds = feeds.into_iter().collect::<Vec<_>>();
let mut output_storage = eval_internal(&tensors.iter().map(|t| t.as_ref()).collect(), &feeds);
let creators = tensors
.iter()
.map(|x| {
let x = x.as_ref();
let creator = if x.is_placeholder || x.persistent_array.is_some() {
x
} else {
let creator = find_resource_creator(&output_storage, x);
output_storage[creator.resource_lookup_key.get()].pending_count += 1;
creator
};
creator
})
.collect::<Vec<&Tensor>>();
let mut key2res: BTreeMap<usize, NodeWithValue> = finalize_resource_store(output_storage);
creators
.iter()
.map(|ref creator| {
if let Some(ref per) = creator.persistent_array {
Some(per.get_array().clone())
} else if creator.is_placeholder {
Some(find_fed_resource(creator, &feeds).clone())
} else {
let res = match key2res.entry(creator.resource_lookup_key.get()) {
Entry::Occupied(mut ent) => {
if ent.get().pending_count == 1 {
let mut got = ent.remove();
map_err(got.value.remove(0))
} else {
let mut got = ent.get_mut();
got.pending_count -= 1;
map_err(got.value[0].clone())
}
}
_ => unreachable!(),
};
res
}
})
.collect()
}
pub fn run<'a, 'b: 'a, 'c: 'a, T, U>(tensors: &[T], feeds: U)
where
T: AsRef<Tensor>,
U: IntoIterator<Item = &'a (&'b Tensor, &'c NdArray)>,
{
eval_internal(
&tensors.iter().map(|t| t.as_ref()).collect(),
&feeds.into_iter().collect(),
);
}
fn find_resource_creator<'a, 'b>(storage: &ResourceStore, x: &'b Tensor) -> &'b Tensor {
match storage[x.resource_lookup_key.get()].value[0] {
Err(::op::ComputeError::Delegate { to: i }) => find_resource_creator(storage, &x.inputs[i]),
_ => x,
}
}
#[inline]
fn map_err<'a>(res: Result<NdArray, ::op::ComputeError>) -> Option<NdArray> {
match res {
Ok(arr) => Some(arr),
Err(::op::ComputeError::NoOutput) => None,
_ => unreachable!(),
}
}
type ResourceStore<'a> = Vec<NodeWithValue<'a>>;
type FeedStore<'a> = Vec<&'a NdArray>;
#[inline]
fn find_fed_resource<'a>(node: &Tensor, feeds: &Vec<&(&Tensor, &'a NdArray)>) -> &'a NdArray {
for feed in feeds {
if Rc::ptr_eq(feed.0, node) {
return feed.1;
}
}
panic!("Placeholder unfilled. See backtrace.");
}
fn eval_internal<'a>(
targets: &Vec<&'a Tensor>,
feeds: &Vec<&(&'a Tensor, &NdArray)>,
) -> ResourceStore<'a> {
let mut res_store = Vec::new();
let mut feed_store = Vec::new();
let mut dfs_stack: Vec<(&Tensor, bool)> = targets.iter().map(|&x| (x, false)).collect();
while let Some((node, is_parent)) = dfs_stack.pop() {
if is_parent {
if node.is_placeholder {
node.resource_lookup_key.set(feed_store.len());
feed_store.push(find_fed_resource(node, &feeds));
} else {
node.resource_lookup_key.set(res_store.len());
if node.persistent_array.is_none() {
let y = {
let ins = OpComputeContext::_grab_inputs(node, &res_store, &feed_store);
if let Some(xs) = ins {
node.op.compute(OpComputeContext { node, xs })
} else {
vec![Err(::op::ComputeError::Delegate { to: 0 })]
}
};
res_store.push(node.with_value(y));
}
}
} else {
dfs_stack.push((node, true));
for child in &node.inputs {
let visited = {
let k = child.resource_lookup_key.get();
k < res_store.len() && Rc::ptr_eq(child, res_store[k].node)
};
if !visited {
dfs_stack.push((child, false));
}
}
}
}
res_store
}
fn finalize_resource_store(mut vec: ResourceStore) -> BTreeMap<usize, NodeWithValue> {
let mut retained_keys = Vec::new();
let len = vec.len();
{
let mut_ref = &mut vec;
let mut del = 0;
{
let v = &mut **mut_ref;
for i in 0..len {
if v[i].pending_count == 0 {
del += 1;
continue;
}
retained_keys.push(i);
if del > 0 {
v.swap(i - del, i);
}
}
}
if del > 0 {
mut_ref.truncate(len - del);
}
}
debug_assert_eq!(vec.len(), retained_keys.len());
retained_keys.into_iter().zip(vec).collect()
}
#[test]
fn test_eval() {
let ref v = ::ops::placeholder(&[3, 2, 1]);
let ref z = ::ops::reduce_sum(&::ops::squeeze(v, &[2]), &[0, 1], false);
let ref g = ::ops::grad(&[z], &[v]);
let eval_result = &eval(g, &[(v, &::ndarray_ext::ones(&[3, 2, 1]))])[0];
assert_eq!(eval_result.as_ref().unwrap().shape(), &[3, 2, 1]);
}
#[test]
fn test_constant_eval() {
let arr = ndarray::arr1(&[0., 0., 0.]);
assert_eq!(Some(arr.clone().into_dyn()), ::variable(arr).eval(&[]));
}
#[test]
fn test_placeholder_eval() {
let arr = ::ndarray_ext::ones(&[3, 2, 1]);
let ref v = ::ops::placeholder(&[3, 2, 1]);
let eval_result = eval(&[v], &[(v, &arr)]);
assert_eq!(eval_result[0], Some(arr));
}
#[test]
fn test_eval_internal() {
let ref v = ::ops::placeholder(&[3, 2, 1]);
let ref z = ::ops::squeeze(v, &[2]);
let ref g = ::ops::grad_with_default(&[z], &[v], &[&::ones(&z.shape())]);
let storage = eval_internal(&vec![&g[0]], &vec![&(v, &::ndarray_ext::ones(&[3, 2, 1]))]);
assert_eq!(
storage.iter().map(|x| x.node.op.name()).collect::<Vec<_>>(),
vec![
"ConvertToTensor",
"Squeeze", "Shape",
"Ones",
"ExpandDims",
]
);
}