use crate::ndarray_ext::{NdArray, NdArrayView};
use crate::op::{self, ComputeContext, InputArray, OpInput};
use crate::smallvec::SmallVec;
use crate::tensor::{Tensor, TensorInternal};
use crate::FxHashMap;
use crate::{Float, Graph};
use std::cell::UnsafeCell;
use std::sync::{RwLockReadGuard, RwLockWriteGuard};
const NUM_MAX_EVAL_BUF: usize = 8;
type EvalBuf<T> = SmallVec<[T; NUM_MAX_EVAL_BUF]>;
pub struct Eval<'view, 'feed, 'graph, F: Float> {
scope: &'graph Graph<F>,
buf: EvalBuf<Tensor<'graph, F>>,
feeds: Option<&'feed [crate::runtime::Feed<'view, F>]>,
}
impl<'feed, 'tensor, 'view, 'graph, F: Float> Eval<'view, 'feed, 'graph, F> {
#[inline]
pub fn new(scope: &'graph Graph<F>) -> Self {
Eval {
feeds: None,
scope,
buf: EvalBuf::new(),
}
}
#[inline]
pub fn push<A>(&mut self, x: A) -> &mut Self
where
A: AsRef<Tensor<'graph, F>>,
{
self.buf.push(*x.as_ref());
self
}
pub fn feed(&mut self, feeds: &'feed [crate::Feed<'view, F>]) -> &mut Self {
self.feeds = Some(feeds);
self
}
#[inline]
pub fn extend<A>(&mut self, xs: &'tensor [A]) -> &mut Self
where
A: AsRef<Tensor<'graph, F>>,
{
self.buf.extend(xs.iter().map(|x| *x.as_ref()));
self
}
#[inline]
pub fn run(&'tensor self) -> Vec<Result<NdArray<F>, crate::EvalError>> {
self.scope
.eval(self.buf.as_slice(), self.feeds.unwrap_or(&[]))
}
}
pub struct Feed<'feed, T: Float> {
placeholder_id: usize,
value: NdArrayView<'feed, T>,
}
impl<'feed, F: Float> Feed<'feed, F> {
#[inline]
pub(crate) fn new(placeholder_id: usize, value: NdArrayView<'feed, F>) -> Self {
Feed {
placeholder_id,
value,
}
}
}
#[derive(Copy, Clone)]
enum ValueType {
Owned,
View,
Empty,
}
#[derive(Copy, Clone)]
struct ValueInfo {
ty: ValueType,
key: usize,
}
impl ValueInfo {
#[inline]
fn new(ty: ValueType, key: usize) -> Self {
ValueInfo { ty, key }
}
}
struct OutputStorage<'view, F: Float> {
inner: UnsafeCell<OutputStorageInner<'view, F>>,
}
struct OutputStorageInner<'view, F: Float> {
value_storage: Vec<Option<NdArray<F>>>,
view_storage: Vec<NdArrayView<'view, F>>,
}
impl<'tensor, 'view, 'lock, F: Float> OutputStorage<'view, F> {
#[inline]
fn new() -> Self {
OutputStorage {
inner: UnsafeCell::new(OutputStorageInner {
value_storage: Vec::new(),
view_storage: Vec::new(),
}),
}
}
#[inline]
unsafe fn inner(&self) -> &OutputStorageInner<'view, F> {
&*self.inner.get()
}
#[inline]
unsafe fn inner_mut(&self) -> &mut OutputStorageInner<'view, F> {
&mut *self.inner.get()
}
#[inline]
fn push_owned(&self, val: NdArray<F>) -> usize {
unsafe {
let s = &mut self.inner_mut().value_storage;
let ret = s.len();
s.push(Some(val));
ret
}
}
#[inline]
fn push_view(&self, view: NdArrayView<'view, F>) -> usize {
unsafe {
let s = &mut self.inner_mut().view_storage;
let ret = s.len();
s.push(view);
ret
}
}
#[inline]
fn get_from_view(&self, i: usize) -> NdArrayView<'view, F> {
unsafe { self.inner().view_storage[i].clone() }
}
#[inline]
fn get_from_owned(&self, i: usize) -> NdArrayView<F> {
unsafe { self.inner().value_storage[i].as_ref().unwrap().view() }
}
#[inline]
fn take_from_owned(&self, i: usize) -> NdArray<F> {
unsafe { self.inner_mut().value_storage[i].take().unwrap() }
}
#[inline]
fn get(&'view self, node: &TensorInternal<F>, vi: ValueInfo) -> NdArrayView<'view, F> {
match vi.ty {
ValueType::Owned => self.get_from_owned(vi.key),
ValueType::View => self.get_from_view(vi.key),
ValueType::Empty => {
panic!(
"Attempting to use {}'s output which is empty.",
node.op.name()
);
}
}
}
}
#[inline]
fn retrieve_feed<'feeds, 'feed, F: Float>(
feeds: &'feeds [Feed<'feed, F>],
in_node_id: usize,
) -> NdArrayView<'feeds, F> {
for feed in feeds {
if feed.placeholder_id == in_node_id {
return feed.value.view();
}
}
panic!("Placeholder unfilled");
}
fn install_compute_results<'view, F: Float>(
results: crate::op::Results<'view, F>,
storage: &OutputStorage<'view, F>,
) -> Result<op::OutputArray<ValueInfo>, op::OpError> {
let mut value_info_list = op::OutputArray::new();
for y in results {
match y {
Some(Ok(crate::ArrRepr::Owned(val))) => {
let key = storage.push_owned(val);
value_info_list.push(ValueInfo::new(ValueType::Owned, key));
}
Some(Ok(crate::ArrRepr::View(val))) => {
let key = storage.push_view(val);
value_info_list.push(ValueInfo::new(ValueType::View, key));
}
Some(Err(e)) => {
return Err(e);
}
None => {
value_info_list.push(ValueInfo::new(ValueType::Empty, 0))
}
};
}
Ok(value_info_list)
}
#[inline]
fn aggregate_op_inputs<'ret, 'tensor: 'ret, 'slice: 'ret, 'feed: 'slice, F: Float>(
node: &'tensor TensorInternal<F>,
g: &Graph<F>,
node_info_map: &FxHashMap<usize, Result<op::OutputArray<ValueInfo>, op::OpError>>,
feeds: &'slice [Feed<'feed, F>],
storage: &'ret OutputStorage<'ret, F>,
input_values: &mut InputArray<OpInput<'ret, F>>,
read_guards: &mut InputArray<RwLockReadGuard<'tensor, NdArray<F>>>,
write_guards: &mut InputArray<RwLockWriteGuard<'tensor, NdArray<F>>>,
) -> Result<(), op::OpError> {
let mut input_status = Ok(());
for (in_node, &in_idx) in node.in_edges.iter().zip(&node.input_indices) {
let input_inner: &TensorInternal<F> = in_node.get(g);
let x = if input_inner.is_placeholder {
Ok(OpInput::new(retrieve_feed(feeds, in_node.id)))
} else if let Some(ref lock) = input_inner.variable_array {
unsafe {
if in_node.mut_usage {
write_guards.push(lock.write().unwrap());
let inserted = write_guards.len() - 1;
Ok(OpInput::new_mut(
(*(&mut write_guards[inserted] as *mut RwLockWriteGuard<NdArray<F>>))
.view_mut(),
))
} else {
read_guards.push(lock.read().unwrap());
let inserted = read_guards.len() - 1;
Ok(OpInput::new(
(*(&mut read_guards[inserted] as *mut RwLockReadGuard<NdArray<F>>)).view(),
))
}
}
} else if let Some(ref arr) = input_inner.get_constant_array_inner() {
Ok(OpInput::new(arr.view()))
} else {
match &node_info_map.get(&in_node.id).unwrap() {
Err(e) => Err(e.clone()),
Ok(vi_list) => Ok(OpInput::new(storage.get(input_inner, vi_list[in_idx]))),
}
};
match x {
Ok(x) => input_values.push(x),
Err(e) => {
input_status = Err(e);
break;
}
}
}
input_status
}
impl<F: Float> Graph<F> {
pub fn eval<'feed, 'tensor, 'scope, A>(
&'scope self,
tensors: &'tensor [A],
feeds: &[Feed<'feed, F>],
) -> Vec<Result<NdArray<F>, crate::EvalError>>
where
A: AsRef<Tensor<'scope, F>> + Copy,
{
let mut node_info_map: FxHashMap<usize, Result<op::OutputArray<ValueInfo>, op::OpError>> =
FxHashMap::default();
let storage = OutputStorage::new();
let mut dfs_stack = Vec::<(&TensorInternal<F>, bool)>::with_capacity(100);
for t in tensors.iter() {
dfs_stack.push((t.as_ref().inner(), false));
}
while let Some((node, is_parent)) = dfs_stack.pop() {
if is_parent {
if would_not_visit(node, &node_info_map) {
continue;
}
let mut xs = InputArray::new();
let (mut write_guards, mut read_guards) = (InputArray::new(), InputArray::new());
let input_status = aggregate_op_inputs(
node,
self,
&node_info_map,
feeds,
&storage,
&mut xs,
&mut read_guards,
&mut write_guards,
);
let installed_node_info = input_status.and_then(|()| {
let mut ctx = ComputeContext::new(node, xs);
node.op.compute(&mut ctx);
let ys = ctx.extract_outputs();
debug_assert!(!ys.is_empty(), "Bad op implementation: empty return value");
install_compute_results(ys, &storage)
});
node_info_map.insert(node.id(), installed_node_info);
} else {
dfs_stack.push((node, true));
for child in &node.in_edges {
let child = child.get(self);
if !would_not_visit(child, &node_info_map) {
dfs_stack.push((child, false));
}
}
}
}
let mut ret = Vec::with_capacity(tensors.len());
for t in tensors {
let t = t.as_ref();
let arr = if let Some(per) = t.clone_persistent_array() {
Ok(per)
} else if t.is_placeholder() {
Ok(retrieve_feed(feeds, t.id()).to_owned())
} else {
match &node_info_map.get(&t.id()).unwrap() {
Ok(value_info_list) => match value_info_list[0] {
ValueInfo {
ty: ValueType::Owned,
key,
} => Ok(storage.take_from_owned(key)),
ValueInfo {
ty: ValueType::View,
key,
} => Ok(storage.get_from_view(key).to_owned()),
ValueInfo {
ty: ValueType::Empty,
key: _,
} => Err(crate::EvalError::Empty),
},
Err(e) => {
Err(crate::EvalError::OpError(e.clone()))
}
}
};
ret.push(arr);
}
ret
}
}
#[inline]
fn would_not_visit<F: Float>(
node: &TensorInternal<F>,
info_map: &FxHashMap<usize, Result<op::OutputArray<ValueInfo>, op::OpError>>,
) -> bool {
node.is_placeholder || node.has_persistent_array || info_map.contains_key(&node.id())
}
#[test]
fn test_eval2() {
crate::with(|g: &mut crate::Graph<f32>| {
let a = g.ones(&[1, 1]);
let b = g.sigmoid(a);
b.eval(&[]).unwrap();
})
}
#[test]
fn test_eval() {
crate::with(|g| {
let v: Tensor<f32> = g.placeholder(&[3, 2, 1]);
let z = g.reduce_sum(g.squeeze(v, &[2]), &[0, 1], false);
let g = g.grad(&[z], &[v]);
let eval_result = g[0].eval(&[v.given(crate::ndarray_ext::ones(&[3, 2, 1]).view())]);
assert_eq!(eval_result.as_ref().unwrap().shape(), &[3, 2, 1]);
})
}
#[test]
fn test_variable_eval() {
use crate::tensor::Variable;
crate::with(|g| {
let arr = ndarray::arr1(&[0., 0., 0.]).into_dyn();
assert_eq!(Ok(arr.clone()), g.variable(arr).eval(&[]));
});
}
#[test]
fn test_constant_eval() {
use crate::tensor::Constant;
crate::with(|g| {
let arr = ndarray::arr1(&[0., 0., 0.]).into_dyn();
assert_eq!(Ok(arr.clone()), g.constant(arr).eval(&[]));
});
}
#[test]
fn test_placeholder_eval() {
crate::with(|g| {
let arr: NdArray<f32> = crate::ndarray_ext::ones(&[3, 2, 1]);
let v = g.placeholder(&[3, 2, 1]);
let eval_result = v.eval(&[v.given(arr.view())]);
assert_eq!(Ok(arr), eval_result);
});
}