use crate::ndarray_ext::{NdArray, NdArrayView};
use crate::op::{self, ComputeContext, InputArray, OpInput};
use crate::smallvec::SmallVec;
use crate::tensor::{Tensor, TensorInternal};
use crate::FxHashMap;
use crate::{Float, Graph};
use std::cell::UnsafeCell;
use std::mem;
use std::sync::{RwLockReadGuard, RwLockWriteGuard};
const NUM_MAX_EVAL_BUF: usize = 8;
type EvalBuf<T> = SmallVec<[T; NUM_MAX_EVAL_BUF]>;
pub struct Eval<'view, 'feed, 'graph, F: Float> {
scope: &'graph Graph<F>,
buf: EvalBuf<Tensor<'graph, F>>,
feeds: Option<&'feed [crate::runtime::Feed<'view, F>]>,
}
impl<'feed, 'tensor, 'view, 'graph, F: Float> Eval<'view, 'feed, 'graph, F> {
#[inline]
pub fn new(scope: &'graph Graph<F>) -> Self {
Eval {
feeds: None,
scope,
buf: EvalBuf::new(),
}
}
#[inline]
pub fn push<A>(&mut self, x: A) -> &mut Self
where
A: AsRef<Tensor<'graph, F>>,
{
self.buf.push(*x.as_ref());
self
}
pub fn feed(&mut self, feeds: &'feed [crate::Feed<'view, F>]) -> &mut Self {
self.feeds = Some(feeds);
self
}
#[inline]
pub fn extend<A>(&mut self, xs: &'tensor [A]) -> &mut Self
where
A: AsRef<Tensor<'graph, F>>,
{
self.buf.extend(xs.iter().map(|x| *x.as_ref()));
self
}
#[inline]
pub fn run(&'tensor self) -> Vec<Result<NdArray<F>, crate::EvalError>> {
self.scope
.eval(self.buf.as_slice(), self.feeds.unwrap_or(&[]))
}
}
pub struct Feed<'feed, T: Float> {
placeholder_id: usize,
value: NdArrayView<'feed, T>,
}
impl<'feed, F: Float> Feed<'feed, F> {
#[inline]
pub(crate) fn new(placeholder_id: usize, value: NdArrayView<'feed, F>) -> Self {
Feed {
placeholder_id,
value,
}
}
}
enum ValueType {
Owned,
View,
Empty,
}
struct ValueInfo {
ty: ValueType,
key: usize,
}
impl ValueInfo {
#[inline]
fn new(ty: ValueType, key: usize) -> Self {
ValueInfo { ty, key }
}
}
struct OutputStorage<'view, F: Float> {
inner: UnsafeCell<OutputStorageInner<'view, F>>,
}
struct OutputStorageInner<'view, F: Float> {
owned: Vec<Option<NdArray<F>>>,
borrowed: Vec<NdArrayView<'view, F>>,
}
impl<'tensor, 'view, 'lock, F: Float> OutputStorage<'view, F> {
#[inline]
fn new() -> Self {
OutputStorage {
inner: UnsafeCell::new(OutputStorageInner {
owned: Vec::new(),
borrowed: Vec::new(),
}),
}
}
#[inline]
fn owned_mut(&self) -> &mut Vec<Option<NdArray<F>>> {
unsafe { &mut (*self.inner.get()).owned }
}
#[inline]
fn owned(&self) -> &[Option<NdArray<F>>] {
unsafe { &(*self.inner.get()).owned }
}
#[inline]
fn view_mut(&self) -> &mut Vec<NdArrayView<'view, F>> {
unsafe { &mut (*self.inner.get()).borrowed }
}
#[inline]
fn view(&self) -> &[NdArrayView<'view, F>] {
unsafe { &(*self.inner.get()).borrowed }
}
}
fn validate_feed_shapes<F: Float>(feeds: &[Feed<F>], g: &Graph<F>) {
for feed in feeds {
let shape = feed.value.shape();
g.access_node(feed.placeholder_id)
.validate_feed_shape(shape);
}
}
#[inline]
fn retrieve_feed<'feeds, 'feed, F: Float>(
feeds: &'feeds [Feed<'feed, F>],
in_node_id: usize,
) -> NdArrayView<'feeds, F> {
for feed in feeds {
if feed.placeholder_id == in_node_id {
return feed.value.view();
}
}
panic!("Placeholder unfilled");
}
fn install_compute_results<'view, F: Float>(
results: crate::op::Results<'view, F>,
storage: &OutputStorage<'view, F>,
) -> Result<op::OutputArray<ValueInfo>, op::OpError> {
let mut value_info_list = op::OutputArray::new();
for y in results {
match y {
Some(Ok(crate::ArrRepr::Owned(val))) => {
storage.owned_mut().push(Some(val));
value_info_list.push(ValueInfo::new(ValueType::Owned, storage.owned().len() - 1));
}
Some(Ok(crate::ArrRepr::View(val))) => {
storage.view_mut().push(val);
value_info_list.push(ValueInfo::new(ValueType::View, storage.view().len() - 1));
}
Some(Err(e)) => {
return Err(e);
}
None => {
value_info_list.push(ValueInfo::new(ValueType::Empty, 0))
}
};
}
Ok(value_info_list)
}
#[inline]
fn aggregate_op_inputs<'ret, 'tensor: 'ret, 'slice: 'ret, 'feed: 'slice, F: Float>(
node: &'tensor TensorInternal<F>,
g: &Graph<F>,
node_info_map: &FxHashMap<usize, Result<op::OutputArray<ValueInfo>, op::OpError>>,
feeds: &'slice [Feed<'feed, F>],
storage: &'ret OutputStorage<'ret, F>,
input_values: &mut InputArray<OpInput<'ret, F>>,
read_guards: &mut InputArray<RwLockReadGuard<'tensor, NdArray<F>>>,
write_guards: &mut InputArray<RwLockWriteGuard<'tensor, NdArray<F>>>,
) -> Result<(), op::OpError> {
let mut input_status = Ok(());
for (in_node, &in_idx) in node.in_edges.iter().zip(&node.input_indices) {
let input_inner: &TensorInternal<F> = in_node.get_inner(g);
let x = if input_inner.is_placeholder {
Ok(OpInput::new(retrieve_feed(feeds, in_node.id)))
} else if let Some(ref lock) = input_inner.variable_array {
unsafe {
if in_node.mut_usage {
write_guards.push(lock.write().unwrap());
let inserted = write_guards.len() - 1;
Ok(OpInput::new_mut(
(*(&mut write_guards[inserted] as *mut RwLockWriteGuard<NdArray<F>>))
.view_mut(),
))
} else {
read_guards.push(lock.read().unwrap());
let inserted = read_guards.len() - 1;
Ok(OpInput::new(
(*(&mut read_guards[inserted] as *mut RwLockReadGuard<NdArray<F>>)).view(),
))
}
}
} else if let Some(ref arr) = input_inner.get_constant_array_inner() {
Ok(OpInput::new(arr.view()))
} else {
match &node_info_map.get(&in_node.id).unwrap() {
Err(e) => Err(e.clone()),
Ok(vi_list) => {
let vi = &vi_list[in_idx];
match vi.ty {
ValueType::Owned => Ok(OpInput::new(
storage.owned()[vi.key].as_ref().unwrap().view(),
)),
ValueType::View => Ok(OpInput::new(storage.view()[vi.key].clone())),
ValueType::Empty => {
panic!(
"Attempting to use {}'s output which is empty.",
input_inner.op.name()
);
}
}
}
}
};
match x {
Ok(x) => input_values.push(x),
Err(e) => {
input_status = Err(e);
break;
}
}
}
input_status
}
impl<F: Float> Graph<F> {
pub fn eval<'feed, 'tensor, 'scope, A>(
&'scope self,
tensors: &'tensor [A],
feeds: &[Feed<'feed, F>],
) -> Vec<Result<NdArray<F>, crate::EvalError>>
where
A: AsRef<Tensor<'scope, F>> + Copy,
{
validate_feed_shapes(feeds, self);
let mut node_info_map: FxHashMap<usize, Result<op::OutputArray<ValueInfo>, op::OpError>> =
FxHashMap::default();
let storage = OutputStorage::new();
let mut dfs_stack = Vec::<(&TensorInternal<F>, bool)>::with_capacity(100);
for t in tensors.iter() {
dfs_stack.push((t.as_ref().inner(), false));
}
while let Some((node, is_parent)) = dfs_stack.pop() {
if is_parent {
if would_not_visit(node, &node_info_map) {
continue;
}
let mut xs = InputArray::new();
let (mut write_guards, mut read_guards) = (InputArray::new(), InputArray::new());
let input_status = aggregate_op_inputs(
node,
self,
&node_info_map,
feeds,
&storage,
&mut xs,
&mut read_guards,
&mut write_guards,
);
let installed_node_info = input_status.and_then(|()| {
let mut ctx = ComputeContext::new(node, xs);
node.op.compute(&mut ctx);
let ys = ctx.extract_outputs();
if ys.is_empty() {
panic!("Bad op implementation: empty return value");
}
install_compute_results(ys, &storage)
});
node_info_map.insert(node.id(), installed_node_info);
} else {
dfs_stack.push((node, true));
for child in &node.in_edges {
let child = child.get(self).scoped_inner();
if !would_not_visit(child, &node_info_map) {
dfs_stack.push((child, false));
}
}
}
}
let mut ret = Vec::with_capacity(tensors.len());
for t in tensors {
let t = t.as_ref();
let arr = if let Some(per) = t.clone_persistent_array() {
Ok(per)
} else if t.is_placeholder() {
Ok(retrieve_feed(feeds, t.id()).to_owned())
} else {
match &node_info_map.get(&t.id()).unwrap() {
Ok(value_info_list) => {
let info = &value_info_list[0];
match &info.ty {
ValueType::Owned => {
Ok(mem::replace(&mut storage.owned_mut()[info.key], None).unwrap())
}
ValueType::View => Ok(storage.view()[info.key].to_owned()),
ValueType::Empty => Err(crate::EvalError::Empty),
}
}
Err(e) => {
Err(crate::EvalError::OpError(e.clone()))
}
}
};
ret.push(arr);
}
ret
}
}
#[inline]
fn would_not_visit<F: Float>(
node: &TensorInternal<F>,
info_map: &FxHashMap<usize, Result<op::OutputArray<ValueInfo>, op::OpError>>,
) -> bool {
node.is_placeholder || node.has_persistent_array || info_map.contains_key(&node.id())
}
#[test]
fn test_eval2() {
crate::with(|g: &mut crate::Graph<f32>| {
let a = g.ones(&[1, 1]);
let b = g.sigmoid(a);
b.eval(&[]).unwrap();
})
}
#[test]
fn test_eval() {
crate::with(|g| {
let v: Tensor<f32> = g.placeholder(&[3, 2, 1]);
let z = g.reduce_sum(g.squeeze(v, &[2]), &[0, 1], false);
let g = g.grad(&[z], &[v]);
let eval_result = g[0].eval(&[v.given(crate::ndarray_ext::ones(&[3, 2, 1]).view())]);
assert_eq!(eval_result.as_ref().unwrap().shape(), &[3, 2, 1]);
})
}
#[test]
fn test_variable_eval() {
use crate::tensor::Variable;
crate::with(|g| {
let arr = ndarray::arr1(&[0., 0., 0.]).into_dyn();
assert_eq!(Ok(arr.clone()), g.variable(arr).eval(&[]));
});
}
#[test]
fn test_constant_eval() {
use crate::tensor::Constant;
crate::with(|g| {
let arr = ndarray::arr1(&[0., 0., 0.]).into_dyn();
assert_eq!(Ok(arr.clone()), g.constant(arr).eval(&[]));
});
}
#[test]
fn test_placeholder_eval() {
crate::with(|g| {
let arr: NdArray<f32> = crate::ndarray_ext::ones(&[3, 2, 1]);
let v = g.placeholder(&[3, 2, 1]);
let eval_result = v.eval(&[v.given(arr.view())]);
assert_eq!(Ok(arr), eval_result);
});
}